66
77# pyre-strict
88
9+ import logging
910from math import prod
10- from typing import Optional , Tuple
11+ from typing import Callable , Optional , Tuple
1112
1213import torch
1314from executorch .backends .cadence .aot .utils import (
2122
2223lib = Library ("cadence" , "DEF" )
2324
25+ # Track meta kernels that have been registered
26+ _REGISTERED_META_KERNELS : set [str ] = set ()
27+
28+
29+ # Original register_fake function to use for registrations
30+ _register_fake_original = register_fake
31+
32+ _OUTPUTS_TYPE = torch .Tensor | tuple [torch .Tensor , ...]
33+
34+
35+ def _validate_ref_impl_exists () -> None :
36+ """
37+ Validates that all registered meta kernels have corresponding reference implementations.
38+ This is called at module initialization time after both files have been imported.
39+ """
40+ # Import here after module initialization to ensure ref_implementations has been loaded
41+ from executorch .backends .cadence .aot .ref_implementations import (
42+ get_registered_ref_implementations ,
43+ )
44+
45+ ref_impls = get_registered_ref_implementations ()
46+ missing_impls = []
47+
48+ for op_name in _REGISTERED_META_KERNELS :
49+ # Strip the namespace prefix if present (e.g., "cadence::" -> "")
50+ op_name_clean = op_name .split ("::" )[- 1 ] if "::" in op_name else op_name
51+
52+ if op_name_clean not in ref_impls :
53+ missing_impls .append (op_name )
54+
55+ if missing_impls :
56+ error_msg = (
57+ "The following meta kernel registrations are missing reference implementations:\n "
58+ + "\n " .join (f" - { op } " for op in missing_impls )
59+ + "\n \n Please add reference implementations in ref_implementations.py using "
60+ + "@impl_tracked(m, '<op_name>')."
61+ )
62+
63+ # TODO: T241466288 Make this an error once all meta kernels have reference implementations
64+ logging .warning (error_msg )
65+
66+
67+ # Wrap register_fake to track all registrations
68+ def register_fake (
69+ op_name : str ,
70+ ) -> Callable [[Callable [..., _OUTPUTS_TYPE ]], Callable [..., _OUTPUTS_TYPE ]]:
71+ """
72+ Wrapped version of register_fake that tracks all meta kernel registrations.
73+ This enables validation that all meta kernels have reference implementations.
74+ """
75+ global _REGISTERED_META_KERNELS
76+ _REGISTERED_META_KERNELS .add (op_name )
77+ return _register_fake_original (op_name )
78+
79+
2480lib .define (
2581 "quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
2682)
@@ -2406,7 +2462,9 @@ def idma_load_impl(
24062462 task_num : int = 0 ,
24072463 channel : int = 0 ,
24082464) -> torch .Tensor :
2409- return copy_idma_copy_impl (src , task_num , channel )
2465+ res = copy_idma_copy_impl (src , task_num , channel )
2466+ assert isinstance (res , torch .Tensor )
2467+ return res
24102468
24112469
24122470@register_fake ("cadence::idma_store" )
@@ -2415,7 +2473,9 @@ def idma_store_impl(
24152473 task_num : int = 0 ,
24162474 channel : int = 0 ,
24172475) -> torch .Tensor :
2418- return copy_idma_copy_impl (src , task_num , channel )
2476+ res = copy_idma_copy_impl (src , task_num , channel )
2477+ assert isinstance (res , torch .Tensor )
2478+ return res
24192479
24202480
24212481@register_fake ("cadence::roi_align_box_processor" )
@@ -2671,3 +2731,8 @@ def quantized_w8a32_gru_meta(
26712731 b_h_scale : float ,
26722732) -> torch .Tensor :
26732733 return inputs .new_empty ((2 , hidden .shape [- 1 ]), dtype = inputs .dtype )
2734+
2735+
2736+ # Validate that all meta kernels have reference implementations
2737+ # This is called at module import time to catch missing implementations early
2738+ _validate_ref_impl_exists ()
0 commit comments