@@ -139,7 +139,59 @@ class const:
139139 pass
140140
141141
142- class constexpr :
142+ class base_value :
143+ """Base class of values that exist in the triton IR (i.e. not constexprs).
144+ """
145+ type : base_type
146+
147+ def _flatten_ir (self , handles : List [ir .value ]) -> None :
148+ """Flatten frontend value into a sequence of mlir handles, which are appended
149+ to the output list
150+ """
151+ raise NotImplementedError
152+
153+
154+ class base_type :
155+
156+ def __eq__ (self , other ):
157+ raise NotImplementedError ("Types must implement __eq__" )
158+
159+ def __ne__ (self , other ):
160+ return not (self == other )
161+
162+ def _unflatten_ir (self , handles : List [ir .value ], cursor : int ) -> Tuple [base_value , int ]:
163+ """Build a frontend value with the current dtype, wrapping a list of existing handles.
164+ cursor is the index of the first handle relevant to this value, and the function
165+ should return the updated cursor position after any handles consumed by the created value.
166+ """
167+ raise NotImplementedError
168+
169+ def mangle (self ) -> str :
170+ raise NotImplementedError (f"NYI: Type mangling for type { self .__class__ } " )
171+
172+ def _flatten_ir_types (self , builder : ir .builder , out : List [ir .type ]) -> None :
173+ raise NotImplementedError
174+
175+
176+ class constexpr_type (base_type ):
177+
178+ def __init__ (self , value ):
179+ self .value = value
180+
181+ def __repr__ (self ) -> str :
182+ return f"constexpr[{ self .value } ]"
183+
184+ def mangle (self ) -> str :
185+ return repr (self )
186+
187+ def _flatten_ir_types (self , builder : ir .builder , out : List [ir .type ]) -> None :
188+ return
189+
190+ def _unflatten_ir (self , handles : List [ir .value ], cursor : int ) -> Tuple [base_value , int ]:
191+ return constexpr (self .value ), cursor
192+
193+
194+ class constexpr (base_value ):
143195 """
144196 This class is used to store a value that is known at compile-time.
145197 """
@@ -149,11 +201,14 @@ def __init__(self, value):
149201 self .value = value .value
150202 else :
151203 self .value = value
152- self .type = constexpr
204+ self .type = constexpr_type ( value )
153205
154206 def __repr__ (self ) -> str :
155207 return f"constexpr[{ self .value } ]"
156208
209+ def _flatten_ir (self , handles : List [ir .value ]) -> None :
210+ return
211+
157212 def __index__ (self ):
158213 return self .value
159214
@@ -322,40 +377,6 @@ def check_bit_width(value, shift_value):
322377 )
323378
324379
325- class base_value :
326- """Base class of values that exist in the triton IR (i.e. not constexprs).
327- """
328- type : base_type
329-
330- def _flatten_ir (self , handles : List [ir .value ]) -> None :
331- """Flatten frontend value into a sequence of mlir handles, which are appended
332- to the output list
333- """
334- raise NotImplementedError
335-
336-
337- class base_type :
338-
339- def __eq__ (self , other ):
340- raise NotImplementedError ("Types must implement __eq__" )
341-
342- def __ne__ (self , other ):
343- return not (self == other )
344-
345- def _unflatten_ir (self , handles : List [ir .value ], cursor : int ) -> Tuple [base_value , int ]:
346- """Build a frontend value with the current dtype, wrapping a list of existing handles.
347- cursor is the index of the first handle relevant to this value, and the function
348- should return the updated cursor position after any handles consumed by the created value.
349- """
350- raise NotImplementedError
351-
352- def mangle (self ) -> str :
353- raise NotImplementedError (f"NYI: Type mangling for type { self .__class__ } " )
354-
355- def _flatten_ir_types (self , builder : ir .builder , out : List [ir .type ]) -> None :
356- raise NotImplementedError
357-
358-
359380# -----------------------
360381# dtype
361382# -----------------------
0 commit comments