File tree Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -84,6 +84,13 @@ def get_inverse_transform_indices(
8484 return permuted_tile_indices
8585
8686
87+ # torch.compiler.is_compiling() is available only in torch >= 2.3
88+ if hasattr (torch .compiler , "is_compiling" ):
89+ _is_compiling = torch .compiler .is_compiling
90+ else :
91+ _is_compiling = torch ._dynamo .is_compiling
92+
93+
8794@deprecated (
8895 "This function is deprecated and will be removed in a future release." ,
8996 category = FutureWarning ,
@@ -174,7 +181,7 @@ def forward(
174181 input_shape = A .shape
175182
176183 # Cast A to fp16
177- if A .dtype != torch .float16 :
184+ if A .dtype != torch .float16 and not _is_compiling () :
178185 warnings .warn (f"MatMul8bitLt: inputs will be cast from { A .dtype } to float16 during quantization" )
179186
180187 if len (A .shape ) == 3 :
You can’t perform that action at this time.
0 commit comments