Skip to content

Commit 1a5f103

Browse files
committed
fix overlappings args in codeflash wrap
1 parent 9ac5d34 commit 1a5f103

File tree

3 files changed

+60
-54
lines changed

3 files changed

+60
-54
lines changed

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -365,15 +365,15 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
365365
targets=[ast.Name(id="test_id", ctx=ast.Store())],
366366
value=ast.JoinedStr(
367367
values=[
368-
ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1),
368+
ast.FormattedValue(value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1),
369369
ast.Constant(value=":"),
370-
ast.FormattedValue(value=ast.Name(id="test_class_name", ctx=ast.Load()), conversion=-1),
370+
ast.FormattedValue(value=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), conversion=-1),
371371
ast.Constant(value=":"),
372-
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
372+
ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1),
373373
ast.Constant(value=":"),
374-
ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1),
374+
ast.FormattedValue(value=ast.Name(id="codeflash_line_id", ctx=ast.Load()), conversion=-1),
375375
ast.Constant(value=":"),
376-
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
376+
ast.FormattedValue(value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1),
377377
]
378378
),
379379
lineno=lineno + 1,
@@ -453,7 +453,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
453453
targets=[ast.Name(id="invocation_id", ctx=ast.Store())],
454454
value=ast.JoinedStr(
455455
values=[
456-
ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1),
456+
ast.FormattedValue(value=ast.Name(id="codeflash_line_id", ctx=ast.Load()), conversion=-1),
457457
ast.Constant(value="_"),
458458
ast.FormattedValue(value=ast.Name(id="codeflash_test_index", ctx=ast.Load()), conversion=-1),
459459
]
@@ -466,25 +466,31 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
466466
targets=[ast.Name(id="test_stdout_tag", ctx=ast.Store())],
467467
value=ast.JoinedStr(
468468
values=[
469-
ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1),
469+
ast.FormattedValue(
470+
value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1
471+
),
470472
ast.Constant(value=":"),
471473
ast.FormattedValue(
472474
value=ast.IfExp(
473-
test=ast.Name(id="test_class_name", ctx=ast.Load()),
475+
test=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()),
474476
body=ast.BinOp(
475-
left=ast.Name(id="test_class_name", ctx=ast.Load()),
477+
left=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()),
476478
op=ast.Add(),
477479
right=ast.Constant(value="."),
478480
),
479481
orelse=ast.Constant(value=""),
480482
),
481483
conversion=-1,
482484
),
483-
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
485+
ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1),
484486
ast.Constant(value=":"),
485-
ast.FormattedValue(value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1),
487+
ast.FormattedValue(
488+
value=ast.Name(id="codeflash_function_name", ctx=ast.Load()), conversion=-1
489+
),
486490
ast.Constant(value=":"),
487-
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
491+
ast.FormattedValue(
492+
value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1
493+
),
488494
ast.Constant(value=":"),
489495
ast.FormattedValue(value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1),
490496
]
@@ -537,7 +543,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
537543
ast.Assign(
538544
targets=[ast.Name(id="return_value", ctx=ast.Store())],
539545
value=ast.Call(
540-
func=ast.Name(id="wrapped", ctx=ast.Load()),
546+
func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()),
541547
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
542548
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
543549
),
@@ -664,11 +670,11 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
664670
ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"),
665671
ast.Tuple(
666672
elts=[
667-
ast.Name(id="test_module_name", ctx=ast.Load()),
668-
ast.Name(id="test_class_name", ctx=ast.Load()),
669-
ast.Name(id="test_name", ctx=ast.Load()),
670-
ast.Name(id="function_name", ctx=ast.Load()),
671-
ast.Name(id="loop_index", ctx=ast.Load()),
673+
ast.Name(id="codeflash_test_module_name", ctx=ast.Load()),
674+
ast.Name(id="codeflash_test_class_name", ctx=ast.Load()),
675+
ast.Name(id="codeflash_test_name", ctx=ast.Load()),
676+
ast.Name(id="codeflash_function_name", ctx=ast.Load()),
677+
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
672678
ast.Name(id="invocation_id", ctx=ast.Load()),
673679
ast.Name(id="codeflash_duration", ctx=ast.Load()),
674680
ast.Name(id="pickled_return_value", ctx=ast.Load()),
@@ -707,13 +713,13 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
707713
name="codeflash_wrap",
708714
args=ast.arguments(
709715
args=[
710-
ast.arg(arg="wrapped", annotation=None),
711-
ast.arg(arg="test_module_name", annotation=None),
712-
ast.arg(arg="test_class_name", annotation=None),
713-
ast.arg(arg="test_name", annotation=None),
714-
ast.arg(arg="function_name", annotation=None),
715-
ast.arg(arg="line_id", annotation=None),
716-
ast.arg(arg="loop_index", annotation=None),
716+
ast.arg(arg="codeflash_wrapped", annotation=None),
717+
ast.arg(arg="codeflash_test_module_name", annotation=None),
718+
ast.arg(arg="codeflash_test_class_name", annotation=None),
719+
ast.arg(arg="codeflash_test_name", annotation=None),
720+
ast.arg(arg="codeflash_function_name", annotation=None),
721+
ast.arg(arg="codeflash_line_id", annotation=None),
722+
ast.arg(arg="codeflash_loop_index", annotation=None),
717723
*([ast.arg(arg="codeflash_cur", annotation=None)] if mode == TestingMode.BEHAVIOR else []),
718724
*([ast.arg(arg="codeflash_con", annotation=None)] if mode == TestingMode.BEHAVIOR else []),
719725
],

tests/test_instrument_all_and_run.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,31 @@
1515
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
1616

1717
# Used by cli instrumentation
18-
codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
19-
test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}'
18+
codeflash_wrap_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
19+
test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}'
2020
if not hasattr(codeflash_wrap, 'index'):
2121
codeflash_wrap.index = {{}}
2222
if test_id in codeflash_wrap.index:
2323
codeflash_wrap.index[test_id] += 1
2424
else:
2525
codeflash_wrap.index[test_id] = 0
2626
codeflash_test_index = codeflash_wrap.index[test_id]
27-
invocation_id = f'{{line_id}}_{{codeflash_test_index}}'
28-
test_stdout_tag = f"{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}"
27+
invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}'
28+
test_stdout_tag = f"{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}"
2929
print(f"!$######{{test_stdout_tag}}######$!")
3030
exception = None
3131
gc.disable()
3232
try:
3333
counter = time.perf_counter_ns()
34-
return_value = wrapped(*args, **kwargs)
34+
return_value = codeflash_wrapped(*args, **kwargs)
3535
codeflash_duration = time.perf_counter_ns() - counter
3636
except Exception as e:
3737
codeflash_duration = time.perf_counter_ns() - counter
3838
exception = e
3939
gc.enable()
4040
print(f"!######{{test_stdout_tag}}######!")
4141
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
42-
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
42+
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
4343
codeflash_con.commit()
4444
if exception:
4545
raise exception

0 commit comments

Comments
 (0)