@@ -885,6 +885,9 @@ class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
885885    """ 
886886    Fuse transpose or permute op pairs to a single view op. 
887887    (transpose or permutation) -> (quant or dequant) -> (transpose or permutation) 
888+     This happens when op2(op1) == identity, modulo unitary dimensions. 
889+     'unitary dimensions' example: a tensor of shape [1, 5, 30] is equivalent (in memory) to [5, 1, 30] 
890+     so transpose(1, 2) then transpose(0, 2) is a pseudo identity and should be fused. 
888891    """ 
889892
890893    # A list of ops that can be bypassed when looking for a 
@@ -908,7 +911,7 @@ def can_fuse_for_chain(
908911        if  not  super ().can_fuse_for_chain (producer , consumer , consumer_op_packets ):
909912            return  False 
910913
911-         # checking that permut2(permut1(identify )) == identity 
914+         # checking that permut2(permut1(identity )) == identity, modulo unitary dimensions  
912915        input_shape  =  cast (torch .fx .Node , producer .args [0 ]).meta ["val" ].shape 
913916        ident_dims  =  list (range (len (input_shape )))
914917        # this mapping helps to handle both transpose and permutations 
@@ -918,14 +921,20 @@ def can_fuse_for_chain(
918921        }
919922        in_dims  =  f [producer .target ](producer , ident_dims )
920923        out_dims  =  f [consumer .target ](consumer , in_dims )
921-         return  out_dims  ==  ident_dims 
924+         # Filtering out unitary dimensions 
925+         non_unit_ident_dims  =  [dim  for  dim  in  ident_dims  if  input_shape [dim ] !=  1 ]
926+         non_unit_out_dims  =  [dim  for  dim  in  out_dims  if  input_shape [dim ] !=  1 ]
927+         return  non_unit_out_dims  ==  non_unit_ident_dims 
922928
923929    def  get_fused_node (
924930        self ,
925931        producer : torch .fx .Node ,
926932        consumer : torch .fx .Node ,
927933        graph_module : torch .fx .GraphModule ,
928934    ) ->  torch .fx .Node :
935+         # This step is important because of how we can fuse transpositions that are not perfectly 
936+         # reverse one of another but will be fused if there are unitary dimensions. 
937+         # The fused operation must have the same output shape as the consumer. 
929938        output_shape  =  consumer .meta ["val" ].shape 
930939        with  graph_module .graph .inserting_after (consumer ):
931940            view  =  graph_module .graph .call_function (
0 commit comments