Skip to content

Commit d0ed2c9

Browse files
committed
standardise and fix smoke_test_timeout
Signed-off-by: Jack Luar <[email protected]>
1 parent 4c2762d commit d0ed2c9

File tree

2 files changed

+22
-34
lines changed

2 files changed

+22
-34
lines changed

tools/AutoTuner/src/autotuner/distributed.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -515,18 +515,19 @@ def parse_arguments():
515515
args.timeout_per_trial = round(args.timeout_per_trial * 3600)
516516
if args.timeout is not None:
517517
args.timeout = round(args.timeout * 3600)
518+
args.timeout = set_timeout(args.timeout, args.timeout_per_trial)
518519

519520
# Calculate timeout based on cpu_budget
520521
if args.cpu_budget != -1:
521522
args.timeout = round(args.cpu_budget / os.cpu_count() * 3600)
522523
args.timeout_per_trial = round(
523524
args.cpu_budget / (args.jobs * args.resources_per_trial) * 3600
524525
)
525-
overall_timeout = min(args.timeout, args.timeout_per_trial)
526+
args.timeotu = set_timeout(args.timeout, args.timeout_per_trial)
526527
if args.mode == "tune":
527-
template = calculate_expected_numbers(overall_timeout, args.samples)
528+
template = calculate_expected_numbers(args.timeout, args.samples)
528529
else:
529-
template = calculate_expected_numbers(overall_timeout, 1)
530+
template = calculate_expected_numbers(args.timeout, 1)
530531
print(template)
531532
if not args.yes:
532533
print(
@@ -620,6 +621,17 @@ def set_training_class(function):
620621
return None
621622

622623

624+
def set_timeout(timeout, timeout_per_trial):
625+
"""
626+
Set timeout for experiment.
627+
"""
628+
return (
629+
min(timeout, timeout_per_trial)
630+
if (timeout and timeout_per_trial)
631+
else (timeout or timeout_per_trial)
632+
)
633+
634+
623635
@ray.remote
624636
def save_best(results):
625637
"""

tools/AutoTuner/test/smoke_test_timeout.py

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import os
44

55
cur_dir = os.path.dirname(os.path.abspath(__file__))
6-
src_dir = os.path.join(cur_dir, "../src/autotuner")
7-
os.chdir(src_dir)
86

97

108
class BaseTimeoutSmokeTest(unittest.TestCase):
@@ -20,9 +18,8 @@ def setUp(self):
2018

2119
# 0.001 hour translates to 3.6 seconds, which will definitely cause failure.
2220
timeout_flags = ["--timeout 0.001", "--timeout_per_trial 0.001"]
23-
self.timeout_limit = 60 # 60 second upper limit (Ray needs time to shutdown)
2421
self.commands = [
25-
"python3 distributed.py"
22+
"python3 -m autotuner.distributed"
2623
f" --design {self.design}"
2724
f" --platform {self.platform}"
2825
f" --experiment {self.experiment}-{idx}"
@@ -34,49 +31,28 @@ def setUp(self):
3431
]
3532

3633
def test_timeout(self):
37-
raise NotImplementedError(
38-
"This method needs to be implemented in the derivative classes."
39-
)
34+
if not (self.platform and self.design):
35+
raise unittest.SkipTest("Platform and design have to be defined")
36+
for command in self.commands:
37+
out = subprocess.run(command, shell=True, check=False)
38+
failed = out.returncode != 0
39+
self.assertTrue(failed)
4040

4141

4242
class asap7TimeoutSmokeTest(BaseTimeoutSmokeTest):
4343
platform = "asap7"
4444
design = "gcd"
4545

46-
def test_timeout(self):
47-
for command in self.commands:
48-
out = subprocess.run(
49-
command, shell=True, check=False, timeout=self.timeout_limit
50-
)
51-
failed = out.returncode == 1
52-
self.assertTrue(failed)
53-
5446

5547
class sky130hdTimeoutSmokeTest(BaseTimeoutSmokeTest):
5648
platform = "sky130hd"
5749
design = "gcd"
5850

59-
def test_timeout(self):
60-
for command in self.commands:
61-
out = subprocess.run(
62-
command, shell=True, check=False, timeout=self.timeout_limit
63-
)
64-
failed = out.returncode == 1
65-
self.assertTrue(failed)
66-
6751

6852
class ihpsg13g2TimeoutSmokeTest(BaseTimeoutSmokeTest):
6953
platform = "ihp-sg13g2"
7054
design = "gcd"
7155

72-
def test_timeout(self):
73-
for command in self.commands:
74-
out = subprocess.run(
75-
command, shell=True, check=False, timeout=self.timeout_limit
76-
)
77-
failed = out.returncode == 1
78-
self.assertTrue(failed)
79-
8056

8157
if __name__ == "__main__":
8258
unittest.main()

0 commit comments

Comments
 (0)