66
77import numpy as np
88from mlir .dialects import arith as arith_dialect
9+ from mlir .dialects import complex as complex_dialect
910from mlir .dialects ._arith_ops_ext import _is_integer_like_type
1011from mlir .dialects ._ods_common import get_op_result_or_value
1112from mlir .dialects .linalg .opdsl .lang .emitter import (
1920 Context ,
2021 DenseElementsAttr ,
2122 IndexType ,
23+ InsertionPoint ,
2224 IntegerAttr ,
2325 IntegerType ,
2426 Location ,
2830 Type ,
2931 Value ,
3032 register_attribute_builder ,
33+ ComplexType ,
34+ BF16Type ,
35+ F16Type ,
36+ F32Type ,
37+ F64Type ,
38+ FloatAttr ,
3139)
3240
33- from mlir_utils .util import get_result_or_results , maybe_cast , get_user_code_loc
41+ from mlir_utils .util import (
42+ get_result_or_results ,
43+ maybe_cast ,
44+ get_user_code_loc ,
45+ register_value_caster ,
46+ )
3447
3548try :
3649 from mlir_utils .dialects .arith import *
@@ -46,7 +59,8 @@ def constant(
4659 index : Optional [bool ] = None ,
4760 * ,
4861 loc : Location = None ,
49- ) -> arith_dialect .ConstantOp :
62+ ip : InsertionPoint = None ,
63+ ) -> Value :
5064 """Instantiate arith.constant with value `value`.
5165
5266 Args:
@@ -67,21 +81,62 @@ def constant(
6781 type = IndexType .get ()
6882 if type is None :
6983 type = infer_mlir_type (value )
70- elif RankedTensorType .isinstance (type ) and isinstance (value , (int , float , bool )):
84+
85+ assert type is not None
86+
87+ if _is_complex_type (type ):
88+ value = complex (value )
89+ return maybe_cast (
90+ get_result_or_results (
91+ complex_dialect .ConstantOp (
92+ type ,
93+ list (
94+ map (
95+ lambda x : FloatAttr .get (type .element_type , x ),
96+ [value .real , value .imag ],
97+ )
98+ ),
99+ loc = loc ,
100+ ip = ip ,
101+ )
102+ )
103+ )
104+
105+ if _is_floating_point_type (type ) and not isinstance (value , np .ndarray ):
106+ value = float (value )
107+
108+ if RankedTensorType .isinstance (type ) and isinstance (value , (int , float , bool )):
71109 ranked_tensor_type = RankedTensorType (type )
72- value = np .ones (
110+ value = np .full (
73111 ranked_tensor_type .shape ,
112+ value ,
74113 dtype = mlir_type_to_np_dtype (ranked_tensor_type .element_type ),
75114 )
76- assert type is not None
77115
78116 if isinstance (value , np .ndarray ):
79117 value = DenseElementsAttr .get (
80118 value ,
81119 type = type ,
82120 )
121+
83122 return maybe_cast (
84- get_result_or_results (arith_dialect .ConstantOp (type , value , loc = loc ))
123+ get_result_or_results (arith_dialect .ConstantOp (type , value , loc = loc , ip = ip ))
124+ )
125+
126+
127+ def index_cast (
128+ value : Value ,
129+ * ,
130+ to : Type = None ,
131+ loc : Location = None ,
132+ ip : InsertionPoint = None ,
133+ ) -> Value :
134+ if loc is None :
135+ loc = get_user_code_loc ()
136+ if to is None :
137+ to = IndexType .get ()
138+ return maybe_cast (
139+ get_result_or_results (arith_dialect .IndexCastOp (to , value , loc = loc , ip = ip ))
85140 )
86141
87142
@@ -231,6 +286,7 @@ def _binary_op(
231286 rhs : "ArithValue" ,
232287 op : str ,
233288 predicate : str = None ,
289+ signedness : str = None ,
234290 * ,
235291 loc : Location = None ,
236292) -> "ArithValue" :
@@ -247,12 +303,15 @@ def _binary_op(
247303 """
248304 if loc is None :
249305 loc = get_user_code_loc ()
250- if not isinstance (rhs , lhs .__class__ ):
306+ if (
307+ isinstance (rhs , Value )
308+ and lhs .type != rhs .type
309+ or isinstance (rhs , (float , int , bool , np .ndarray ))
310+ ):
251311 lhs , rhs = lhs .coerce (rhs )
252- if lhs .type != rhs .type :
253- raise ValueError (f"{ lhs = } { rhs = } must have the same type." )
312+ assert lhs .type == rhs .type , f"{ lhs = } { rhs = } must have the same type."
254313
255- assert op in {"add" , "sub" , "mul" , "cmp" , "truediv" , "floordiv" , "mod" }
314+ assert op in {"add" , "and" , "or" , " sub" , "mul" , "cmp" , "truediv" , "floordiv" , "mod" }
256315
257316 if op == "cmp" :
258317 assert predicate is not None
@@ -301,15 +360,20 @@ def _binary_op(
301360 elif _is_integer_like_type (lhs .dtype ):
302361 # eq, ne signs don't matter
303362 if predicate not in {"eq" , "ne" }:
304- if lhs . dtype . is_signed :
305- predicate = "s" + predicate
363+ if signedness is not None :
364+ predicate = signedness + predicate
306365 else :
307- predicate = "u" + predicate
366+ if lhs .dtype .is_signed :
367+ predicate = "s" + predicate
368+ else :
369+ predicate = "u" + predicate
308370 return lhs .__class__ (op (predicate , lhs , rhs , loc = loc ), dtype = lhs .dtype )
309371 else :
310372 return lhs .__class__ (op (lhs , rhs , loc = loc ), dtype = lhs .dtype )
311373
312374
375+ # TODO(max): these could be generic in the dtype
376+ # TODO(max): hit .verify() before constructing (maybe)
313377class ArithValue (Value , metaclass = ArithValueMeta ):
314378 """Class for functionality shared by Value subclasses that support
315379 arithmetic operations.
@@ -363,6 +427,9 @@ def __repr__(self):
363427 __rsub__ = partialmethod (_binary_op , op = "sub" )
364428 __rmul__ = partialmethod (_binary_op , op = "mul" )
365429
430+ __and__ = partialmethod (_binary_op , op = "and" )
431+ __or__ = partialmethod (_binary_op , op = "or" )
432+
366433 def __eq__ (self , other ):
367434 if not isinstance (other , self .__class__ ):
368435 try :
@@ -435,6 +502,14 @@ def __float__(self):
435502 def coerce (self , other ) -> tuple ["Scalar" , "Scalar" ]:
436503 if isinstance (other , (int , float , bool )):
437504 other = Scalar (other , dtype = self .dtype )
505+ elif isinstance (other , Scalar ) and _is_index_type (self .type ):
506+ other = index_cast (other )
507+ elif isinstance (other , Scalar ) and _is_index_type (other .type ):
508+ other = index_cast (other , to = self .type )
438509 else :
439- raise ValueError (f"can't coerce { other = } to Scalar " )
510+ raise ValueError (f"can't coerce { other = } to { self = } " )
440511 return self , other
512+
513+
514+ for t in [BF16Type , F16Type , F32Type , F64Type , IndexType , IntegerType , ComplexType ]:
515+ register_value_caster (t .static_typeid )(Scalar )
0 commit comments