Skip to content

Commit 99c86a3

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Cadence: Warning if reference kernels not implemented for registered ops (#15130)
Summary: I ran into a problem where some new ops were checked into ops_registrations.py without an associated ref implementation. Any current meta kernels that don't have a reference we will warn on, but anything new will error out if no reference is provided. Differential Revision: D84650725
1 parent 2c706f1 commit 99c86a3

File tree

3 files changed

+189
-55
lines changed

3 files changed

+189
-55
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ runtime.python_library(
117117
],
118118
deps = [
119119
"fbcode//caffe2:torch",
120+
"fbcode//executorch/backends/cadence/aot:ref_implementations",
120121
"fbcode//executorch/backends/cadence/aot:utils",
121122
],
122123
)
@@ -425,7 +426,6 @@ python_unittest(
425426
"//executorch/exir:pass_base",
426427
"//executorch/exir/dialects:lib",
427428
"//executorch/exir/passes:lib",
428-
":ref_implementations",
429429
],
430430
)
431431

@@ -628,7 +628,6 @@ python_unittest(
628628
deps = [
629629
":typing_stubs",
630630
"//executorch/backends/cadence/aot:ops_registrations",
631-
"//executorch/backends/cadence/aot:ref_implementations",
632631
"//caffe2:torch",
633632
]
634633
)

backends/cadence/aot/ops_registrations.py

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77
# pyre-strict
88

9+
import logging
910
from math import prod
10-
from typing import Optional, Tuple
11+
from typing import Callable, Optional, Tuple
1112

1213
import torch
1314
from executorch.backends.cadence.aot.utils import (
@@ -21,6 +22,113 @@
2122

2223
lib = Library("cadence", "DEF")
2324

25+
# Track meta kernels that have been registered
26+
_REGISTERED_META_KERNELS: set[str] = set()
27+
28+
29+
# Original register_fake function to use for registrations
30+
_register_fake_original = register_fake
31+
32+
_OUTPUTS_TYPE = torch.Tensor | tuple[torch.Tensor, ...]
33+
34+
35+
def _validate_ref_impl_exists() -> None:
36+
"""
37+
Validates that all registered meta kernels have corresponding reference implementations.
38+
This is called at module initialization time after both files have been imported.
39+
"""
40+
41+
# Import here after module initialization to ensure ref_implementations has been loaded
42+
from executorch.backends.cadence.aot.ref_implementations import (
43+
get_registered_ref_implementations,
44+
)
45+
46+
# If reference implementation should not be in
47+
# executorch.backends.cadence.aot.ref_implementations, add here
48+
_SKIP_OPS = {
49+
"cadence::roi_align_box_processor",
50+
}
51+
52+
# All of these should either
53+
# 1. be removed
54+
# 2. have a reference implementation added to ref_implementations.py
55+
_WARN_ONLY = {
56+
"cadence::quantized_w8a32_linear",
57+
"cadence::quantized_add", # We should only support per_tensor variant, should remove
58+
"cadence::idma_store",
59+
"cadence::idma_load",
60+
"cadence::_softmax_f32_f32",
61+
"cadence::requantize", # We should only support per_tensor variant, should remove
62+
"cadence::quantized_softmax.per_tensor",
63+
"cadence::quantize_per_tensor_asym8u",
64+
"cadence::quantize_per_tensor_asym8s",
65+
"cadence::dequantize_per_tensor_asym8u",
66+
"cadence::dequantize_per_tensor_asym32s",
67+
"cadence::dequantize_per_tensor_asym16u",
68+
"cadence::linalg_vector_norm",
69+
"cadence::quantized_conv2d_nchw", # We should only support per_tensor variant, should remove
70+
"cadence::quantized_w8a32_conv",
71+
"cadence::quantize_per_tensor_asym32s",
72+
"cadence::quantized_relu", # We should only support per_tensor variant, should remove
73+
"cadence::linalg_svd",
74+
"cadence::quantized_conv2d_nhwc", # We should only support per_tensor variant, should remove
75+
"cadence::idma_copy",
76+
"cadence::quantize_per_tensor_asym16u",
77+
"cadence::dequantize_per_tensor_asym8s",
78+
"cadence::quantize_per_tensor_asym16s",
79+
"cadence::dequantize_per_tensor_asym16s",
80+
"cadence::quantized_softmax",
81+
"cadence::idma_wait",
82+
"cadence::quantized_w8a32_gru",
83+
"cadence::quantized_layer_norm", # We should only support per_tensor variant, should remove
84+
}
85+
86+
ref_impls = get_registered_ref_implementations()
87+
warn_impls = []
88+
error_impls = []
89+
for op_name in _REGISTERED_META_KERNELS:
90+
# Strip the namespace prefix if present (e.g., "cadence::" -> "")
91+
op_name_clean = op_name.split("::")[-1] if "::" in op_name else op_name
92+
93+
if op_name_clean not in ref_impls:
94+
if op_name in _WARN_ONLY:
95+
warn_impls.append(op_name)
96+
elif op_name not in _SKIP_OPS:
97+
error_impls.append(op_name)
98+
99+
if warn_impls:
100+
warn_msg = (
101+
f"The following {len(warn_impls)} meta kernel registrations are missing reference implementations:\n"
102+
+ "\n".join(f" - {op}" for op in warn_impls)
103+
+ "\n\nPlease add reference implementations in ref_implementations.py using "
104+
+ "@impl_tracked(m, '<op_name>')."
105+
)
106+
logging.warning(warn_msg)
107+
108+
if error_impls:
109+
error_msg = (
110+
f"The following {len(error_impls)} meta kernel registrations are missing reference implementations:\n"
111+
+ "\n".join(f" - {op}" for op in error_impls)
112+
+ "\n\nPlease add reference implementations in ref_implementations.py using "
113+
+ "@impl_tracked(m, '<op_name>')."
114+
)
115+
116+
raise RuntimeError(error_msg)
117+
118+
119+
# Wrap register_fake to track all registrations
120+
def register_fake(
121+
op_name: str,
122+
) -> Callable[[Callable[..., _OUTPUTS_TYPE]], Callable[..., _OUTPUTS_TYPE]]:
123+
"""
124+
Wrapped version of register_fake that tracks all meta kernel registrations.
125+
This enables validation that all meta kernels have reference implementations.
126+
"""
127+
global _REGISTERED_META_KERNELS
128+
_REGISTERED_META_KERNELS.add(op_name)
129+
return _register_fake_original(op_name)
130+
131+
24132
lib.define(
25133
"quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
26134
)
@@ -2406,7 +2514,9 @@ def idma_load_impl(
24062514
task_num: int = 0,
24072515
channel: int = 0,
24082516
) -> torch.Tensor:
2409-
return copy_idma_copy_impl(src, task_num, channel)
2517+
res = copy_idma_copy_impl(src, task_num, channel)
2518+
assert isinstance(res, torch.Tensor)
2519+
return res
24102520

24112521

24122522
@register_fake("cadence::idma_store")
@@ -2415,7 +2525,9 @@ def idma_store_impl(
24152525
task_num: int = 0,
24162526
channel: int = 0,
24172527
) -> torch.Tensor:
2418-
return copy_idma_copy_impl(src, task_num, channel)
2528+
res = copy_idma_copy_impl(src, task_num, channel)
2529+
assert isinstance(res, torch.Tensor)
2530+
return res
24192531

24202532

24212533
@register_fake("cadence::roi_align_box_processor")
@@ -2671,3 +2783,8 @@ def quantized_w8a32_gru_meta(
26712783
b_h_scale: float,
26722784
) -> torch.Tensor:
26732785
return inputs.new_empty((2, hidden.shape[-1]), dtype=inputs.dtype)
2786+
2787+
2788+
# Validate that all meta kernels have reference implementations
2789+
# This is called at module import time to catch missing implementations early
2790+
_validate_ref_impl_exists()

0 commit comments

Comments
 (0)