77
88
99# This file contains all the functions that replace one op with another in the
10- # graph. The functions replacing ops for models deployed with Jarvis are grouped
11- # together in class 'ReplaceOpsInGraph'. Some examples of functions in the class are
12- # 1. functions that replace an ATen op with a custom op that accepts extra arguments
13- # 2. functions that replace in-place variants of ATen ops with out-of-place version.
14- # 3. functions that replace an ATen op with another semantically equivalent ATen op.
15- # 4. functions that concretize optional args.
10+ # graph.
1611
1712# pyre-unsafe
1813
5449from torch .fx .node import Argument
5550
5651# A map to represent ops that:
57- # (a) are functionally equivalent wrt. Jarvis ; and
52+ # (a) are functionally equivalent; and
5853# (b) have identical arguments
5954# An op whose target is 'key' in this dict can be replaced by the functionally euivalent
6055# op whose target is 'value'. The replacement would just involve changing the op target.
@@ -650,7 +645,7 @@ def call_operator(self, op, args, kwargs, meta):
650645
651646# Make that pass runnable standalone at opt level 0.
652647@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
653- class ReplaceAtenConvolutionWithJarvisConvolutionPass (ExportPass ):
648+ class ReplaceAtenConvolutionWithCadenceConvolutionPass (ExportPass ):
654649 """
655650 Replace aten convolution op with jarvis-specific convolution op, since the
656651 aten version is not supported by jarvis.
@@ -784,7 +779,7 @@ class ReplaceConvWithChannelLastConv:
784779 tensors. However, if the input and output to the convolution op are originally
785780 in NWHC layout, and are then permuted to conform to NCHW layout, we can fuse
786781 the two permute ops with the convolution op, and call the NHWC layout
787- convolution op in Jarvis .
782+ convolution op.
788783 """
789784
790785 def __init__ (self ):
@@ -821,7 +816,7 @@ def conv_layout_is_nhwc(self, node: torch.fx.Node) -> bool:
821816 out_shape = get_shape (self .graph_module , node )
822817 assert out_shape is not None
823818 out_dims = len (out_shape )
824- assert out_dims in {3 , 4 }, "Jarvis only supports conv1d and conv2d"
819+ assert out_dims in {3 , 4 }, "Only supports conv1d and conv2d"
825820 conv1d = out_dims == 3
826821
827822 # Get the possible targets for the nodes in pt_nodes. Since conv1d has
@@ -951,7 +946,7 @@ class ReplaceConvWithChannelLastConvPass(ExportPass):
951946 """
952947
953948 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
954- result = ReplaceAtenConvolutionWithJarvisConvolutionPass ()(graph_module )
949+ result = ReplaceAtenConvolutionWithCadenceConvolutionPass ()(graph_module )
955950 assert result is not None
956951 ReplaceConvWithChannelLastConv ()(result .graph_module )
957952 return result
@@ -1871,9 +1866,9 @@ def call_operator(self, op, args, kwargs, meta):
18711866
18721867
18731868@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
1874- class ReplaceAtenAvgPoolWithJarvisAvgPoolPass (ExportPass ):
1869+ class ReplaceAtenAvgPoolWithCadenceAvgPoolPass (ExportPass ):
18751870 """
1876- Replace the aten avg_pool op with the jarvis custom avg_pool2d op.
1871+ Replace the aten avg_pool op with the cadence custom avg_pool2d op.
18771872 """
18781873
18791874 def call_operator (self , op , args , kwargs , meta ):
@@ -2435,7 +2430,7 @@ class CadenceReplaceOpsInGraph:
24352430 ReplacePadWithCatPass ,
24362431 ReplaceConstantPadNdWithSlicePass ,
24372432 ReplaceConvWithChannelLastConvPass ,
2438- ReplaceAtenConvolutionWithJarvisConvolutionPass ,
2433+ ReplaceAtenConvolutionWithCadenceConvolutionPass ,
24392434 ForceChannelLastForConvPass ,
24402435 ReplaceTrivialConvWithLinear ,
24412436 ReplaceConvWithIm2RowAndLinear ,
@@ -2454,7 +2449,7 @@ class CadenceReplaceOpsInGraph:
24542449 ReplacePT2DequantWithCadenceDequantPass ,
24552450 ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass ,
24562451 ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass ,
2457- ReplaceAtenAvgPoolWithJarvisAvgPoolPass ,
2452+ ReplaceAtenAvgPoolWithCadenceAvgPoolPass ,
24582453 ReplaceWhereWithFullArgsWithWhereScalar ,
24592454 ReplaceAtenApproxGeluWithApproxGeluPass ,
24602455 ReplaceSplitWithSlicePass ,
0 commit comments