Skip to content

Commit d38d73d

Browse files
committed
Move handle_comparator to type_normalization
1 parent 0a65717 commit d38d73d

File tree

2 files changed

+35
-35
lines changed

2 files changed

+35
-35
lines changed

pythonbpf/expr/expr_pass.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Dict
66

77
from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes
8-
from .type_normalization import normalize_types, convert_to_bool
8+
from .type_normalization import convert_to_bool, handle_comparator
99

1010
logger: Logger = logging.getLogger(__name__)
1111

@@ -130,37 +130,6 @@ def _handle_ctypes_call(
130130
return val
131131

132132

133-
def _handle_comparator(func, builder, op, lhs, rhs):
134-
"""Handle comparison operations."""
135-
136-
# NOTE: For now assume same types
137-
if lhs.type != rhs.type:
138-
lhs, rhs = normalize_types(func, builder, lhs, rhs)
139-
140-
if lhs is None or rhs is None:
141-
return None
142-
143-
comparison_ops = {
144-
ast.Eq: "==",
145-
ast.NotEq: "!=",
146-
ast.Lt: "<",
147-
ast.LtE: "<=",
148-
ast.Gt: ">",
149-
ast.GtE: ">=",
150-
ast.Is: "==",
151-
ast.IsNot: "!=",
152-
}
153-
154-
if type(op) not in comparison_ops:
155-
logger.error(f"Unsupported comparison operator: {type(op)}")
156-
return None
157-
158-
predicate = comparison_ops[type(op)]
159-
result = builder.icmp_signed(predicate, lhs, rhs)
160-
logger.debug(f"Comparison result: {result}")
161-
return result, ir.IntType(1)
162-
163-
164133
def _handle_compare(
165134
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None
166135
):
@@ -194,7 +163,7 @@ def _handle_compare(
194163

195164
lhs, _ = lhs
196165
rhs, _ = rhs
197-
return _handle_comparator(func, builder, cond.ops[0], lhs, rhs)
166+
return handle_comparator(func, builder, cond.ops[0], lhs, rhs)
198167

199168

200169
def _handle_unary_op(

pythonbpf/expr/type_normalization.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
11
from llvmlite import ir
22
import logging
3+
import ast
34

45
logger = logging.getLogger(__name__)
56

7+
COMPARISON_OPS = {
8+
ast.Eq: "==",
9+
ast.NotEq: "!=",
10+
ast.Lt: "<",
11+
ast.LtE: "<=",
12+
ast.Gt: ">",
13+
ast.GtE: ">=",
14+
ast.Is: "==",
15+
ast.IsNot: "!=",
16+
}
17+
618

719
def _get_base_type_and_depth(ir_type):
820
"""Get the base type for pointer types."""
@@ -60,7 +72,7 @@ def _deref_to_depth(func, builder, val, target_depth):
6072
return cur_val
6173

6274

63-
def normalize_types(func, builder, lhs, rhs):
75+
def _normalize_types(func, builder, lhs, rhs):
6476
"""Normalize types for comparison."""
6577

6678
logger.info(f"Normalizing types: {lhs.type} vs {rhs.type}")
@@ -83,7 +95,7 @@ def normalize_types(func, builder, lhs, rhs):
8395
rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
8496
elif rhs_depth < lhs_depth:
8597
lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
86-
return normalize_types(func, builder, lhs, rhs)
98+
return _normalize_types(func, builder, lhs, rhs)
8799

88100

89101
def convert_to_bool(builder, val):
@@ -95,3 +107,22 @@ def convert_to_bool(builder, val):
95107
else:
96108
zero = ir.Constant(val.type, 0)
97109
return builder.icmp_signed("!=", val, zero)
110+
111+
112+
def handle_comparator(func, builder, op, lhs, rhs):
113+
"""Handle comparison operations."""
114+
115+
if lhs.type != rhs.type:
116+
lhs, rhs = _normalize_types(func, builder, lhs, rhs)
117+
118+
if lhs is None or rhs is None:
119+
return None
120+
121+
if type(op) not in COMPARISON_OPS:
122+
logger.error(f"Unsupported comparison operator: {type(op)}")
123+
return None
124+
125+
predicate = COMPARISON_OPS[type(op)]
126+
result = builder.icmp_signed(predicate, lhs, rhs)
127+
logger.debug(f"Comparison result: {result}")
128+
return result, ir.IntType(1)

0 commit comments

Comments
 (0)