|
5 | 5 | from typing import Dict |
6 | 6 |
|
7 | 7 | from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes |
| 8 | +from .type_normalization import normalize_types |
8 | 9 |
|
9 | 10 | logger: Logger = logging.getLogger(__name__) |
10 | 11 |
|
@@ -129,94 +130,12 @@ def _handle_ctypes_call( |
129 | 130 | return val |
130 | 131 |
|
131 | 132 |
|
132 | | -def _get_base_type_and_depth(ir_type): |
133 | | - """Get the base type for pointer types.""" |
134 | | - cur_type = ir_type |
135 | | - depth = 0 |
136 | | - while isinstance(cur_type, ir.PointerType): |
137 | | - depth += 1 |
138 | | - cur_type = cur_type.pointee |
139 | | - return cur_type, depth |
140 | | - |
141 | | - |
142 | | -def _deref_to_depth(func, builder, val, target_depth): |
143 | | - """Dereference a pointer to a certain depth.""" |
144 | | - |
145 | | - cur_val = val |
146 | | - cur_type = val.type |
147 | | - |
148 | | - for depth in range(target_depth): |
149 | | - if not isinstance(val.type, ir.PointerType): |
150 | | - logger.error("Cannot dereference further, non-pointer type") |
151 | | - return None |
152 | | - |
153 | | - # dereference with null check |
154 | | - pointee_type = cur_type.pointee |
155 | | - null_check_block = builder.block |
156 | | - not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}") |
157 | | - merge_block = func.append_basic_block(name=f"deref_merge_{depth}") |
158 | | - |
159 | | - null_ptr = ir.Constant(cur_type, None) |
160 | | - is_not_null = builder.icmp_signed("!=", cur_val, null_ptr) |
161 | | - logger.debug(f"Inserted null check for pointer at depth {depth}") |
162 | | - |
163 | | - builder.cbranch(is_not_null, not_null_block, merge_block) |
164 | | - |
165 | | - builder.position_at_end(not_null_block) |
166 | | - dereferenced_val = builder.load(cur_val) |
167 | | - logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}") |
168 | | - builder.branch(merge_block) |
169 | | - |
170 | | - builder.position_at_end(merge_block) |
171 | | - phi = builder.phi(pointee_type, name=f"deref_result_{depth}") |
172 | | - |
173 | | - zero_value = ( |
174 | | - ir.Constant(pointee_type, 0) |
175 | | - if isinstance(pointee_type, ir.IntType) |
176 | | - else ir.Constant(pointee_type, None) |
177 | | - ) |
178 | | - phi.add_incoming(zero_value, null_check_block) |
179 | | - |
180 | | - phi.add_incoming(dereferenced_val, not_null_block) |
181 | | - |
182 | | - # Continue with phi result |
183 | | - cur_val = phi |
184 | | - cur_type = pointee_type |
185 | | - return cur_val |
186 | | - |
187 | | - |
188 | | -def _normalize_types(func, builder, lhs, rhs): |
189 | | - """Normalize types for comparison.""" |
190 | | - |
191 | | - logger.info(f"Normalizing types: {lhs.type} vs {rhs.type}") |
192 | | - if isinstance(lhs.type, ir.IntType) and isinstance(rhs.type, ir.IntType): |
193 | | - if lhs.type.width < rhs.type.width: |
194 | | - lhs = builder.sext(lhs, rhs.type) |
195 | | - else: |
196 | | - rhs = builder.sext(rhs, lhs.type) |
197 | | - return lhs, rhs |
198 | | - elif not isinstance(lhs.type, ir.PointerType) and not isinstance( |
199 | | - rhs.type, ir.PointerType |
200 | | - ): |
201 | | - logger.error(f"Type mismatch: {lhs.type} vs {rhs.type}") |
202 | | - return None, None |
203 | | - else: |
204 | | - lhs_base, lhs_depth = _get_base_type_and_depth(lhs.type) |
205 | | - rhs_base, rhs_depth = _get_base_type_and_depth(rhs.type) |
206 | | - if lhs_base == rhs_base: |
207 | | - if lhs_depth < rhs_depth: |
208 | | - rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth) |
209 | | - elif rhs_depth < lhs_depth: |
210 | | - lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth) |
211 | | - return _normalize_types(func, builder, lhs, rhs) |
212 | | - |
213 | | - |
214 | 133 | def _handle_comparator(func, builder, op, lhs, rhs): |
215 | 134 | """Handle comparison operations.""" |
216 | 135 |
|
217 | 136 | # NOTE: For now assume same types |
218 | 137 | if lhs.type != rhs.type: |
219 | | - lhs, rhs = _normalize_types(func, builder, lhs, rhs) |
| 138 | + lhs, rhs = normalize_types(func, builder, lhs, rhs) |
220 | 139 |
|
221 | 140 | if lhs is None or rhs is None: |
222 | 141 | return None |
|
0 commit comments