@@ -139,7 +139,59 @@ class const:
139
139
pass
140
140
141
141
142
- class constexpr :
142
+ class base_value :
143
+ """Base class of values that exist in the triton IR (i.e. not constexprs).
144
+ """
145
+ type : base_type
146
+
147
+ def _flatten_ir (self , handles : List [ir .value ]) -> None :
148
+ """Flatten frontend value into a sequence of mlir handles, which are appended
149
+ to the output list
150
+ """
151
+ raise NotImplementedError
152
+
153
+
154
+ class base_type :
155
+
156
+ def __eq__ (self , other ):
157
+ raise NotImplementedError ("Types must implement __eq__" )
158
+
159
+ def __ne__ (self , other ):
160
+ return not (self == other )
161
+
162
+ def _unflatten_ir (self , handles : List [ir .value ], cursor : int ) -> Tuple [base_value , int ]:
163
+ """Build a frontend value with the current dtype, wrapping a list of existing handles.
164
+ cursor is the index of the first handle relevant to this value, and the function
165
+ should return the updated cursor position after any handles consumed by the created value.
166
+ """
167
+ raise NotImplementedError
168
+
169
+ def mangle (self ) -> str :
170
+ raise NotImplementedError (f"NYI: Type mangling for type { self .__class__ } " )
171
+
172
+ def _flatten_ir_types (self , builder : ir .builder , out : List [ir .type ]) -> None :
173
+ raise NotImplementedError
174
+
175
+
176
+ class constexpr_type (base_type ):
177
+
178
+ def __init__ (self , value ):
179
+ self .value = value
180
+
181
+ def __repr__ (self ) -> str :
182
+ return f"constexpr[{ self .value } ]"
183
+
184
+ def mangle (self ) -> str :
185
+ return repr (self )
186
+
187
+ def _flatten_ir_types (self , builder : ir .builder , out : List [ir .type ]) -> None :
188
+ return
189
+
190
+ def _unflatten_ir (self , handles : List [ir .value ], cursor : int ) -> Tuple [base_value , int ]:
191
+ return constexpr (self .value ), cursor
192
+
193
+
194
+ class constexpr (base_value ):
143
195
"""
144
196
This class is used to store a value that is known at compile-time.
145
197
"""
@@ -149,11 +201,14 @@ def __init__(self, value):
149
201
self .value = value .value
150
202
else :
151
203
self .value = value
152
- self .type = constexpr
204
+ self .type = constexpr_type ( value )
153
205
154
206
def __repr__ (self ) -> str :
155
207
return f"constexpr[{ self .value } ]"
156
208
209
+ def _flatten_ir (self , handles : List [ir .value ]) -> None :
210
+ return
211
+
157
212
def __index__ (self ):
158
213
return self .value
159
214
@@ -322,40 +377,6 @@ def check_bit_width(value, shift_value):
322
377
)
323
378
324
379
325
- class base_value :
326
- """Base class of values that exist in the triton IR (i.e. not constexprs).
327
- """
328
- type : base_type
329
-
330
- def _flatten_ir (self , handles : List [ir .value ]) -> None :
331
- """Flatten frontend value into a sequence of mlir handles, which are appended
332
- to the output list
333
- """
334
- raise NotImplementedError
335
-
336
-
337
- class base_type :
338
-
339
- def __eq__ (self , other ):
340
- raise NotImplementedError ("Types must implement __eq__" )
341
-
342
- def __ne__ (self , other ):
343
- return not (self == other )
344
-
345
- def _unflatten_ir (self , handles : List [ir .value ], cursor : int ) -> Tuple [base_value , int ]:
346
- """Build a frontend value with the current dtype, wrapping a list of existing handles.
347
- cursor is the index of the first handle relevant to this value, and the function
348
- should return the updated cursor position after any handles consumed by the created value.
349
- """
350
- raise NotImplementedError
351
-
352
- def mangle (self ) -> str :
353
- raise NotImplementedError (f"NYI: Type mangling for type { self .__class__ } " )
354
-
355
- def _flatten_ir_types (self , builder : ir .builder , out : List [ir .type ]) -> None :
356
- raise NotImplementedError
357
-
358
-
359
380
# -----------------------
360
381
# dtype
361
382
# -----------------------
0 commit comments