@@ -71,7 +71,7 @@ def move_to(obj, device):
7171 return obj
7272
7373
74- def mockbmm (mat1 , mat2 ):
74+ def mockbmm (mat1 , mat2 , default_to_torch = False ):
7575 """
7676 This function is used to mock the behavior of the bmm function in PyTorch.
7777 It is used to work around the fact that the bmm function in PyTorch is not
@@ -86,20 +86,23 @@ def mockbmm(mat1, mat2):
8686 """
8787 cf = sys ._getframe ()
8888 qbmm_mod = None
89+ qbmm_lineno = cf .f_back .f_lineno
8990 while cf .f_back and qbmm_mod is None :
9091 # First frame is QBmm's forward itself, can start searching from previous stack
9192 cf = cf .f_back
92- if "forward" in cf .f_code .co_name or "_attn" in cf .f_code .co_name :
93+ if (
94+ "forward" in cf .f_code .co_name or "_attn" in cf .f_code .co_name
95+ ) and "self" in cf .f_locals :
9396 mod_calling_bmm_function = cf .f_locals ["self" ]
9497 # If not found -> default to torch.matmul
95- qbmm_mod = getattr (
96- mod_calling_bmm_function , "QBmm" + str (cf .f_lineno ), torch .matmul
97- )
98+ qbmm_mod = getattr (mod_calling_bmm_function , f"QBmm{ qbmm_lineno } " , None )
9899 del cf
100+ if qbmm_mod is None and default_to_torch :
101+ qbmm_mod = torch .matmul
99102 return qbmm_mod (mat1 , mat2 )
100103
101104
102- def mockmatmul (mat1 , mat2 ):
105+ def mockmatmul (mat1 , mat2 , default_to_torch = False ):
103106 """
104107 Patches torch.matmul() with QBmm( torch.bmm() )
105108
@@ -109,31 +112,37 @@ def mockmatmul(mat1, mat2):
109112
110113 Returns:
111114 torch.Tensor: The result of the mock matrix multiplication.
115+ NOTE:
116+ 1. First frame is mockmatmul itself. One frame back (cf.f_back) is where torch.matmul
117+ happened, whose line number is the one used for QBmm<xxx>
118+ 2. QBmm module may not be attached to the immediate frame where torch.matmul happened. Need
119+ to trace back and find the frame with both "forward" in name and "self" in locals, i.e.
120+ a class (nn.module) has a function named "forward" something
121+ 3. Keep default_to_torch=False unless really needed, otherwise if something went wrong with
122+ QBmm detection, it could go to default silently, which would be very difficult to debug.
112123 """
113124 cf = sys ._getframe ()
114125 qbmm_mod = None
126+ qbmm_lineno = cf .f_back .f_lineno
115127 while cf .f_back and qbmm_mod is None :
116- # First frame is QBmm's forward itself, can start searching from previous stack
117128 cf = cf .f_back
118129 if (
119130 "forward" in cf .f_code .co_name or "_attn" in cf .f_code .co_name
120131 ) and "self" in cf .f_locals :
121132 mod_calling_bmm_function = cf .f_locals ["self" ]
122133 # If not found -> default to torch.bmm
123- qbmm_mod = getattr (
124- mod_calling_bmm_function , "QBmm" + str (cf .f_lineno ), torch .bmm
125- )
134+ qbmm_mod = getattr (mod_calling_bmm_function , f"QBmm{ qbmm_lineno } " , None )
126135 del cf
127136
128137 # Didn't find the corresponding QBmm, default the call to torch.bmm
129- if qbmm_mod == torch . bmm :
138+ if qbmm_mod is None and default_to_torch :
130139 org_batch_header = mat1 .shape [:2 ]
131140 # Need to double check m1/m2 are 3d, otherwise reshape
132141 if len (mat1 .shape ) > 3 :
133142 mat1 = mat1 .reshape ([- 1 , mat1 .shape [- 2 ], mat1 .shape [- 1 ]])
134143 if len (mat2 .shape ) > 3 :
135144 mat2 = mat2 .reshape ([- 1 , mat2 .shape [- 2 ], mat2 .shape [- 1 ]])
136- output = qbmm_mod (mat1 , mat2 )
145+ output = torch . bmm (mat1 , mat2 )
137146 output = output .reshape ([* org_batch_header , * output .shape [1 :]])
138147 return output
139148 return qbmm_mod (mat1 , mat2 )
@@ -149,6 +158,9 @@ def patch_torch_bmm(qcfg):
149158 if qcfg is not None :
150159 # could be 'torch.bmm', 'torch.matmul', or None
151160 ops_to_patch = qcfg .get ("which2patch_contextmanager" , None )
161+ # if qcfg["bmm_prep"]["bmm_only_in_self_attn"] is False, may need to enable default_to_torch
162+ # in mock functions, e.g. partial(mockmatmul, default_to_torch=True)
163+ # This is in case a model uses extra matmuls, and QBmmXXX is not found or attached properly.
152164 new_target = (
153165 mockbmm
154166 if ops_to_patch == "torch.bmm"
0 commit comments