66
77# pyre-strict
88
9+ import logging
910from math import prod
10- from typing import Optional , Tuple
11+ from typing import Callable , Optional , Tuple
1112
1213import torch
1314from executorch .backends .cadence .aot .utils import (
2122
2223lib = Library ("cadence" , "DEF" )
2324
25+ # Track meta kernels that have been registered
26+ _REGISTERED_META_KERNELS : set [str ] = set ()
27+
28+
29+ # Original register_fake function to use for registrations
30+ _register_fake_original = register_fake
31+
32+ _OUTPUTS_TYPE = torch .Tensor | tuple [torch .Tensor , ...]
33+
34+
35+ def _validate_ref_impl_exists () -> None :
36+ """
37+ Validates that all registered meta kernels have corresponding reference implementations.
38+ This is called at module initialization time after both files have been imported.
39+ """
40+
41+ # Import here after module initialization to ensure ref_implementations has been loaded
42+ from executorch .backends .cadence .aot .ref_implementations import (
43+ get_registered_ref_implementations ,
44+ )
45+
46+ # If reference implementation should not be in
47+ # executorch.backends.cadence.aot.ref_implementations, add here
48+ _SKIP_OPS = {
49+ "cadence::roi_align_box_processor" ,
50+ }
51+
52+ # All of these should either
53+ # 1. be removed
54+ # 2. have a reference implementation added to ref_implementations.py
55+ _WARN_ONLY = {
56+ "cadence::quantized_w8a32_linear" ,
57+ "cadence::quantized_add" , # We should only support per_tensor variant, should remove
58+ "cadence::idma_store" ,
59+ "cadence::idma_load" ,
60+ "cadence::_softmax_f32_f32" ,
61+ "cadence::requantize" , # We should only support per_tensor variant, should remove
62+ "cadence::quantized_softmax.per_tensor" ,
63+ "cadence::quantize_per_tensor_asym8u" ,
64+ "cadence::quantize_per_tensor_asym8s" ,
65+ "cadence::dequantize_per_tensor_asym8u" ,
66+ "cadence::dequantize_per_tensor_asym32s" ,
67+ "cadence::dequantize_per_tensor_asym16u" ,
68+ "cadence::linalg_vector_norm" ,
69+ "cadence::quantized_conv2d_nchw" , # We should only support per_tensor variant, should remove
70+ "cadence::quantized_w8a32_conv" ,
71+ "cadence::quantize_per_tensor_asym32s" ,
72+ "cadence::quantized_relu" , # We should only support per_tensor variant, should remove
73+ "cadence::linalg_svd" ,
74+ "cadence::quantized_conv2d_nhwc" , # We should only support per_tensor variant, should remove
75+ "cadence::idma_copy" ,
76+ "cadence::quantize_per_tensor_asym16u" ,
77+ "cadence::dequantize_per_tensor_asym8s" ,
78+ "cadence::quantize_per_tensor_asym16s" ,
79+ "cadence::dequantize_per_tensor_asym16s" ,
80+ "cadence::quantized_softmax" ,
81+ "cadence::idma_wait" ,
82+ "cadence::quantized_w8a32_gru" ,
83+ "cadence::quantized_layer_norm" , # We should only support per_tensor variant, should remove
84+ }
85+
86+ ref_impls = get_registered_ref_implementations ()
87+ warn_impls = []
88+ error_impls = []
89+ for op_name in _REGISTERED_META_KERNELS :
90+ # Strip the namespace prefix if present (e.g., "cadence::" -> "")
91+ op_name_clean = op_name .split ("::" )[- 1 ] if "::" in op_name else op_name
92+
93+ if op_name_clean not in ref_impls :
94+ if op_name in _WARN_ONLY :
95+ warn_impls .append (op_name )
96+ elif op_name not in _SKIP_OPS :
97+ error_impls .append (op_name )
98+
99+ if warn_impls :
100+ warn_msg = (
101+ f"The following { len (warn_impls )} meta kernel registrations are missing reference implementations:\n "
102+ + "\n " .join (f" - { op } " for op in warn_impls )
103+ + "\n \n Please add reference implementations in ref_implementations.py using "
104+ + "@impl_tracked(m, '<op_name>')."
105+ )
106+ logging .warning (warn_msg )
107+
108+ if error_impls :
109+ error_msg = (
110+ f"The following { len (error_impls )} meta kernel registrations are missing reference implementations:\n "
111+ + "\n " .join (f" - { op } " for op in error_impls )
112+ + "\n \n Please add reference implementations in ref_implementations.py using "
113+ + "@impl_tracked(m, '<op_name>')."
114+ )
115+
116+ raise RuntimeError (error_msg )
117+
118+
119+ # Wrap register_fake to track all registrations
120+ def register_fake (
121+ op_name : str ,
122+ ) -> Callable [[Callable [..., _OUTPUTS_TYPE ]], Callable [..., _OUTPUTS_TYPE ]]:
123+ """
124+ Wrapped version of register_fake that tracks all meta kernel registrations.
125+ This enables validation that all meta kernels have reference implementations.
126+ """
127+ global _REGISTERED_META_KERNELS
128+ _REGISTERED_META_KERNELS .add (op_name )
129+ return _register_fake_original (op_name )
130+
131+
24132lib .define (
25133 "quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
26134)
@@ -2406,7 +2514,9 @@ def idma_load_impl(
24062514 task_num : int = 0 ,
24072515 channel : int = 0 ,
24082516) -> torch .Tensor :
2409- return copy_idma_copy_impl (src , task_num , channel )
2517+ res = copy_idma_copy_impl (src , task_num , channel )
2518+ assert isinstance (res , torch .Tensor )
2519+ return res
24102520
24112521
24122522@register_fake ("cadence::idma_store" )
@@ -2415,7 +2525,9 @@ def idma_store_impl(
24152525 task_num : int = 0 ,
24162526 channel : int = 0 ,
24172527) -> torch .Tensor :
2418- return copy_idma_copy_impl (src , task_num , channel )
2528+ res = copy_idma_copy_impl (src , task_num , channel )
2529+ assert isinstance (res , torch .Tensor )
2530+ return res
24192531
24202532
24212533@register_fake ("cadence::roi_align_box_processor" )
@@ -2671,3 +2783,8 @@ def quantized_w8a32_gru_meta(
26712783 b_h_scale : float ,
26722784) -> torch .Tensor :
26732785 return inputs .new_empty ((2 , hidden .shape [- 1 ]), dtype = inputs .dtype )
2786+
2787+
2788+ # Validate that all meta kernels have reference implementations
2789+ # This is called at module import time to catch missing implementations early
2790+ _validate_ref_impl_exists ()
0 commit comments