Skip to content

Commit a71c684

Browse files
authored
add TORCH_VERSION_CHECK for _register_meta (#2575)
1 parent 2eb4f97 commit a71c684

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

torchao/quantization/quant_primitives.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
TORCH_VERSION_AT_LEAST_2_5,
2121
TORCH_VERSION_AT_LEAST_2_6,
2222
_register_custom_op,
23+
_register_meta_op,
2324
)
2425

2526
__all__ = [
@@ -2292,7 +2293,7 @@ def _quantize_affine_float8(
22922293
return fp8_tensor
22932294

22942295

2295-
@torch.library.impl(quant_lib, "quantize_affine_float8", "Meta")
2296+
@_register_meta_op(quant_lib, "quantize_affine_float8")
22962297
def _quantize_affine_float8_meta(
22972298
tensor: torch.Tensor,
22982299
scale: torch.Tensor,
@@ -2319,7 +2320,7 @@ def _dequantize_affine_float8(
23192320
return hp_tensor.to(output_dtype)
23202321

23212322

2322-
@torch.library.impl(quant_lib, "dequantize_affine_float8", "Meta")
2323+
@_register_meta_op(quant_lib, "dequantize_affine_float8")
23232324
def _dequantize_affine_float8_meta(
23242325
tensor: torch.Tensor,
23252326
scale: torch.Tensor,

torchao/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,17 @@ def decorator(fn):
237237
return decorator
238238

239239

240+
def _register_meta_op(lib, op_name):
241+
def decorator(fn):
242+
if TORCH_VERSION_AT_LEAST_2_5:
243+
op = lib.impl(op_name, fn, "Meta")
244+
return op
245+
else:
246+
return fn
247+
248+
return decorator
249+
250+
240251
def get_model_size_in_bytes(model, ignore_embeddings=False):
241252
"""
242253
Returns the model size in bytes. The option to ignore embeddings

0 commit comments

Comments
 (0)