@@ -772,186 +772,6 @@ def call_operator(self, op, args, kwargs, meta):
772772 return super ().call_operator (target , new_args , kwargs , meta )
773773
774774
775- # TODO(matthiascremon): this is a fuse op, not a replace op
776- class ReplaceConvWithChannelLastConv :
777- """
778- Convolution op in pytorch expects NCHW layout for input, weight, and output
779- tensors. However, if the input and output to the convolution op are originally
780- in NWHC layout, and are then permuted to conform to NCHW layout, we can fuse
781- the two permute ops with the convolution op, and call the NHWC layout
782- convolution op.
783- """
784-
785- def __init__ (self ):
786- self .counter = 0
787- self .graph_module = None
788-
789- def __call__ (self , graph_module : torch .fx .GraphModule ):
790- self .replace_conv_with_nhwc_conv (graph_module )
791-
792- def conv_layout_is_nhwc (self , node : torch .fx .Node ) -> bool :
793- """
794- Return true if the convolution input and output are connected to permute
795- ops, and the input/output to/from the permute ops is NHWC layout tensor.
796- """
797- # There must only be a single user of the output node (which must be a
798- # permute/tranpsose op). The input of the convolution must be connected
799- # to a permute op, and that permute op should have a single user.
800- conv_inp = node .args [0 ]
801- assert isinstance (conv_inp , torch .fx .Node )
802- if len (node .users ) != 1 or len (conv_inp .users ) != 1 :
803- return False
804-
805- # Get the input and output (permute/transpose) nodes of the convolution
806- conv_user = list (node .users .keys ())[0 ]
807- assert isinstance (conv_user , torch .fx .Node )
808- pt_nodes : Set [torch .fx .Node ] = {conv_inp , conv_user }
809-
810- # Any node in pt_nodes must not be a placeholder.
811- if contains_placeholder_or_param (pt_nodes ):
812- return False
813-
814- # Determine if the convolution is 1d or 2d. The output tensor must be
815- # 3- or 4-dimensional
816- out_shape = get_shape (self .graph_module , node )
817- assert out_shape is not None
818- out_dims = len (out_shape )
819- assert out_dims in {3 , 4 }, "Only supports conv1d and conv2d"
820- conv1d = out_dims == 3
821-
822- # Get the possible targets for the nodes in pt_nodes. Since conv1d has
823- # 3-dimensional input and output tensors, the nodes in pt_nodes could
824- # be either permute or transpose op. For conv2d, the nodes in pt_nodes
825- # must be permute ops.
826- p_target = exir_ops .edge .aten .permute_copy .default
827- t_target = exir_ops .edge .aten .transpose_copy .int
828- pt_targets = [p_target ] + ([t_target ] if conv1d else [])
829-
830- # If any node in pt_nodes is not permute op (or tranpose op for conv1d),
831- # bail.
832- if any (x .target not in pt_targets for x in pt_nodes ):
833- return False
834-
835- # Now we need to determine the dimension permutations:
836- # If the input had NHWC layout, which was then permuted/transposed
837- # by a permute/transpose op to NCHW layout, the permutation must be
838- # [0, 3, 2, 1] (or [0, 2, 1] for conv1d).
839- # If the output had NCHW layout, and was then permuted to NHWC layout,
840- # the permutation must be [0, 2, 3, 1] (or [0, 2, 1] for conv1d).
841- nhwc_permute_order = {
842- node .args [0 ]: [0 , 2 , 1 ] if conv1d else [0 , 3 , 1 , 2 ],
843- list (node .users .keys ())[0 ]: [0 , 2 , 1 ] if conv1d else [0 , 2 , 3 , 1 ],
844- }
845- for x in pt_nodes :
846- order = (
847- x .args [1 ]
848- if x .target == p_target
849- else get_transposed_dims (x , list (range (out_dims )))
850- )
851- if order != nhwc_permute_order [x ]:
852- return False
853-
854- return True
855-
856- def replace_conv_with_nhwc_conv (self , graph_module : torch .fx .GraphModule ):
857- self .graph_module = graph_module
858- graph = graph_module .graph
859- for node in graph .nodes :
860- # We are only interested in convolution nodes that have NHWC layout
861- if node .target not in {
862- exir_ops .edge .cadence .quantized_conv .default ,
863- exir_ops .edge .cadence .convolution .default ,
864- exir_ops .edge .cadence .quantized_transposed_conv .default ,
865- exir_ops .edge .cadence .transposed_convolution .default ,
866- } or not self .conv_layout_is_nhwc (node ):
867- continue
868-
869- # Get the args of convolution op
870- args = list (node .args )
871- # The input is connected to a permute/transpose op that converts the
872- # NHWC layout to NCHW layout. The input of the permute op will become
873- # this convolution op's input.
874- in_tp = args [0 ]
875- args [0 ] = in_tp .args [0 ]
876- # The weight is in NHWC layout. Permute it to NHWC layout.
877- weight_tensor = get_tensor_from_attr (graph_module , args [1 ])
878- assert isinstance (weight_tensor , torch .Tensor )
879- # We cannot directly permute a per-channel quantized tensor. We will
880- # dequantize it, permute the fp32 tensor, and then requantize the
881- # permuted tensor.
882- if (
883- is_quantized_tensor (weight_tensor )
884- and weight_tensor .qscheme () == torch .per_channel_affine
885- ):
886- # We have already asserted during quantizing conv op that the
887- # quantization axis is 0.
888- dequant_weight = weight_tensor .dequantize ()
889- dequant_weight = (
890- dequant_weight .permute ([0 , 2 , 1 ])
891- if dequant_weight .dim () == 3
892- else dequant_weight .permute ([0 , 2 , 3 , 1 ])
893- )
894- weight_tensor = torch .quantize_per_channel (
895- dequant_weight .contiguous (),
896- weight_tensor .q_per_channel_scales (),
897- weight_tensor .q_per_channel_zero_points (),
898- 0 ,
899- weight_tensor .dtype ,
900- )
901- else :
902- weight_tensor = (
903- weight_tensor .permute ([0 , 2 , 1 ])
904- if weight_tensor .dim () == 3
905- else weight_tensor .permute ([0 , 2 , 3 , 1 ])
906- )
907- # Make the weight tensor contiguous, since we have permuted it.
908- weight_tensor = weight_tensor .contiguous ()
909- # Add the permuted weight into the graph, and update the weight in
910- # args.
911- with graph .inserting_before (node ):
912- weight_name = f"_weight_nhwc_{ self .counter } "
913- graph_module .register_buffer (weight_name , weight_tensor )
914- weight = graph .get_attr (weight_name )
915- args [1 ] = weight
916-
917- # The 'channel_last' arg is True. It is the last arg.
918- args [- 1 ] = True
919- # Now update the convolution node args to mark it as NHWC convolution
920- node .args = tuple (args )
921-
922- # Replace all the uses of the permute op connected to the output op
923- # with this convolution.
924- out_tp = list (node .users .keys ())[0 ]
925- out_tp .replace_all_uses_with (node )
926- node .meta = out_tp .meta
927-
928- # Erase the permute ops connected to the input and output of the
929- # convolution op.
930- graph .erase_node (in_tp )
931- graph .erase_node (out_tp )
932- self .counter += 1
933-
934- graph_module .recompile ()
935-
936-
937- # This pass needs to be reworked to be compatible with PT2. It is an optimization
938- # pass anyway, so move it to opt level 2.
939- # TODO: T213724613 update and improve this pass.
940- # @register_cadence_pass(CadencePassAttribute(opt_level=2))
941- class ReplaceConvWithChannelLastConvPass (ExportPass ):
942- """
943- Replace the ATen convolution op with custom conv op with NCHW or NHWC layout
944- input tensors, depending on the presence of permute/transpose ops connected
945- to the input tensor.
946- """
947-
948- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
949- result = ReplaceAtenConvolutionWithCadenceConvolutionPass ()(graph_module )
950- assert result is not None
951- ReplaceConvWithChannelLastConv ()(result .graph_module )
952- return result
953-
954-
955775@register_cadence_pass (CadencePassAttribute (opt_level = 2 ))
956776class ReplaceTrivialConvWithLinear (ExportPass ):
957777 """
@@ -1127,7 +947,7 @@ def transpose_dims(
1127947
1128948
1129949@register_cadence_pass (CadencePassAttribute (opt_level = 3 ))
1130- class ForceChannelLastForConvPass (ExportPassWithTransposeHelper ):
950+ class ReplaceConvWithChannelLastConvPass (ExportPassWithTransposeHelper ):
1131951 def change_nchw_to_nhwc (self , proxy : ProxyValue , meta : NodeMetadata ) -> ProxyValue :
1132952 shape = proxy .to_tensor ().shape
1133953 if len (shape ) == 3 :
@@ -2429,9 +2249,8 @@ class CadenceReplaceOpsInGraph:
24292249 ReplaceRepeatWithCatPass ,
24302250 ReplacePadWithCatPass ,
24312251 ReplaceConstantPadNdWithSlicePass ,
2432- ReplaceConvWithChannelLastConvPass ,
24332252 ReplaceAtenConvolutionWithCadenceConvolutionPass ,
2434- ForceChannelLastForConvPass ,
2253+ ReplaceConvWithChannelLastConvPass ,
24352254 ReplaceTrivialConvWithLinear ,
24362255 ReplaceConvWithIm2RowAndLinear ,
24372256 ReplaceTransposedConvWithLinearPass ,
0 commit comments