1- from typing import Any , ClassVar , Optional , Type , TypeVar , Union , cast
1+ from typing import Any , ClassVar , Optional , Type , TypeVar , Union
22
3- from abc import ABC , abstractmethod
3+ from abc import ABC
44from dataclasses import dataclass
5- from enum import Enum
65
76import sympy
8- import torch
9-
10- from .. import ops
117
128from . import context
139from . import dtype
10+ from .shaped_type import ShapedType , ShapedDataType
1411
1512__all__ = [
1613 "backed_sym_index_type" ,
1714 "sym" ,
1815 "BoundedRelation" ,
1916 "EqualRelation" ,
20- "Grid" ,
2117 "IndexingContext" ,
2218 "IndexRelation" ,
2319 "IndexExpr" ,
2420 "IndexSymbol" ,
25- "InputBuffer" ,
26- "KernelBuffer" ,
27- "OutputBuffer" ,
2821 "SymIndex" ,
29- "TemporaryBuffer" ,
3022]
3123
3224DataType = dtype .DataType
@@ -74,270 +66,12 @@ def __getattr__(self, n):
7466SymbolicDimable = Union [str , IndexExpr ]
7567SymbolicShapeable = tuple [SymbolicDimable ]
7668SymbolicShapeExpr = tuple [IndexExpr ]
77-
78-
79- def make_symbolic_shape (elements : SymbolicShapeable ) -> SymbolicShapeExpr :
80- return tuple (
81- index_symbol (expr ) if isinstance (expr , str ) else expr for expr in elements
82- )
83-
84-
85- ###############################################################################
86- # Grid
87- ###############################################################################
88-
89-
90- class _GridMeta (type ):
91- """Meta-class for a symbolically shaped grid."""
92-
93- def __new__ (
94- mcls ,
95- name : str ,
96- bases ,
97- dct ,
98- * ,
99- symbolic_shape : Optional [SymbolicShapeExpr ],
100- ):
101- new_class = type .__new__ (mcls , name , bases , dct )
102- new_class .symbolic_shape = symbolic_shape
103- new_class .rank = len (symbolic_shape ) if symbolic_shape is not None else None
104- new_class .__qualname__ = repr (new_class )
105- return new_class
106-
107- def __repr__ (self ):
108- if self .symbolic_shape :
109- return f"Grid[{ ', ' .join (repr (s ) for s in self .symbolic_shape )} ]"
110- else :
111- return "Grid"
112-
113-
114- class Grid (metaclass = _GridMeta , symbolic_shape = None ):
115- """Grid with bounding symbolic shape information in the type."""
116-
117- symbolic_shape : ClassVar [Optional [SymbolicShapeExpr ]]
118- # TODO: dims should also allow dynamic dimensions.
119- dims : list [int ]
120- rank : int
121-
122- def __init__ (self ):
123- # Resolve the symbolic shape to concrete values.
124- idxc = IndexingContext .current ()
125- if self .symbolic_shape :
126- dims = [idxc .get_static_value (dim ) for dim in self .symbolic_shape ]
127- if None in dims :
128- raise ValueError (f"NYI: Dynamic dims in Grid" )
129- self .dims = cast (list [int ], dims )
130- else :
131- self .dims = []
132-
133- # Shadow the type rank with the actual, which makes it concrete
134- # for the generic case.
135- self .rank = len (self .dims )
136-
137- def __class_getitem__ (
138- cls , symbolic_shape : Union [SymbolicDimable , tuple [SymbolicShapeable ]]
139- ) -> Type ["Grid" ]:
140- if not isinstance (symbolic_shape , tuple ):
141- symbolic_shape = (symbolic_shape ,)
142- return cast (Grid , _make_shaped_grid (cls , make_symbolic_shape (symbolic_shape )))
143-
144- def __repr__ (self ):
145- return f"{ repr (type (self ))} ({ ', ' .join (str (i ) for i in self .dims )} )"
146-
147- def __getitem__ (self , index : int ) -> int :
148- return self .dims [index ]
149-
150- def __len__ (self ) -> int :
151- return len (self .dims )
152-
153- def __iter__ (self ):
154- return iter (self .dims )
155-
156-
157- def _make_shaped_grid (cls : Type [Grid ], symbolic_shape : tuple [IndexExpr ]):
158- class ShapedGrid (Grid , symbolic_shape = symbolic_shape ):
159- ...
160-
161- return ShapedGrid
162-
163-
164- ###############################################################################
165- # KernelBuffer
166- ###############################################################################
167-
16869Dims = list [Union [None , IndexSymbol , int ]]
16970
170-
171- class KernelBufferUsage (Enum ):
172- NONE = 0
173- INPUT = 1
174- OUTPUT = 2
175- TEMPORARY = 3
176-
177- @staticmethod
178- def _type_name (v ) -> str :
179- if v == KernelBufferUsage .NONE :
180- return "KernelBuffer"
181- elif v == KernelBufferUsage .INPUT :
182- return "InputBuffer"
183- elif v == KernelBufferUsage .OUTPUT :
184- return "OutputBuffer"
185- elif v == KernelBufferUsage .TEMPORARY :
186- return "TemporaryBuffer"
187- else :
188- raise AssertionError (f"uncovered KernelBufferUsage enum ({ v } )" )
189-
190-
191- class _KernelBufferMeta (type ):
192- """Meta-class for kernel buffers.
193-
194- This lets us specialize with symbolic shape information.
195- """
196-
197- element_type : DataType
198- usage : KernelBufferUsage
199- symbolic_shape : Optional [SymbolicShapeExpr ]
200- rank : Optional [int ]
201-
202- def __new__ (
203- mcls ,
204- name : str ,
205- bases ,
206- dct ,
207- ):
208- element_type = dct .get ("element_type" ) or DefaultDataType
209- dct ["element_type" ] = element_type
210- usage = dct .get ("usage" ) or KernelBufferUsage .NONE
211- dct ["usage" ] = usage
212- if "usage" not in dct :
213- dct ["usage" ] = KernelBufferUsage .NONE
214- symbolic_shape = dct .get ("symbolic_shape" )
215- dct ["symbolic_shape" ] = symbolic_shape
216- dct ["rank" ] = len (symbolic_shape ) if symbolic_shape is not None else None
217- dct ["__qualname__" ] = _kernel_buffer_type_repr (
218- element_type = element_type , usage = usage , symbolic_shape = symbolic_shape
219- )
220- new_class = type .__new__ (mcls , name , bases , dct )
221- return new_class
222-
223- def new_subtype (
224- cls : Type [SubtypeT ],
225- * ,
226- element_type : Union [NotSetType , DataType ] = NotSet ,
227- symbolic_shape : Union [NotSetType , Optional [SymbolicShapeable ]] = NotSet ,
228- usage : Union [NotSetType , KernelBufferUsage ] = NotSet ,
229- ) -> Type [SubtypeT ]:
230- init_element_type = (
231- element_type if element_type is not NotSet else cls .element_type
232- )
233- init_symbolic_shape = (
234- symbolic_shape if symbolic_shape is not NotSet else cls .symbolic_shape
235- )
236- init_usage = usage if usage is not NotSet else cls .usage
237-
238- class Subtype (cls ):
239- element_type = init_element_type
240- symbolic_shape = make_symbolic_shape (init_symbolic_shape )
241- usage = init_usage
242-
243- return Subtype
244-
245- def of (cls : Type [SubtypeT ], element_type : Union [Any , DataType ]) -> Type [SubtypeT ]:
246- return cls .new_subtype (element_type = element_type )
247-
248- def __repr__ (cls ):
249- return _kernel_buffer_type_repr (
250- element_type = cls .element_type ,
251- usage = cls .usage ,
252- symbolic_shape = cls .symbolic_shape ,
253- )
254-
255-
256- def is_kernel_buffer_meta_derived (t : type ) -> bool :
257- return isinstance (t , _KernelBufferMeta )
258-
259-
260- def _kernel_buffer_type_repr (
261- * ,
262- element_type : DataType ,
263- usage : KernelBufferUsage ,
264- symbolic_shape : Optional [tuple [IndexExpr ]],
265- ) -> str :
266- root = KernelBufferUsage ._type_name (usage )
267- if symbolic_shape :
268- stem = f"{ root } [{ ', ' .join (repr (s ) for s in symbolic_shape )} ]"
269- else :
270- stem = f"{ root } "
271- if element_type != DefaultDataType :
272- stem += f".of({ element_type } )"
273- return stem
274-
275-
276- class KernelBuffer (metaclass = _KernelBufferMeta ):
277- """Represents a buffer in global memory.
278-
279- Top level kernels always operate on global memory via these
280- buffers, and the primary operations that can be performed on
281- them are loads/stores and DMAs to some form of compute
282- capable local buffer.
283-
284- When executing eagerly, these are backed by a normal torch
285- Tensor. When compiling, an appropriate duck-typed proxy
286- is used.
287- """
288-
289- usage : ClassVar [KernelBufferUsage ]
290- symbolic_shape : ClassVar [Optional [SymbolicShapeExpr ]]
291- rank : Optional [int ]
292-
293- def __init__ (self , tensor : torch .Tensor ):
294- assert isinstance (tensor , torch .Tensor ), f"Expected Tensor but got { tensor } "
295- type_rank = type (self ).rank
296- tensor_rank = len (tensor .shape )
297- if type_rank is not None and type_rank != tensor_rank :
298- raise ValueError (
299- f"Cannot create { type (self )} (tensor({ tensor .shape } )): mismatched symbolic rank"
300- )
301- self ._tensor = tensor
302- self .rank = tensor_rank
303-
304- def __class_getitem__ (
305- cls , symbolic_shape : Union [IndexExpr , SymbolicShapeExpr ]
306- ) -> Type ["KernelBuffer" ]:
307- if not isinstance (symbolic_shape , tuple ):
308- symbolic_shape = (symbolic_shape ,)
309- return cast (
310- cls , cls .new_subtype (symbolic_shape = make_symbolic_shape (symbolic_shape ))
311- )
312-
313- def __repr__ (self ):
314- return f"{ type (self )} ({ self ._tensor } )"
315-
316- def __setitem__ (self , key , item ):
317- ops .kernel_buffer_setitem (self , key , item )
318-
319- def __getitem__ (self , key ):
320- return ops .kernel_buffer_getitem (self , key )
321-
322-
323- class InputBuffer (KernelBuffer ):
324- usage = KernelBufferUsage .INPUT
325-
326-
327- class OutputBuffer (KernelBuffer ):
328- usage = KernelBufferUsage .OUTPUT
329-
330-
331- class TemporaryBuffer (KernelBuffer ):
332- usage = KernelBufferUsage .TEMPORARY
333-
334-
33571###############################################################################
33672# IndexingContext
33773###############################################################################
33874
339- ShapedType = Union [Type [KernelBuffer ], Type [Grid ]]
340-
34175
34276@dataclass (slots = True )
34377class _ShapedBinding :
@@ -377,7 +111,7 @@ def __init__(self):
377111 # Indexed by .instance
378112 self .shaped_bindings : dict [Any , _ShapedBinding ] = {}
379113 self .dyn_dims : list [IndexSymbol ] = []
380- self .frozen_subs : list [IndexSymbol , int ] = []
114+ self .frozen_subs : list [tuple [ IndexSymbol , int ] ] = []
381115 self .unbacked_symbols : list [IndexSymbol ] = []
382116
383117 def next_dyn_dim (self ) -> IndexSymbol :
@@ -390,9 +124,7 @@ def new_unbacked_symbol(self) -> IndexSymbol:
390124 self .unbacked_symbols .append (s )
391125 return s
392126
393- def bind_shaped (
394- self , instance : Any , shaped_type : ShapedType , dims : Dims
395- ) -> _ShapedBinding :
127+ def bind_shaped (self , instance : Any , shaped_type : ShapedType , dims : Dims ) -> None :
396128 if instance in self .shaped_bindings :
397129 raise ValueError (f"Argument binding { instance } is already bound" )
398130 symbolic_shape = shaped_type .symbolic_shape
@@ -406,7 +138,7 @@ def bind_shaped(
406138 )
407139 self .shaped_bindings [instance ] = binding
408140
409- def bind_constant (self , sym : IndexSymbol , value : int ):
141+ def bind_constant (self , sym : IndexSymbol , value : int ) -> None :
410142 try :
411143 self ._bind_symbol (sym , value )
412144 except ValueError :
0 commit comments