Skip to content

Commit c7369e9

Browse files
committed
no timeout_decorator windows
1 parent 841f55b commit c7369e9

File tree

2 files changed

+85
-12
lines changed

2 files changed

+85
-12
lines changed

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import platform
45
from pathlib import Path
56
from typing import TYPE_CHECKING
67

@@ -135,7 +136,10 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
135136
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
136137
if node.name.startswith("test_"):
137138
did_update = False
138-
if self.test_framework == "unittest":
139+
if self.test_framework == "unittest" and platform.system() != "Windows":
140+
# Only add timeout decorator on non-Windows platforms
141+
# Windows doesn't support SIGALRM signal required by timeout_decorator
142+
139143
node.decorator_list.append(
140144
ast.Call(
141145
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
@@ -354,7 +358,7 @@ def inject_profiling_into_existing_test(
354358
new_imports.extend(
355359
[ast.Import(names=[ast.alias(name="sqlite3")]), ast.Import(names=[ast.alias(name="dill", asname="pickle")])]
356360
)
357-
if test_framework == "unittest":
361+
if test_framework == "unittest" and platform.system() != "Windows":
358362
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
359363
tree.body = [*new_imports, create_wrapper_function(mode), *tree.body]
360364
return True, isort.code(ast.unparse(tree), float_to_top=True)

tests/test_instrument_tests.py

Lines changed: 79 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
TestsInFile,
2525
TestType,
2626
)
27+
import platform
28+
2729
from codeflash.optimization.function_optimizer import FunctionOptimizer
2830
from codeflash.verification.verification_utils import TestConfig
2931

@@ -1451,6 +1453,7 @@ def test_sort():
14511453

14521454

14531455
def test_perfinjector_bubble_sort_unittest_results() -> None:
1456+
14541457
code = """import unittest
14551458
14561459
from code_to_optimize.bubble_sort import sorter
@@ -1471,8 +1474,74 @@ def test_sort(self):
14711474
self.assertEqual(output, list(range(50)))
14721475
"""
14731476

1474-
expected = (
1475-
"""import gc
1477+
is_windows = platform.system() == "Windows"
1478+
1479+
if is_windows:
1480+
expected = (
1481+
"""import gc
1482+
import os
1483+
import sqlite3
1484+
import time
1485+
import unittest
1486+
1487+
import dill as pickle
1488+
1489+
from code_to_optimize.bubble_sort import sorter
1490+
1491+
1492+
"""
1493+
+ codeflash_wrap_string
1494+
+ """
1495+
class TestPigLatin(unittest.TestCase):
1496+
1497+
def test_sort(self):
1498+
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
1499+
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
1500+
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
1501+
codeflash_cur = codeflash_con.cursor()
1502+
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
1503+
input = [5, 4, 3, 2, 1, 0]
1504+
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, input)
1505+
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
1506+
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
1507+
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, input)
1508+
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
1509+
input = list(reversed(range(50)))
1510+
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, input)
1511+
self.assertEqual(output, list(range(50)))
1512+
codeflash_con.close()
1513+
"""
1514+
)
1515+
expected_perf = (
1516+
"""import gc
1517+
import os
1518+
import time
1519+
import unittest
1520+
1521+
from code_to_optimize.bubble_sort import sorter
1522+
1523+
1524+
"""
1525+
+ codeflash_wrap_perfonly_string
1526+
+ """
1527+
class TestPigLatin(unittest.TestCase):
1528+
1529+
def test_sort(self):
1530+
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
1531+
input = [5, 4, 3, 2, 1, 0]
1532+
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, input)
1533+
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
1534+
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
1535+
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, input)
1536+
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
1537+
input = list(reversed(range(50)))
1538+
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, input)
1539+
self.assertEqual(output, list(range(50)))
1540+
"""
1541+
)
1542+
else:
1543+
expected = (
1544+
"""import gc
14761545
import os
14771546
import sqlite3
14781547
import time
@@ -1485,8 +1554,8 @@ def test_sort(self):
14851554
14861555
14871556
"""
1488-
+ codeflash_wrap_string
1489-
+ """
1557+
+ codeflash_wrap_string
1558+
+ """
14901559
class TestPigLatin(unittest.TestCase):
14911560
14921561
@timeout_decorator.timeout(15)
@@ -1507,9 +1576,9 @@ def test_sort(self):
15071576
self.assertEqual(output, list(range(50)))
15081577
codeflash_con.close()
15091578
"""
1510-
)
1511-
expected_perf = (
1512-
"""import gc
1579+
)
1580+
expected_perf = (
1581+
"""import gc
15131582
import os
15141583
import time
15151584
import unittest
@@ -1520,8 +1589,8 @@ def test_sort(self):
15201589
15211590
15221591
"""
1523-
+ codeflash_wrap_perfonly_string
1524-
+ """
1592+
+ codeflash_wrap_perfonly_string
1593+
+ """
15251594
class TestPigLatin(unittest.TestCase):
15261595
15271596
@timeout_decorator.timeout(15)
@@ -1537,7 +1606,7 @@ def test_sort(self):
15371606
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, input)
15381607
self.assertEqual(output, list(range(50)))
15391608
"""
1540-
)
1609+
)
15411610
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
15421611
test_path = (
15431612
Path(__file__).parent.resolve()

0 commit comments

Comments
 (0)