Skip to content

Commit 823046b

Browse files
committed
fixtures to functions for better composability
1 parent 7c1826c commit 823046b

File tree

1 file changed

+49
-52
lines changed

1 file changed

+49
-52
lines changed

unit-tests/test_report.py

Lines changed: 49 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,13 @@
1919
from libkernelbot.run_eval import CompileResult, EvalResult, FullResult, RunResult, SystemInfo
2020

2121

22-
# define fixtures that create mock results
23-
@pytest.fixture
22+
# define helpers and fixtures that create mock results
2423
def sample_system_info() -> SystemInfo:
2524
return SystemInfo(
2625
gpu="NVIDIA RTX 4090", cpu="Intel i9-12900K", platform="Linux-5.15.0", torch="2.0.1+cu118"
2726
)
2827

2928

30-
@pytest.fixture
3129
def sample_compile_result() -> CompileResult:
3230
return CompileResult(
3331
success=True,
@@ -40,7 +38,6 @@ def sample_compile_result() -> CompileResult:
4038
)
4139

4240

43-
@pytest.fixture
4441
def sample_run_result() -> RunResult:
4542
return RunResult(
4643
success=True,
@@ -65,23 +62,19 @@ def sample_run_result() -> RunResult:
6562

6663

6764
@pytest.fixture
68-
def sample_eval_result(
69-
sample_compile_result: CompileResult, sample_run_result: RunResult
70-
) -> EvalResult:
65+
def sample_eval_result() -> EvalResult:
7166
return EvalResult(
7267
start=datetime.datetime.now() - datetime.timedelta(minutes=5),
7368
end=datetime.datetime.now(),
74-
compilation=sample_compile_result,
75-
run=sample_run_result,
69+
compilation=sample_compile_result(),
70+
run=sample_run_result(),
7671
)
7772

7873

7974
@pytest.fixture
80-
def sample_full_result(
81-
sample_system_info: SystemInfo, sample_eval_result: EvalResult
82-
) -> FullResult:
75+
def sample_full_result(sample_eval_result: EvalResult) -> FullResult:
8376
return FullResult(
84-
success=True, error="", system=sample_system_info, runs={"test": sample_eval_result}
77+
success=True, error="", system=sample_system_info(), runs={"test": sample_eval_result}
8578
)
8679

8780

@@ -90,16 +83,17 @@ def sample_full_result(
9083
################################################
9184

9285

93-
def test_generate_compile_report_nvcc_not_found(sample_compile_result: CompileResult):
94-
sample_compile_result.success = False
95-
sample_compile_result.nvcc_found = False
96-
sample_compile_result.command = ""
97-
sample_compile_result.exit_code = 127
98-
sample_compile_result.stderr = "nvcc: command not found"
99-
sample_compile_result.stdout = ""
86+
def test_generate_compile_report_nvcc_not_found():
87+
compile_result = sample_compile_result()
88+
compile_result.success = False
89+
compile_result.nvcc_found = False
90+
compile_result.command = ""
91+
compile_result.exit_code = 127
92+
compile_result.stderr = "nvcc: command not found"
93+
compile_result.stdout = ""
10094

10195
reporter = RunResultReport()
102-
_generate_compile_report(reporter, sample_compile_result)
96+
_generate_compile_report(reporter, compile_result)
10397

10498
assert len(reporter.data) == 1
10599
assert hasattr(reporter.data[0], "text")
@@ -110,16 +104,17 @@ def test_generate_compile_report_nvcc_not_found(sample_compile_result: CompileRe
110104
assert "notify the server admins" in text
111105

112106

113-
def test_generate_compile_report_with_errors(sample_compile_result: CompileResult):
114-
sample_compile_result.success = False
115-
sample_compile_result.nvcc_found = True
116-
sample_compile_result.command = "nvcc -o test test.cu -arch=sm_75"
117-
sample_compile_result.exit_code = 1
118-
sample_compile_result.stderr = 'test.cu(15): error: identifier "invalid_function" is undefined'
119-
sample_compile_result.stdout = "warning: deprecated feature used"
107+
def test_generate_compile_report_with_errors():
108+
compile_result = sample_compile_result()
109+
compile_result.success = False
110+
compile_result.nvcc_found = True
111+
compile_result.command = "nvcc -o test test.cu -arch=sm_75"
112+
compile_result.exit_code = 1
113+
compile_result.stderr = 'test.cu(15): error: identifier "invalid_function" is undefined'
114+
compile_result.stdout = "warning: deprecated feature used"
120115

121116
reporter = RunResultReport()
122-
_generate_compile_report(reporter, sample_compile_result)
117+
_generate_compile_report(reporter, compile_result)
123118

124119
# Should have compilation text + stderr log + stdout log
125120
assert len(reporter.data) == 3
@@ -141,16 +136,17 @@ def test_generate_compile_report_with_errors(sample_compile_result: CompileResul
141136
assert "warning: deprecated feature used" in reporter.data[2].content
142137

143138

144-
def test_generate_compile_report_no_stdout(sample_compile_result: CompileResult):
145-
sample_compile_result.success = False
146-
sample_compile_result.nvcc_found = True
147-
sample_compile_result.command = "nvcc -o test test.cu"
148-
sample_compile_result.exit_code = 1
149-
sample_compile_result.stderr = "compilation error"
150-
sample_compile_result.stdout = ""
139+
def test_generate_compile_report_no_stdout():
140+
compile_result = sample_compile_result()
141+
compile_result.success = False
142+
compile_result.nvcc_found = True
143+
compile_result.command = "nvcc -o test test.cu"
144+
compile_result.exit_code = 1
145+
compile_result.stderr = "compilation error"
146+
compile_result.stdout = ""
151147

152148
reporter = RunResultReport()
153-
_generate_compile_report(reporter, sample_compile_result)
149+
_generate_compile_report(reporter, compile_result)
154150

155151
# Should have compilation text + stderr log (no stdout log)
156152
assert len(reporter.data) == 2
@@ -168,19 +164,20 @@ def test_generate_compile_report_no_stdout(sample_compile_result: CompileResult)
168164
################################################
169165

170166

171-
def test_short_fail_reason(sample_run_result: RunResult):
172-
sample_run_result.exit_code = consts.ExitCode.TIMEOUT_EXPIRED
173-
assert _short_fail_reason(sample_run_result) == " (timeout)"
167+
def test_short_fail_reason():
168+
run_result = sample_run_result()
169+
run_result.exit_code = consts.ExitCode.TIMEOUT_EXPIRED
170+
assert _short_fail_reason(run_result) == " (timeout)"
174171

175-
sample_run_result.exit_code = consts.ExitCode.CUDA_FAIL
176-
assert _short_fail_reason(sample_run_result) == " (cuda api error)"
172+
run_result.exit_code = consts.ExitCode.CUDA_FAIL
173+
assert _short_fail_reason(run_result) == " (cuda api error)"
177174

178175
# VALIDATE_FAIL means unit tests failed, which will be reported differently
179-
sample_run_result.exit_code = consts.ExitCode.VALIDATE_FAIL
180-
assert _short_fail_reason(sample_run_result) == ""
176+
run_result.exit_code = consts.ExitCode.VALIDATE_FAIL
177+
assert _short_fail_reason(run_result) == ""
181178

182-
sample_run_result.exit_code = 42
183-
assert _short_fail_reason(sample_run_result) == " (internal error 42)"
179+
run_result.exit_code = 42
180+
assert _short_fail_reason(run_result) == " (internal error 42)"
184181

185182

186183
def test_make_short_report_compilation_failed(sample_eval_result: EvalResult):
@@ -191,13 +188,13 @@ def test_make_short_report_compilation_failed(sample_eval_result: EvalResult):
191188
assert result == ["❌ Compilation failed"]
192189

193190

194-
def test_make_short_report_full_success(sample_compile_result: CompileResult):
191+
def test_make_short_report_full_success():
195192
runs = {}
196193
for run_type in ["test", "benchmark", "profile", "leaderboard"]:
197194
runs[run_type] = EvalResult(
198195
start=datetime.datetime.now() - datetime.timedelta(minutes=5),
199196
end=datetime.datetime.now(),
200-
compilation=sample_compile_result,
197+
compilation=sample_compile_result(),
201198
run=RunResult(
202199
success=True,
203200
passed=True,
@@ -239,8 +236,8 @@ def test_make_short_report_missing_components(sample_eval_result: EvalResult):
239236
################################################
240237

241238

242-
def test_make_test_log(sample_run_result: RunResult):
243-
log = make_test_log(sample_run_result)
239+
def test_make_test_log():
240+
log = make_test_log(sample_run_result())
244241
expected_lines = [
245242
"✅ Test addition",
246243
"> Addition works correctly",
@@ -316,8 +313,8 @@ def test_make_profile_log_no_data():
316313
assert log == "❗ Could not find any profiling data"
317314

318315

319-
def test_generate_system_info(sample_system_info: SystemInfo):
320-
info = generate_system_info(sample_system_info)
316+
def test_generate_system_info():
317+
info = generate_system_info(sample_system_info())
321318

322319
expected_parts = ["NVIDIA RTX 4090", "Intel i9-12900K", "Linux-5.15.0", "2.0.1+cu118"]
323320

0 commit comments

Comments
 (0)