File tree Expand file tree Collapse file tree 2 files changed +14
-2
lines changed Expand file tree Collapse file tree 2 files changed +14
-2
lines changed Original file line number Diff line number Diff line change 20
20
TORCH_VERSION_AT_LEAST_2_5 ,
21
21
TORCH_VERSION_AT_LEAST_2_6 ,
22
22
_register_custom_op ,
23
+ _register_meta_op ,
23
24
)
24
25
25
26
__all__ = [
@@ -2292,7 +2293,7 @@ def _quantize_affine_float8(
2292
2293
return fp8_tensor
2293
2294
2294
2295
2295
- @torch . library . impl (quant_lib , "quantize_affine_float8" , "Meta " )
2296
+ @_register_meta_op (quant_lib , "quantize_affine_float8" )
2296
2297
def _quantize_affine_float8_meta (
2297
2298
tensor : torch .Tensor ,
2298
2299
scale : torch .Tensor ,
@@ -2319,7 +2320,7 @@ def _dequantize_affine_float8(
2319
2320
return hp_tensor .to (output_dtype )
2320
2321
2321
2322
2322
- @torch . library . impl (quant_lib , "dequantize_affine_float8" , "Meta " )
2323
+ @_register_meta_op (quant_lib , "dequantize_affine_float8" )
2323
2324
def _dequantize_affine_float8_meta (
2324
2325
tensor : torch .Tensor ,
2325
2326
scale : torch .Tensor ,
Original file line number Diff line number Diff line change @@ -237,6 +237,17 @@ def decorator(fn):
237
237
return decorator
238
238
239
239
240
+ def _register_meta_op (lib , op_name ):
241
+ def decorator (fn ):
242
+ if TORCH_VERSION_AT_LEAST_2_5 :
243
+ op = lib .impl (op_name , fn , "Meta" )
244
+ return op
245
+ else :
246
+ return fn
247
+
248
+ return decorator
249
+
250
+
240
251
def get_model_size_in_bytes (model , ignore_embeddings = False ):
241
252
"""
242
253
Returns the model size in bytes. The option to ignore embeddings
You can’t perform that action at this time.
0 commit comments