Skip to content

Commit 7b27971

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Cadence: Warning if reference kernels not implemented for registered ops
Summary: I ran into a problem where some new ops were checked into ops_registrations.py without an associated ref implementation. For now, I am warning for all of these situations, but once I can run without warnings, will error out instead. Differential Revision: D84650725
1 parent 8fa1e38 commit 7b27971

File tree

2 files changed

+136
-53
lines changed

2 files changed

+136
-53
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77
# pyre-strict
88

9+
import logging
910
from math import prod
10-
from typing import Optional, Tuple
11+
from typing import Callable, Optional, Tuple
1112

1213
import torch
1314
from executorch.backends.cadence.aot.utils import (
@@ -21,6 +22,61 @@
2122

2223
lib = Library("cadence", "DEF")
2324

25+
# Track meta kernels that have been registered
26+
_REGISTERED_META_KERNELS: set[str] = set()
27+
28+
29+
# Original register_fake function to use for registrations
30+
_register_fake_original = register_fake
31+
32+
_OUTPUTS_TYPE = torch.Tensor | tuple[torch.Tensor, ...]
33+
34+
35+
def _validate_ref_impl_exists() -> None:
36+
"""
37+
Validates that all registered meta kernels have corresponding reference implementations.
38+
This is called at module initialization time after both files have been imported.
39+
"""
40+
# Import here after module initialization to ensure ref_implementations has been loaded
41+
from executorch.backends.cadence.aot.ref_implementations import (
42+
get_registered_ref_implementations,
43+
)
44+
45+
ref_impls = get_registered_ref_implementations()
46+
missing_impls = []
47+
48+
for op_name in _REGISTERED_META_KERNELS:
49+
# Strip the namespace prefix if present (e.g., "cadence::" -> "")
50+
op_name_clean = op_name.split("::")[-1] if "::" in op_name else op_name
51+
52+
if op_name_clean not in ref_impls:
53+
missing_impls.append(op_name)
54+
55+
if missing_impls:
56+
error_msg = (
57+
"The following meta kernel registrations are missing reference implementations:\n"
58+
+ "\n".join(f" - {op}" for op in missing_impls)
59+
+ "\n\nPlease add reference implementations in ref_implementations.py using "
60+
+ "@impl_tracked(m, '<op_name>')."
61+
)
62+
63+
# TODO: T241466288 Make this an error once all meta kernels have reference implementations
64+
logging.warning(error_msg)
65+
66+
67+
# Wrap register_fake to track all registrations
68+
def register_fake(
69+
op_name: str,
70+
) -> Callable[[Callable[..., _OUTPUTS_TYPE]], Callable[..., _OUTPUTS_TYPE]]:
71+
"""
72+
Wrapped version of register_fake that tracks all meta kernel registrations.
73+
This enables validation that all meta kernels have reference implementations.
74+
"""
75+
global _REGISTERED_META_KERNELS
76+
_REGISTERED_META_KERNELS.add(op_name)
77+
return _register_fake_original(op_name)
78+
79+
2480
lib.define(
2581
"quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
2682
)
@@ -2406,7 +2462,9 @@ def idma_load_impl(
24062462
task_num: int = 0,
24072463
channel: int = 0,
24082464
) -> torch.Tensor:
2409-
return copy_idma_copy_impl(src, task_num, channel)
2465+
res = copy_idma_copy_impl(src, task_num, channel)
2466+
assert isinstance(res, torch.Tensor)
2467+
return res
24102468

24112469

24122470
@register_fake("cadence::idma_store")
@@ -2415,7 +2473,9 @@ def idma_store_impl(
24152473
task_num: int = 0,
24162474
channel: int = 0,
24172475
) -> torch.Tensor:
2418-
return copy_idma_copy_impl(src, task_num, channel)
2476+
res = copy_idma_copy_impl(src, task_num, channel)
2477+
assert isinstance(res, torch.Tensor)
2478+
return res
24192479

24202480

24212481
@register_fake("cadence::roi_align_box_processor")
@@ -2671,3 +2731,8 @@ def quantized_w8a32_gru_meta(
26712731
b_h_scale: float,
26722732
) -> torch.Tensor:
26732733
return inputs.new_empty((2, hidden.shape[-1]), dtype=inputs.dtype)
2734+
2735+
2736+
# Validate that all meta kernels have reference implementations
2737+
# This is called at module import time to catch missing implementations early
2738+
_validate_ref_impl_exists()

0 commit comments

Comments
 (0)