11import operator
2+ from abc import abstractmethod
23from copy import deepcopy
34from functools import partialmethod , cached_property
45from typing import Union , Optional
@@ -184,7 +185,7 @@ def _arith_CmpIPredicateAttr(predicate: str | Attribute, context: Context):
184185 }
185186 if isinstance (predicate , Attribute ):
186187 return predicate
187- assert predicate in predicates , f"predicate { predicate } not in predicates"
188+ assert predicate in predicates , f"{ predicate = } not in predicates"
188189 return IntegerAttr .get (
189190 IntegerType .get_signless (64 , context = context ), predicates [predicate ]
190191 )
@@ -219,7 +220,7 @@ def _arith_CmpFPredicateAttr(predicate: str | Attribute, context: Context):
219220 }
220221 if isinstance (predicate , Attribute ):
221222 return predicate
222- assert predicate in predicates , f"predicate { predicate } not in predicates"
223+ assert predicate in predicates , f"{ predicate = } not in predicates"
223224 return IntegerAttr .get (
224225 IntegerType .get_signless (64 , context = context ), predicates [predicate ]
225226 )
@@ -247,13 +248,14 @@ def _binary_op(
247248 if loc is None :
248249 loc = get_user_code_loc ()
249250 if not isinstance (rhs , lhs .__class__ ):
250- rhs = lhs .__class__ (rhs , dtype = lhs .type )
251+ lhs , rhs = lhs .coerce (rhs )
252+ if lhs .type != rhs .type :
253+ raise ValueError (f"{ lhs = } { rhs = } must have the same type." )
254+
255+ assert op in {"add" , "sub" , "mul" , "cmp" , "truediv" , "floordiv" , "mod" }
251256
252- assert op in {"add" , "sub" , "mul" , "cmp" }
253257 if op == "cmp" :
254258 assert predicate is not None
255- if lhs .type != rhs .type :
256- raise ValueError (f"{ lhs = } { rhs = } must have the same type." )
257259
258260 if lhs .fold () and lhs .fold ():
259261 klass = lhs .__class__
@@ -267,15 +269,30 @@ def _binary_op(
267269 op = predicate
268270 op = operator .attrgetter (op )(operator )
269271 return klass (op (lhs , rhs ), fold = True )
272+
273+ if op == "truediv" :
274+ op = "div"
275+ if op == "mod" :
276+ op = "rem"
277+
278+ op = op .capitalize ()
279+ if _is_floating_point_type (lhs .dtype ):
280+ if op == "Floordiv" :
281+ raise ValueError (f"floordiv not supported for { lhs = } " )
282+ op += "F"
283+ elif _is_integer_like_type (lhs .dtype ):
284+ # TODO(max): this needs to all be regularized
285+ if "div" in op .lower () or "rem" in op .lower ():
286+ if not lhs .dtype .is_signless :
287+ raise ValueError (f"{ op .lower ()} i not supported for { lhs = } " )
288+ if op == "Floordiv" :
289+ op = "FloorDiv"
290+ op += "S"
291+ op += "I"
270292 else :
271- op = op .capitalize ()
272- lhs , rhs = lhs , rhs
273- if _is_floating_point_type (lhs .dtype ):
274- op = getattr (arith_dialect , f"{ op } FOp" )
275- elif _is_integer_like_type (lhs .dtype ):
276- op = getattr (arith_dialect , f"{ op } IOp" )
277- else :
278- raise NotImplementedError (f"Unsupported '{ op } ' operands: { lhs } , { rhs } " )
293+ raise NotImplementedError (f"Unsupported '{ op } ' operands: { lhs } , { rhs } " )
294+
295+ op = getattr (arith_dialect , f"{ op } Op" )
279296
280297 if predicate is not None :
281298 if _is_floating_point_type (lhs .dtype ):
@@ -315,6 +332,15 @@ def is_constant(self) -> bool:
315332 self .owner .opview , arith_dialect .ConstantOp
316333 )
317334
335+ @property
336+ @abstractmethod
337+ def literal_value (self ):
338+ pass
339+
340+ @abstractmethod
341+ def coerce (self , other ) -> tuple ["ArithValue" , "ArithValue" ]:
342+ pass
343+
318344 def fold (self ) -> bool :
319345 return self .is_constant () and self ._fold
320346
@@ -329,6 +355,10 @@ def __repr__(self):
329355 __add__ = partialmethod (_binary_op , op = "add" )
330356 __sub__ = partialmethod (_binary_op , op = "sub" )
331357 __mul__ = partialmethod (_binary_op , op = "mul" )
358+ __truediv__ = partialmethod (_binary_op , op = "truediv" )
359+ __floordiv__ = partialmethod (_binary_op , op = "floordiv" )
360+ __mod__ = partialmethod (_binary_op , op = "mod" )
361+
332362 __radd__ = partialmethod (_binary_op , op = "add" )
333363 __rsub__ = partialmethod (_binary_op , op = "sub" )
334364 __rmul__ = partialmethod (_binary_op , op = "mul" )
@@ -401,3 +431,10 @@ def __int__(self):
401431
402432 def __float__ (self ):
403433 return float (self .literal_value )
434+
435+ def coerce (self , other ) -> tuple ["Scalar" , "Scalar" ]:
436+ if isinstance (other , (int , float , bool )):
437+ other = Scalar (other , dtype = self .dtype )
438+ else :
439+ raise ValueError (f"can't coerce { other = } to Scalar" )
440+ return self , other
0 commit comments