|
9 | 9 | from pythonbpf.binary_ops import handle_binary_op |
10 | 10 | from pythonbpf.expr_pass import eval_expr, handle_expr |
11 | 11 |
|
| 12 | + |
12 | 13 | logger = logging.getLogger(__name__) |
13 | 14 |
|
14 | 15 |
|
@@ -350,6 +351,65 @@ def handle_if( |
350 | 351 | builder.position_at_end(merge_block) |
351 | 352 |
|
352 | 353 |
|
| 354 | +def handle_return( |
| 355 | + func, module, builder, stmt, map_sym_tab, local_sym_tab, struct_sym_tab, ret_type |
| 356 | +): |
| 357 | + if stmt.value is None: |
| 358 | + builder.ret(ir.Constant(ir.IntType(64), 0)) |
| 359 | + return True |
| 360 | + elif ( |
| 361 | + isinstance(stmt.value, ast.Call) |
| 362 | + and isinstance(stmt.value.func, ast.Name) |
| 363 | + and len(stmt.value.args) == 1 |
| 364 | + ): |
| 365 | + if isinstance(stmt.value.args[0], ast.Constant) and isinstance( |
| 366 | + stmt.value.args[0].value, int |
| 367 | + ): |
| 368 | + call_type = stmt.value.func.id |
| 369 | + if ctypes_to_ir(call_type) != ret_type: |
| 370 | + raise ValueError( |
| 371 | + "Return type mismatch: expected" |
| 372 | + f"{ctypes_to_ir(call_type)}, got {call_type}" |
| 373 | + ) |
| 374 | + else: |
| 375 | + builder.ret(ir.Constant(ret_type, stmt.value.args[0].value)) |
| 376 | + return True |
| 377 | + elif isinstance(stmt.value.args[0], ast.BinOp): |
| 378 | + # TODO: Should be routed through eval_expr |
| 379 | + val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab) |
| 380 | + if val is None: |
| 381 | + raise ValueError("Failed to evaluate return expression") |
| 382 | + if val[1] != ret_type: |
| 383 | + raise ValueError( |
| 384 | + f"Return type mismatch: expected {ret_type}, got {val[1]}" |
| 385 | + ) |
| 386 | + builder.ret(val[0]) |
| 387 | + return True |
| 388 | + elif isinstance(stmt.value.args[0], ast.Name): |
| 389 | + if stmt.value.args[0].id in local_sym_tab: |
| 390 | + var = local_sym_tab[stmt.value.args[0].id].var |
| 391 | + val = builder.load(var) |
| 392 | + if val.type != ret_type: |
| 393 | + raise ValueError( |
| 394 | + f"Return type mismatch: expected {ret_type}, got {val.type}" |
| 395 | + ) |
| 396 | + builder.ret(val) |
| 397 | + return True |
| 398 | + else: |
| 399 | + raise ValueError("Failed to evaluate return expression") |
| 400 | + elif isinstance(stmt.value, ast.Name): |
| 401 | + if stmt.value.id == "XDP_PASS": |
| 402 | + builder.ret(ir.Constant(ret_type, 2)) |
| 403 | + return True |
| 404 | + elif stmt.value.id == "XDP_DROP": |
| 405 | + builder.ret(ir.Constant(ret_type, 1)) |
| 406 | + return True |
| 407 | + else: |
| 408 | + raise ValueError("Failed to evaluate return expression") |
| 409 | + else: |
| 410 | + raise ValueError("Unsupported return value") |
| 411 | + |
| 412 | + |
353 | 413 | def process_stmt( |
354 | 414 | func, |
355 | 415 | module, |
@@ -383,60 +443,16 @@ def process_stmt( |
383 | 443 | func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab |
384 | 444 | ) |
385 | 445 | elif isinstance(stmt, ast.Return): |
386 | | - if stmt.value is None: |
387 | | - builder.ret(ir.Constant(ir.IntType(64), 0)) |
388 | | - did_return = True |
389 | | - elif ( |
390 | | - isinstance(stmt.value, ast.Call) |
391 | | - and isinstance(stmt.value.func, ast.Name) |
392 | | - and len(stmt.value.args) == 1 |
393 | | - ): |
394 | | - if isinstance(stmt.value.args[0], ast.Constant) and isinstance( |
395 | | - stmt.value.args[0].value, int |
396 | | - ): |
397 | | - call_type = stmt.value.func.id |
398 | | - if ctypes_to_ir(call_type) != ret_type: |
399 | | - raise ValueError( |
400 | | - "Return type mismatch: expected" |
401 | | - f"{ctypes_to_ir(call_type)}, got {call_type}" |
402 | | - ) |
403 | | - else: |
404 | | - builder.ret(ir.Constant(ret_type, stmt.value.args[0].value)) |
405 | | - did_return = True |
406 | | - elif isinstance(stmt.value.args[0], ast.BinOp): |
407 | | - # TODO: Should be routed through eval_expr |
408 | | - val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab) |
409 | | - if val is None: |
410 | | - raise ValueError("Failed to evaluate return expression") |
411 | | - if val[1] != ret_type: |
412 | | - raise ValueError( |
413 | | - f"Return type mismatch: expected {ret_type}, got {val[1]}" |
414 | | - ) |
415 | | - builder.ret(val[0]) |
416 | | - did_return = True |
417 | | - elif isinstance(stmt.value.args[0], ast.Name): |
418 | | - if stmt.value.args[0].id in local_sym_tab: |
419 | | - var = local_sym_tab[stmt.value.args[0].id].var |
420 | | - val = builder.load(var) |
421 | | - if val.type != ret_type: |
422 | | - raise ValueError( |
423 | | - f"Return type mismatch: expected {ret_type}, got {val.type}" |
424 | | - ) |
425 | | - builder.ret(val) |
426 | | - did_return = True |
427 | | - else: |
428 | | - raise ValueError("Failed to evaluate return expression") |
429 | | - elif isinstance(stmt.value, ast.Name): |
430 | | - if stmt.value.id == "XDP_PASS": |
431 | | - builder.ret(ir.Constant(ret_type, 2)) |
432 | | - did_return = True |
433 | | - elif stmt.value.id == "XDP_DROP": |
434 | | - builder.ret(ir.Constant(ret_type, 1)) |
435 | | - did_return = True |
436 | | - else: |
437 | | - raise ValueError("Failed to evaluate return expression") |
438 | | - else: |
439 | | - raise ValueError("Unsupported return value") |
| 446 | + did_return = handle_return( |
| 447 | + func, |
| 448 | + module, |
| 449 | + builder, |
| 450 | + stmt, |
| 451 | + map_sym_tab, |
| 452 | + local_sym_tab, |
| 453 | + structs_sym_tab, |
| 454 | + ret_type, |
| 455 | + ) |
440 | 456 | return did_return |
441 | 457 |
|
442 | 458 |
|
|
0 commit comments