1+ import operator
12from copy import deepcopy
23from functools import partialmethod , cached_property
34from typing import Union , Optional
@@ -116,7 +117,7 @@ class ArithValueMeta(type(Value)):
116117 """
117118
118119 def __call__ (cls , * args , ** kwargs ):
119- """Orchestrate the Python object protocol for Indexing dialect extension
120+ """Orchestrate the Python object protocol for mlir
120121 values in order to handle wrapper arbitrary Python objects.
121122
122123 Args:
@@ -132,7 +133,7 @@ def __call__(cls, *args, **kwargs):
132133 if len (args ) != 1 :
133134 raise ValueError ("Only one non-kw arg supported." )
134135 arg = args [0 ]
135- arg_copy = None
136+ fold = None
136137 if isinstance (arg , (OpView , Operation , Value )):
137138 # wrap an already created Value (or op the produces a Value)
138139 if isinstance (arg , (Operation , OpView )):
@@ -143,13 +144,15 @@ def __call__(cls, *args, **kwargs):
143144 dtype = kwargs .get ("dtype" )
144145 if dtype is not None and not isinstance (dtype , Type ):
145146 raise ValueError (f"{ dtype = } is expected to be an ir.Type." )
147+ fold = kwargs .get ("fold" )
148+ if fold is not None and not isinstance (fold , bool ):
149+ raise ValueError (f"{ fold = } is expected to be a bool." )
146150 # If we're wrapping a numpy array (effectively a tensor literal),
147151 # then we want to make sure no one else has access to that memory.
148152 # Otherwise, the array will get funneled down to DenseElementsAttr.get,
149153 # which by default (through the Python buffer protocol) does not copy;
150154 # see mlir/lib/Bindings/Python/IRAttributes.cpp#L556
151- arg_copy = deepcopy (arg )
152- return constant (arg_copy , dtype )
155+ val = constant (deepcopy (arg ), dtype )
153156 else :
154157 raise NotImplementedError (f"{ cls .__name__ } doesn't support wrapping { arg } ." )
155158
@@ -161,7 +164,7 @@ def __call__(cls, *args, **kwargs):
161164 # the Python object protocol; first an object is new'ed and then
162165 # it is init'ed. Note we pass arg_copy here in case a subclass wants to
163166 # inspect the literal.
164- cls .__init__ (cls_obj , val )
167+ cls .__init__ (cls_obj , val , fold = fold )
165168 return cls_obj
166169
167170
@@ -252,14 +255,28 @@ def _binary_op(
252255 if lhs .type != rhs .type :
253256 raise ValueError (f"{ lhs = } { rhs = } must have the same type." )
254257
255- op = op .capitalize ()
256- lhs , rhs = lhs , rhs
257- if _is_floating_point_type (lhs .dtype ):
258- op = getattr (arith_dialect , f"{ op } FOp" )
259- elif _is_integer_like_type (lhs .dtype ):
260- op = getattr (arith_dialect , f"{ op } IOp" )
258+ if lhs .fold () and lhs .fold ():
259+ klass = lhs .__class__
260+ # if both operands are constants (results of an arith.constant op)
261+ # then both have a literal value (i.e. Python value).
262+ lhs , rhs = lhs .literal_value , rhs .literal_value
263+ # if we're folding constants (self._fold = True) then we just carry out
264+ # the corresponding operation on the literal values; e.g., operator.add.
265+ # note this is the same as op = operator.__dict__[op].
266+ if predicate is not None :
267+ op = predicate
268+ op = operator .attrgetter (op )(operator )
269+ return klass (op (lhs , rhs ), fold = True )
261270 else :
262- raise NotImplementedError (f"Unsupported '{ op } ' operands: { lhs } , { rhs } " )
271+ op = op .capitalize ()
272+ lhs , rhs = lhs , rhs
273+ if _is_floating_point_type (lhs .dtype ):
274+ op = getattr (arith_dialect , f"{ op } FOp" )
275+ elif _is_integer_like_type (lhs .dtype ):
276+ op = getattr (arith_dialect , f"{ op } IOp" )
277+ else :
278+ raise NotImplementedError (f"Unsupported '{ op } ' operands: { lhs } , { rhs } " )
279+
263280 if predicate is not None :
264281 if _is_floating_point_type (lhs .dtype ):
265282 # ordered comparison - see above
@@ -289,9 +306,18 @@ class ArithValue(Value, metaclass=ArithValueMeta):
289306 Value.__init__
290307 """
291308
292- def __init__ (self , val ):
309+ def __init__ (self , val , * , fold : Optional [bool ] = None ):
310+ self ._fold = fold if fold is not None else False
293311 super ().__init__ (val )
294312
313+ def is_constant (self ) -> bool :
314+ return isinstance (self .owner , Operation ) and isinstance (
315+ self .owner .opview , arith_dialect .ConstantOp
316+ )
317+
318+ def fold (self ) -> bool :
319+ return self .is_constant () and self ._fold
320+
295321 def __str__ (self ):
296322 return f"{ self .__class__ .__name__ } ({ self .get_name ()} , { self .type } )"
297323
@@ -306,8 +332,29 @@ def __repr__(self):
306332 __radd__ = partialmethod (_binary_op , op = "add" )
307333 __rsub__ = partialmethod (_binary_op , op = "sub" )
308334 __rmul__ = partialmethod (_binary_op , op = "mul" )
309- __eq__ = partialmethod (_binary_op , op = "cmp" , predicate = "eq" )
310- __ne__ = partialmethod (_binary_op , op = "cmp" , predicate = "ne" )
335+
336+ def __eq__ (self , other ):
337+ if not isinstance (other , self .__class__ ):
338+ try :
339+ other = self .__class__ (other , dtype = self .type )
340+ except NotImplementedError as e :
341+ assert "doesn't support wrapping" in str (e )
342+ return False
343+ if self is other :
344+ return True
345+ return _binary_op (self , other , op = "cmp" , predicate = "eq" )
346+
347+ def __ne__ (self , other ):
348+ if not isinstance (other , self .__class__ ):
349+ try :
350+ other = self .__class__ (other , dtype = self .type )
351+ except NotImplementedError as e :
352+ assert "doesn't support wrapping" in str (e )
353+ return True
354+ if self is other :
355+ return False
356+ return _binary_op (self , other , op = "cmp" , predicate = "ne" )
357+
311358 __le__ = partialmethod (_binary_op , op = "cmp" , predicate = "le" )
312359 __lt__ = partialmethod (_binary_op , op = "cmp" , predicate = "lt" )
313360 __ge__ = partialmethod (_binary_op , op = "cmp" , predicate = "ge" )
@@ -342,3 +389,15 @@ def isinstance(other: Value):
342389 or _is_index_type (other .type )
343390 or _is_complex_type (other .type )
344391 )
392+
393+ @cached_property
394+ def literal_value (self ) -> Union [int , float , bool ]:
395+ if not self .is_constant ():
396+ raise ValueError ("Can't build literal from non-constant Scalar" )
397+ return self .owner .opview .literal_value
398+
399+ def __int__ (self ):
400+ return int (self .literal_value )
401+
402+ def __float__ (self ):
403+ return float (self .literal_value )
0 commit comments