@@ -885,6 +885,9 @@ class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
885885 """
886886 Fuse transpose or permute op pairs to a single view op.
887887 (transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
888+ This happens when op2(op1) == identity, modulo unitary dimensions.
889+ 'unitary dimensions' example: a tensor of shape [1, 5, 30] is equivalent (in memory) to [5, 1, 30]
890+ so transpose(1, 2) then transpose(0, 2) is a pseudo identity and should be fused.
888891 """
889892
890893 # A list of ops that can be bypassed when looking for a
@@ -908,7 +911,7 @@ def can_fuse_for_chain(
908911 if not super ().can_fuse_for_chain (producer , consumer , consumer_op_packets ):
909912 return False
910913
911- # checking that permut2(permut1(identify )) == identity
914+ # checking that permut2(permut1(identity )) == identity, modulo unitary dimensions
912915 input_shape = cast (torch .fx .Node , producer .args [0 ]).meta ["val" ].shape
913916 ident_dims = list (range (len (input_shape )))
914917 # this mapping helps to handle both transpose and permutations
@@ -918,14 +921,20 @@ def can_fuse_for_chain(
918921 }
919922 in_dims = f [producer .target ](producer , ident_dims )
920923 out_dims = f [consumer .target ](consumer , in_dims )
921- return out_dims == ident_dims
924+ # Filtering out unitary dimensions
925+ non_unit_ident_dims = [dim for dim in ident_dims if input_shape [dim ] != 1 ]
926+ non_unit_out_dims = [dim for dim in out_dims if input_shape [dim ] != 1 ]
927+ return non_unit_out_dims == non_unit_ident_dims
922928
923929 def get_fused_node (
924930 self ,
925931 producer : torch .fx .Node ,
926932 consumer : torch .fx .Node ,
927933 graph_module : torch .fx .GraphModule ,
928934 ) -> torch .fx .Node :
935+ # This step is important because of how we can fuse transpositions that are not perfectly
936+ # reverse one of another but will be fused if there are unitary dimensions.
937+ # The fused operation must have the same output shape as the consumer.
929938 output_shape = consumer .meta ["val" ].shape
930939 with graph_module .graph .inserting_after (consumer ):
931940 view = graph_module .graph .call_function (
0 commit comments