Skip to content

Commit e62557b

Browse files
committed
Seperate type_normalization from expr_pass
1 parent ee90ee9 commit e62557b

File tree

2 files changed

+88
-83
lines changed

2 files changed

+88
-83
lines changed

pythonbpf/expr/expr_pass.py

Lines changed: 2 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Dict
66

77
from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes
8+
from .type_normalization import normalize_types
89

910
logger: Logger = logging.getLogger(__name__)
1011

@@ -129,94 +130,12 @@ def _handle_ctypes_call(
129130
return val
130131

131132

132-
def _get_base_type_and_depth(ir_type):
133-
"""Get the base type for pointer types."""
134-
cur_type = ir_type
135-
depth = 0
136-
while isinstance(cur_type, ir.PointerType):
137-
depth += 1
138-
cur_type = cur_type.pointee
139-
return cur_type, depth
140-
141-
142-
def _deref_to_depth(func, builder, val, target_depth):
143-
"""Dereference a pointer to a certain depth."""
144-
145-
cur_val = val
146-
cur_type = val.type
147-
148-
for depth in range(target_depth):
149-
if not isinstance(val.type, ir.PointerType):
150-
logger.error("Cannot dereference further, non-pointer type")
151-
return None
152-
153-
# dereference with null check
154-
pointee_type = cur_type.pointee
155-
null_check_block = builder.block
156-
not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}")
157-
merge_block = func.append_basic_block(name=f"deref_merge_{depth}")
158-
159-
null_ptr = ir.Constant(cur_type, None)
160-
is_not_null = builder.icmp_signed("!=", cur_val, null_ptr)
161-
logger.debug(f"Inserted null check for pointer at depth {depth}")
162-
163-
builder.cbranch(is_not_null, not_null_block, merge_block)
164-
165-
builder.position_at_end(not_null_block)
166-
dereferenced_val = builder.load(cur_val)
167-
logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}")
168-
builder.branch(merge_block)
169-
170-
builder.position_at_end(merge_block)
171-
phi = builder.phi(pointee_type, name=f"deref_result_{depth}")
172-
173-
zero_value = (
174-
ir.Constant(pointee_type, 0)
175-
if isinstance(pointee_type, ir.IntType)
176-
else ir.Constant(pointee_type, None)
177-
)
178-
phi.add_incoming(zero_value, null_check_block)
179-
180-
phi.add_incoming(dereferenced_val, not_null_block)
181-
182-
# Continue with phi result
183-
cur_val = phi
184-
cur_type = pointee_type
185-
return cur_val
186-
187-
188-
def _normalize_types(func, builder, lhs, rhs):
189-
"""Normalize types for comparison."""
190-
191-
logger.info(f"Normalizing types: {lhs.type} vs {rhs.type}")
192-
if isinstance(lhs.type, ir.IntType) and isinstance(rhs.type, ir.IntType):
193-
if lhs.type.width < rhs.type.width:
194-
lhs = builder.sext(lhs, rhs.type)
195-
else:
196-
rhs = builder.sext(rhs, lhs.type)
197-
return lhs, rhs
198-
elif not isinstance(lhs.type, ir.PointerType) and not isinstance(
199-
rhs.type, ir.PointerType
200-
):
201-
logger.error(f"Type mismatch: {lhs.type} vs {rhs.type}")
202-
return None, None
203-
else:
204-
lhs_base, lhs_depth = _get_base_type_and_depth(lhs.type)
205-
rhs_base, rhs_depth = _get_base_type_and_depth(rhs.type)
206-
if lhs_base == rhs_base:
207-
if lhs_depth < rhs_depth:
208-
rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
209-
elif rhs_depth < lhs_depth:
210-
lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
211-
return _normalize_types(func, builder, lhs, rhs)
212-
213-
214133
def _handle_comparator(func, builder, op, lhs, rhs):
215134
"""Handle comparison operations."""
216135

217136
# NOTE: For now assume same types
218137
if lhs.type != rhs.type:
219-
lhs, rhs = _normalize_types(func, builder, lhs, rhs)
138+
lhs, rhs = normalize_types(func, builder, lhs, rhs)
220139

221140
if lhs is None or rhs is None:
222141
return None
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from llvmlite import ir
2+
import logging
3+
4+
logger = logging.getLogger(__name__)
5+
6+
7+
def _get_base_type_and_depth(ir_type):
8+
"""Get the base type for pointer types."""
9+
cur_type = ir_type
10+
depth = 0
11+
while isinstance(cur_type, ir.PointerType):
12+
depth += 1
13+
cur_type = cur_type.pointee
14+
return cur_type, depth
15+
16+
17+
def _deref_to_depth(func, builder, val, target_depth):
18+
"""Dereference a pointer to a certain depth."""
19+
20+
cur_val = val
21+
cur_type = val.type
22+
23+
for depth in range(target_depth):
24+
if not isinstance(val.type, ir.PointerType):
25+
logger.error("Cannot dereference further, non-pointer type")
26+
return None
27+
28+
# dereference with null check
29+
pointee_type = cur_type.pointee
30+
null_check_block = builder.block
31+
not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}")
32+
merge_block = func.append_basic_block(name=f"deref_merge_{depth}")
33+
34+
null_ptr = ir.Constant(cur_type, None)
35+
is_not_null = builder.icmp_signed("!=", cur_val, null_ptr)
36+
logger.debug(f"Inserted null check for pointer at depth {depth}")
37+
38+
builder.cbranch(is_not_null, not_null_block, merge_block)
39+
40+
builder.position_at_end(not_null_block)
41+
dereferenced_val = builder.load(cur_val)
42+
logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}")
43+
builder.branch(merge_block)
44+
45+
builder.position_at_end(merge_block)
46+
phi = builder.phi(pointee_type, name=f"deref_result_{depth}")
47+
48+
zero_value = (
49+
ir.Constant(pointee_type, 0)
50+
if isinstance(pointee_type, ir.IntType)
51+
else ir.Constant(pointee_type, None)
52+
)
53+
phi.add_incoming(zero_value, null_check_block)
54+
55+
phi.add_incoming(dereferenced_val, not_null_block)
56+
57+
# Continue with phi result
58+
cur_val = phi
59+
cur_type = pointee_type
60+
return cur_val
61+
62+
63+
def normalize_types(func, builder, lhs, rhs):
64+
"""Normalize types for comparison."""
65+
66+
logger.info(f"Normalizing types: {lhs.type} vs {rhs.type}")
67+
if isinstance(lhs.type, ir.IntType) and isinstance(rhs.type, ir.IntType):
68+
if lhs.type.width < rhs.type.width:
69+
lhs = builder.sext(lhs, rhs.type)
70+
else:
71+
rhs = builder.sext(rhs, lhs.type)
72+
return lhs, rhs
73+
elif not isinstance(lhs.type, ir.PointerType) and not isinstance(
74+
rhs.type, ir.PointerType
75+
):
76+
logger.error(f"Type mismatch: {lhs.type} vs {rhs.type}")
77+
return None, None
78+
else:
79+
lhs_base, lhs_depth = _get_base_type_and_depth(lhs.type)
80+
rhs_base, rhs_depth = _get_base_type_and_depth(rhs.type)
81+
if lhs_base == rhs_base:
82+
if lhs_depth < rhs_depth:
83+
rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
84+
elif rhs_depth < lhs_depth:
85+
lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
86+
return normalize_types(func, builder, lhs, rhs)

0 commit comments

Comments
 (0)