53
53
CompiledConfig = Callable [..., _R ]
54
54
55
55
56
+ @dataclasses .dataclass (frozen = True )
57
+ class BoundKernelInMemoryCacheKey :
58
+ specialization_key : tuple [Hashable , ...]
59
+ extra_results : tuple [Hashable , ...]
60
+
61
+
56
62
class Kernel (Generic [_R ]):
57
63
def __init__ (
58
64
self ,
@@ -80,7 +86,7 @@ def __init__(
80
86
Config (** c ) if isinstance (c , dict ) else c # pyright: ignore[reportArgumentType]
81
87
for c in configs or []
82
88
]
83
- self ._bound_kernels : dict [Hashable , BoundKernel ] = {}
89
+ self ._bound_kernels : dict [BoundKernelInMemoryCacheKey , BoundKernel ] = {}
84
90
self ._specialize_extra : dict [
85
91
Hashable , list [Callable [[Sequence [object ]], Hashable ]]
86
92
] = {}
@@ -105,6 +111,25 @@ def __init__(
105
111
else :
106
112
self ._annotations .append (ann )
107
113
114
+ def _get_bound_kernel_cache_key (
115
+ self , args : tuple [object , ...], signature : tuple [Hashable , ...]
116
+ ) -> BoundKernelInMemoryCacheKey | None :
117
+ extra_fns = self ._specialize_extra .get (signature )
118
+ if extra_fns is not None :
119
+ extra_results : tuple [Hashable , ...] = tuple ([s (args ) for s in extra_fns ])
120
+ return BoundKernelInMemoryCacheKey (signature , extra_results )
121
+ return None
122
+
123
+ def _create_bound_kernel_cache_key (
124
+ self ,
125
+ bound_kernel : BoundKernel ,
126
+ args : tuple [object , ...],
127
+ signature : tuple [Hashable , ...],
128
+ ) -> BoundKernelInMemoryCacheKey :
129
+ self ._specialize_extra [signature ] = extra_fns = bound_kernel ._specialize_extra ()
130
+ extra_results : tuple [Hashable , ...] = tuple ([s (args ) for s in extra_fns ])
131
+ return BoundKernelInMemoryCacheKey (signature , extra_results )
132
+
108
133
def bind (self , args : tuple [object , ...]) -> BoundKernel [_R ]:
109
134
"""
110
135
Bind the given arguments to the Kernel and return a BoundKernel object.
@@ -119,28 +144,22 @@ def bind(self, args: tuple[object, ...]) -> BoundKernel[_R]:
119
144
assert isinstance (args , list ), "args must be a tuple or list"
120
145
args = tuple (args )
121
146
signature = self .specialization_key (args )
122
- extra_fns = self ._specialize_extra .get (signature )
123
- if extra_fns is not None :
124
- extra_results : list [Hashable ] = [s (args ) for s in extra_fns ]
125
- signature_extra = (* signature , * extra_results )
126
- bound_kernel = self ._bound_kernels .get (signature_extra )
127
- else :
128
- signature_extra = None
129
- bound_kernel = None
147
+ cache_key = self ._get_bound_kernel_cache_key (args , signature )
148
+ bound_kernel = (
149
+ None if cache_key is None else self ._bound_kernels .get (cache_key , None )
150
+ )
130
151
if bound_kernel is None :
131
152
normalized_args : tuple [object , ...] = self .normalize_args (* args )
132
153
if len (normalized_args ) != len (args ):
133
154
# we had default args that needed to be applied
134
155
bound_kernel = self .bind (normalized_args )
135
156
else :
136
157
bound_kernel = BoundKernel (self , args )
137
- if signature_extra is None :
138
- self . _specialize_extra [ signature ] = extra_fns = (
139
- bound_kernel . _specialize_extra ()
158
+ if cache_key is None :
159
+ cache_key = self . _create_bound_kernel_cache_key (
160
+ bound_kernel , args , signature
140
161
)
141
- extra_results = [s (args ) for s in extra_fns ]
142
- signature_extra = (* signature , * extra_results )
143
- self ._bound_kernels [signature_extra ] = bound_kernel
162
+ self ._bound_kernels [cache_key ] = bound_kernel
144
163
return bound_kernel
145
164
146
165
def specialization_key (self , args : Sequence [object ]) -> tuple [Hashable , ...]:
@@ -608,16 +627,18 @@ def kernel(
608
627
609
628
610
629
def _tensor_key (fn : Kernel , obj : torch .Tensor ) -> Hashable :
630
+ # NOTE: If a machine has two different gpu types on the same machine,
631
+ # obj.device.type will incorrectly hit
611
632
if fn .settings .static_shapes :
612
633
return (
613
634
obj .dtype ,
614
- obj .device ,
635
+ obj .device . type ,
615
636
(* obj .size (),),
616
637
(* obj .stride (),),
617
638
)
618
639
return (
619
640
obj .dtype ,
620
- obj .device ,
641
+ obj .device . type ,
621
642
# 0, 1, or >=2 specialization
622
643
tuple ([min (s , 2 ) for s in obj .size ()]),
623
644
)
0 commit comments