4747 _is_float_type ,
4848 _is_index_type ,
4949 _is_integer_like_type ,
50+ _is_signed_or_signless_type ,
5051)
5152
5253# TK infrastructure imports.
5859 broadcast ,
5960 cast ,
6061 conditional ,
62+ eq ,
6163 exp2 ,
6264 extract ,
6365 extract_slice ,
@@ -444,8 +446,8 @@ def handle_div(lhs: Value, rhs: Value) -> OpResult:
444446 element_type = get_type_or_element_type (lhs .type )
445447 if _is_float_type (element_type ):
446448 result = arith_d .divf (lhs , rhs )
447- elif _is_integer_like_type (element_type ) and (
448- element_type . is_signed or element_type . is_signless
449+ elif _is_integer_like_type (element_type ) and _is_signed_or_signless_type (
450+ element_type
449451 ):
450452 result = arith_d .divsi (lhs , rhs )
451453 else :
@@ -456,12 +458,28 @@ def handle_div(lhs: Value, rhs: Value) -> OpResult:
456458@handle_binary_op (operator .and_ )
457459def handle_and (lhs : Value , rhs : Value ) -> OpResult :
458460 element_type = get_type_or_element_type (lhs .type )
459- if _is_integer_like_type (element_type ) and (
460- element_type . is_signed or element_type . is_signless
461+ if _is_integer_like_type (element_type ) and _is_signed_or_signless_type (
462+ element_type
461463 ):
462464 result = arith_d .andi (lhs , rhs )
463465 else :
464- raise ValidationError (f"Found unhandled operand type for div: { element_type } " )
466+ raise ValidationError (
467+ f"Found unhandled operand type for bitwise and: { element_type } "
468+ )
469+ return result
470+
471+
472+ @handle_binary_op (operator .or_ )
473+ def handle_or (lhs : Value , rhs : Value ) -> OpResult :
474+ element_type = get_type_or_element_type (lhs .type )
475+ if _is_integer_like_type (element_type ) and _is_signed_or_signless_type (
476+ element_type
477+ ):
478+ result = arith_d .ori (lhs , rhs )
479+ else :
480+ raise ValidationError (
481+ f"Found unhandled operand type for bitwise or: { element_type } "
482+ )
465483 return result
466484
467485
@@ -470,8 +488,8 @@ def handle_gt(lhs: Value, rhs: Value) -> OpResult:
470488 element_type = get_type_or_element_type (lhs .type )
471489 if _is_float_type (element_type ):
472490 result = arith_d .cmpi (arith_d .CmpFPredicate .OGT , lhs , rhs )
473- elif _is_integer_like_type (element_type ) and (
474- element_type . is_signed or element_type . is_signless
491+ elif _is_integer_like_type (element_type ) and _is_signed_or_signless_type (
492+ element_type
475493 ):
476494 result = arith_d .cmpi (arith_d .CmpIPredicate .sgt , lhs , rhs )
477495 else :
@@ -484,8 +502,8 @@ def handle_ge(lhs: Value, rhs: Value) -> OpResult:
484502 element_type = get_type_or_element_type (lhs .type )
485503 if _is_float_type (element_type ):
486504 result = arith_d .cmpi (arith_d .CmpFPredicate .OGE , lhs , rhs )
487- elif _is_integer_like_type (element_type ) and (
488- element_type . is_signed or element_type . is_signless
505+ elif _is_integer_like_type (element_type ) and _is_signed_or_signless_type (
506+ element_type
489507 ):
490508 result = arith_d .cmpi (arith_d .CmpIPredicate .sge , lhs , rhs )
491509 else :
@@ -498,8 +516,8 @@ def handle_lt(lhs: Value, rhs: Value) -> OpResult:
498516 element_type = get_type_or_element_type (lhs .type )
499517 if _is_float_type (element_type ):
500518 result = arith_d .cmpi (arith_d .CmpFPredicate .OLT , lhs , rhs )
501- elif _is_integer_like_type (element_type ) and (
502- element_type . is_signed or element_type . is_signless
519+ elif _is_integer_like_type (element_type ) and _is_signed_or_signless_type (
520+ element_type
503521 ):
504522 result = arith_d .cmpi (arith_d .CmpIPredicate .slt , lhs , rhs )
505523 else :
@@ -512,22 +530,36 @@ def handle_le(lhs: Value, rhs: Value) -> OpResult:
512530 element_type = get_type_or_element_type (lhs .type )
513531 if _is_float_type (element_type ):
514532 result = arith_d .cmpi (arith_d .CmpFPredicate .OLE , lhs , rhs )
515- elif _is_integer_like_type (element_type ) and (
516- element_type . is_signed or element_type . is_signless
533+ elif _is_integer_like_type (element_type ) and _is_signed_or_signless_type (
534+ element_type
517535 ):
518536 result = arith_d .cmpi (arith_d .CmpIPredicate .sle , lhs , rhs )
519537 else :
520538 raise ValidationError (f"Found unhandled operand type for le: { element_type } " )
521539 return result
522540
523541
542+ @handle_binary_op ([operator .eq , eq ])
543+ def handle_le (lhs : Value , rhs : Value ) -> OpResult :
544+ element_type = get_type_or_element_type (lhs .type )
545+ if _is_float_type (element_type ):
546+ result = arith_d .cmpf (arith_d .CmpFPredicate .OEQ , lhs , rhs )
547+ elif _is_integer_like_type (element_type ) and _is_signed_or_signless_type (
548+ element_type
549+ ):
550+ result = arith_d .cmpi (arith_d .CmpIPredicate .eq , lhs , rhs )
551+ else :
552+ raise ValidationError (f"Found unhandled operand type for eq: { element_type } " )
553+ return result
554+
555+
524556@handle_binary_op (maximum )
525557def handle_maximum (lhs : Value , rhs : Value ) -> OpResult :
526558 element_type = get_type_or_element_type (lhs .type )
527559 if _is_float_type (element_type ):
528560 result = arith_d .maximumf (lhs , rhs )
529- elif _is_integer_like_type (element_type ) and (
530- element_type . is_signed or element_type . is_signless
561+ elif _is_integer_like_type (element_type ) and _is_signed_or_signless_type (
562+ element_type
531563 ):
532564 result = arith_d .maxsi (lhs , rhs )
533565 else :
@@ -542,8 +574,8 @@ def handle_minimum(lhs: Value, rhs: Value) -> OpResult:
542574 element_type = get_type_or_element_type (lhs .type )
543575 if _is_float_type (element_type ):
544576 result = arith_d .minimumf (lhs , rhs )
545- elif _is_integer_like_type (element_type ) and (
546- element_type . is_signed or element_type . is_signless
577+ elif _is_integer_like_type (element_type ) and _is_signed_or_signless_type (
578+ element_type
547579 ):
548580 result = arith_d .minsi (lhs , rhs )
549581 else :
0 commit comments