11from llvmlite import ir
22import logging
3+ import ast
34
45logger = logging .getLogger (__name__ )
56
7+ COMPARISON_OPS = {
8+ ast .Eq : "==" ,
9+ ast .NotEq : "!=" ,
10+ ast .Lt : "<" ,
11+ ast .LtE : "<=" ,
12+ ast .Gt : ">" ,
13+ ast .GtE : ">=" ,
14+ ast .Is : "==" ,
15+ ast .IsNot : "!=" ,
16+ }
17+
618
719def _get_base_type_and_depth (ir_type ):
820 """Get the base type for pointer types."""
@@ -60,7 +72,7 @@ def _deref_to_depth(func, builder, val, target_depth):
6072 return cur_val
6173
6274
63- def normalize_types (func , builder , lhs , rhs ):
75+ def _normalize_types (func , builder , lhs , rhs ):
6476 """Normalize types for comparison."""
6577
6678 logger .info (f"Normalizing types: { lhs .type } vs { rhs .type } " )
@@ -83,7 +95,7 @@ def normalize_types(func, builder, lhs, rhs):
8395 rhs = _deref_to_depth (func , builder , rhs , rhs_depth - lhs_depth )
8496 elif rhs_depth < lhs_depth :
8597 lhs = _deref_to_depth (func , builder , lhs , lhs_depth - rhs_depth )
86- return normalize_types (func , builder , lhs , rhs )
98+ return _normalize_types (func , builder , lhs , rhs )
8799
88100
89101def convert_to_bool (builder , val ):
@@ -95,3 +107,22 @@ def convert_to_bool(builder, val):
95107 else :
96108 zero = ir .Constant (val .type , 0 )
97109 return builder .icmp_signed ("!=" , val , zero )
110+
111+
112+ def handle_comparator (func , builder , op , lhs , rhs ):
113+ """Handle comparison operations."""
114+
115+ if lhs .type != rhs .type :
116+ lhs , rhs = _normalize_types (func , builder , lhs , rhs )
117+
118+ if lhs is None or rhs is None :
119+ return None
120+
121+ if type (op ) not in COMPARISON_OPS :
122+ logger .error (f"Unsupported comparison operator: { type (op )} " )
123+ return None
124+
125+ predicate = COMPARISON_OPS [type (op )]
126+ result = builder .icmp_signed (predicate , lhs , rhs )
127+ logger .debug (f"Comparison result: { result } " )
128+ return result , ir .IntType (1 )
0 commit comments