3232 F64Type ,
3333)
3434
35- # TODO: Have a way upstream to check if a floating point type .
36- FLOAT_TYPES_ASM = {
37- "bf16" ,
38- "f16" ,
39- "f32" ,
40- "f64" ,
35+ # TODO: Use FloatType from upstream when available .
36+ FLOAT_BITWIDTHS = {
37+ "bf16" : 16 ,
38+ "f16" : 16 ,
39+ "f32" : 32 ,
40+ "f64" : 64 ,
4141 # TODO: FP8 types.
4242}
4343
@@ -87,28 +87,54 @@ def __init__(
8787
8888class _ScalarBuilder :
8989 def is_floating_point_type (self , t : IrType ) -> bool :
90- return str (t ) in FLOAT_TYPES_ASM
90+ # TODO: Use FloatType from upstream when available.
91+ return str (t ) in FLOAT_BITWIDTHS
9192
9293 def is_integer_type (self , t : IrType ) -> bool :
9394 return IntegerType .isinstance (t )
9495
9596 def is_index_type (self , t : IrType ) -> bool :
9697 return IndexType .isinstance (t )
9798
98- def promote (self , value : Value , to_type : IrType ) -> Value :
99- value_type = value .type
99+ def get_typeclass (self , t : IrType , index_same_as_integer = False ) -> str :
100+ # If this is a vector type, get the element type.
101+ if isinstance (t , VectorType ):
102+ t = t .element_type
103+ if self .is_floating_point_type (t ):
104+ return "float"
105+ if self .is_integer_type (t ):
106+ return "integer"
107+ if self .is_index_type (t ):
108+ return "integer" if index_same_as_integer else "index"
109+ raise CodegenError (f"Unknown typeclass for type `{ t } `" )
110+
111+ def get_float_bitwidth (self , t : IrType ) -> int :
112+ # If this is a vector type, get the element type.
113+ if isinstance (t , VectorType ):
114+ t = t .element_type
115+ return FLOAT_BITWIDTHS [str (t )]
116+
117+ def to_dtype (self , value : IRProxyValue , dtype : IrType ) -> IRProxyValue :
118+ value_type = value .ir_value .type
119+ # Create a vector type for dtype if value is a vector.
120+ to_type = dtype
121+ if isinstance (value_type , VectorType ):
122+ to_type = VectorType .get (value_type .shape , dtype )
123+
100124 # Short-circuit if already the right type.
101125 if value_type == to_type :
102126 return value
103127
104- attr_name = f"promote_{ value_type } _to_{ to_type } "
128+ value_typeclass = self .get_typeclass (value_type )
129+ to_typeclass = self .get_typeclass (dtype )
130+ attr_name = f"to_dtype_{ value_typeclass } _to_{ to_typeclass } "
105131 try :
106132 handler = getattr (self , attr_name )
107133 except AttributeError :
108134 raise CodegenError (
109135 f"No implemented path to implicitly promote scalar `{ value_type } ` to `{ to_type } ` (tried '{ attr_name } ')"
110136 )
111- return handler (value , to_type )
137+ return IRProxyValue ( handler (value . ir_value , to_type ) )
112138
113139 def constant_attr (self , val : int | float , element_type : IrType ) -> Attribute :
114140 if self .is_integer_type (element_type ) or self .is_index_type (element_type ):
@@ -153,7 +179,7 @@ def binary_arithmetic(
153179 f"Cannot perform binary arithmetic operation '{ op } ' between { lhs_ir_type } and { rhs_ir_type } due to element type mismatch"
154180 )
155181
156- typeclass = "float" if self .is_floating_point_type (lhs_ir_type ) else "integer"
182+ typeclass = self .get_typeclass (lhs_ir_type , True )
157183 attr_name = f"binary_{ op } _{ typeclass } "
158184 try :
159185 handler = getattr (self , attr_name )
@@ -176,9 +202,7 @@ def binary_vector_arithmetic(
176202 f"Cannot perform binary arithmetic operation '{ op } ' between { lhs_ir .type } and { rhs_ir .type } due to element type mismatch"
177203 )
178204
179- typeclass = (
180- "float" if self .is_floating_point_type (lhs_element_type ) else "integer"
181- )
205+ typeclass = self .get_typeclass (lhs_element_type , True )
182206 attr_name = f"binary_{ op } _{ typeclass } "
183207 try :
184208 handler = getattr (self , attr_name )
@@ -190,7 +214,7 @@ def binary_vector_arithmetic(
190214
191215 def unary_arithmetic (self , op : str , val : IRProxyValue ) -> IRProxyValue :
192216 val_ir_type = val .ir_value .type
193- typeclass = "float" if self .is_floating_point_type (val_ir_type ) else "integer"
217+ typeclass = self .get_typeclass (val_ir_type , True )
194218 attr_name = f"unary_{ op } _{ typeclass } "
195219 try :
196220 handler = getattr (self , attr_name )
@@ -203,9 +227,7 @@ def unary_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue:
203227 def unary_vector_arithmetic (self , op : str , val : IRProxyValue ) -> IRProxyValue :
204228 val_ir = val .ir_value
205229 val_element_type = VectorType (val_ir .type ).element_type
206- typeclass = (
207- "float" if self .is_floating_point_type (val_element_type ) else "integer"
208- )
230+ typeclass = self .get_typeclass (val_element_type , True )
209231 attr_name = f"unary_{ op } _{ typeclass } "
210232 try :
211233 handler = getattr (self , attr_name )
@@ -217,10 +239,33 @@ def unary_vector_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue:
217239
218240 ### Specializations
219241
220- def promote_index_to_f32 (self , value : Value , to_type : IrType ) -> Value :
221- i32_type = IntegerType .get_signless (32 )
222- i32 = arith_d .index_cast (i32_type , value )
223- return arith_d .sitofp (to_type , i32 )
242+ # Casting
243+ def to_dtype_index_to_integer (self , value : Value , to_type : IrType ) -> Value :
244+ return arith_d .index_cast (to_type , value )
245+
246+ def to_dtype_index_to_float (self , value : Value , to_type : IrType ) -> Value :
247+ # Cast index to integer, and then ask for a integer to float cast.
248+ # TODO: I don't really know how to query the machine bitwidth here,
249+ # so using 64.
250+ casted_to_int = arith_d .index_cast (IntegerType .get_signless (64 ), value )
251+ return self .to_dtype (IRProxyValue (casted_to_int ), to_type ).ir_value
252+
253+ def to_dtype_integer_to_float (self , value : Value , to_type : IrType ) -> Value :
254+ # sitofp
255+ casted_to_float = arith_d .sitofp (to_type , value )
256+ return self .to_dtype (IRProxyValue (casted_to_float ), to_type ).ir_value
257+
258+ def to_dtype_float_to_float (self , value : Value , to_type : IrType ) -> Value :
259+ # Check bitwidth to determine if we need to extend or narrow
260+ from_type = value .type
261+ from_bitwidth = self .get_float_bitwidth (from_type )
262+ to_bitwidth = self .get_float_bitwidth (to_type )
263+ if from_bitwidth < to_bitwidth :
264+ return arith_d .extf (to_type , value )
265+ elif from_bitwidth > to_bitwidth :
266+ return arith_d .truncf (to_type , value )
267+ else :
268+ raise CodegenError (f"NYI: Cast from { from_type } to { to_type } " )
224269
225270 # Binary integer/integer arithmetic.
226271 def binary_add_integer (self , lhs : IRProxyValue , rhs : IRProxyValue ) -> IRProxyValue :
0 commit comments