Skip to content

Commit 22c62ca

Browse files
committed
Fix some bugs and tests
1 parent 8b2f948 commit 22c62ca

File tree

4 files changed

+230
-186
lines changed

4 files changed

+230
-186
lines changed

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 46 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -463,51 +463,51 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
463463
*(
464464
[
465465
ast.Assign(
466-
targets=[
467-
ast.Name(id='test_stdout_tag', ctx=ast.Store())],
466+
targets=[ast.Name(id="test_stdout_tag", ctx=ast.Store())],
468467
value=ast.JoinedStr(
469468
values=[
470-
ast.FormattedValue(
471-
value=ast.Name(id='test_module_name', ctx=ast.Load()),
472-
conversion=-1),
473-
ast.Constant(value=':'),
469+
ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1),
470+
ast.Constant(value=":"),
474471
ast.FormattedValue(
475472
value=ast.IfExp(
476-
test=ast.Name(id='test_class_name', ctx=ast.Load()),
473+
test=ast.Name(id="test_class_name", ctx=ast.Load()),
477474
body=ast.BinOp(
478-
left=ast.Name(id='test_class_name', ctx=ast.Load()),
475+
left=ast.Name(id="test_class_name", ctx=ast.Load()),
479476
op=ast.Add(),
480-
right=ast.Constant(value='.')),
481-
orelse=ast.Constant(value='')),
482-
conversion=-1),
483-
ast.FormattedValue(
484-
value=ast.Name(id='test_name', ctx=ast.Load()),
485-
conversion=-1),
486-
ast.Constant(value=':'),
487-
ast.FormattedValue(
488-
value=ast.Name(id='function_name', ctx=ast.Load()),
489-
conversion=-1),
490-
ast.Constant(value=':'),
491-
ast.FormattedValue(
492-
value=ast.Name(id='loop_index', ctx=ast.Load()),
493-
conversion=-1),
494-
ast.Constant(value=':'),
495-
ast.FormattedValue(
496-
value=ast.Name(id='invocation_id', ctx=ast.Load()),
497-
conversion=-1)]),
498-
lineno=lineno + 9,),
477+
right=ast.Constant(value="."),
478+
),
479+
orelse=ast.Constant(value=""),
480+
),
481+
conversion=-1,
482+
),
483+
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
484+
ast.Constant(value=":"),
485+
ast.FormattedValue(value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1),
486+
ast.Constant(value=":"),
487+
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
488+
ast.Constant(value=":"),
489+
ast.FormattedValue(value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1),
490+
]
491+
),
492+
lineno=lineno + 9,
493+
),
499494
ast.Expr(
500495
value=ast.Call(
501-
func=ast.Name(id='print', ctx=ast.Load()),
496+
func=ast.Name(id="print", ctx=ast.Load()),
502497
args=[
503498
ast.JoinedStr(
504499
values=[
505-
ast.Constant(value='!$######'),
500+
ast.Constant(value="!$######"),
506501
ast.FormattedValue(
507-
value=ast.Name(id='test_stdout_tag', ctx=ast.Load()),
508-
conversion=-1),
509-
ast.Constant(value='######$!')])],
510-
keywords=[])),
502+
value=ast.Name(id="test_stdout_tag", ctx=ast.Load()), conversion=-1
503+
),
504+
ast.Constant(value="######$!"),
505+
]
506+
)
507+
],
508+
keywords=[],
509+
)
510+
),
511511
# ast.Expr(
512512
# value=ast.Call(
513513
# func=ast.Name(id="print", ctx=ast.Load()),
@@ -646,66 +646,28 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
646646
),
647647
ast.Expr(
648648
value=ast.Call(
649-
func=ast.Name(id='print', ctx=ast.Load()),
649+
func=ast.Name(id="print", ctx=ast.Load()),
650650
args=[
651651
ast.JoinedStr(
652652
values=[
653-
ast.Constant(value='!######'),
654-
ast.FormattedValue(
655-
value=ast.Name(id='test_stdout_tag', ctx=ast.Load()),
656-
conversion=-1),
657-
ast.Constant(value='######!')])],
658-
keywords=[])),
659-
*(
660-
[
661-
ast.Expr(
662-
value=ast.Call(
663-
func=ast.Name(id="print", ctx=ast.Load()),
664-
args=[
665-
ast.JoinedStr(
666-
values=[
667-
ast.Constant(value="!######"),
668-
ast.FormattedValue(
669-
value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1
670-
),
671-
ast.Constant(value=":"),
672-
ast.FormattedValue(
673-
value=ast.IfExp(
674-
test=ast.Name(id="test_class_name", ctx=ast.Load()),
675-
body=ast.BinOp(
676-
left=ast.Name(id="test_class_name", ctx=ast.Load()),
677-
op=ast.Add(),
678-
right=ast.Constant(value="."),
679-
),
680-
orelse=ast.Constant(value=""),
681-
),
682-
conversion=-1,
683-
),
684-
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
685-
ast.Constant(value=":"),
686-
ast.FormattedValue(
687-
value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1
688-
),
689-
ast.Constant(value=":"),
690-
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
691-
ast.Constant(value=":"),
692-
ast.FormattedValue(
693-
value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1
694-
),
653+
ast.Constant(value="!######"),
654+
ast.FormattedValue(value=ast.Name(id="test_stdout_tag", ctx=ast.Load()), conversion=-1),
655+
*(
656+
[
695657
ast.Constant(value=":"),
696658
ast.FormattedValue(
697659
value=ast.Name(id="codeflash_duration", ctx=ast.Load()), conversion=-1
698660
),
699-
ast.Constant(value="######!"),
700661
]
701-
)
702-
],
703-
keywords=[],
662+
if mode == TestingMode.PERFORMANCE
663+
else []
664+
),
665+
ast.Constant(value="######!"),
666+
]
704667
)
705-
)
706-
]
707-
if mode == TestingMode.PERFORMANCE
708-
else []
668+
],
669+
keywords=[],
670+
)
709671
),
710672
*(
711673
[

0 commit comments

Comments
 (0)