1919from libkernelbot .run_eval import CompileResult , EvalResult , FullResult , RunResult , SystemInfo
2020
2121
22- # define fixtures that create mock results
23- @pytest .fixture
22+ # define helpers and fixtures that create mock results
2423def sample_system_info () -> SystemInfo :
2524 return SystemInfo (
2625 gpu = "NVIDIA RTX 4090" , cpu = "Intel i9-12900K" , platform = "Linux-5.15.0" , torch = "2.0.1+cu118"
2726 )
2827
2928
30- @pytest .fixture
3129def sample_compile_result () -> CompileResult :
3230 return CompileResult (
3331 success = True ,
@@ -40,7 +38,6 @@ def sample_compile_result() -> CompileResult:
4038 )
4139
4240
43- @pytest .fixture
4441def sample_run_result () -> RunResult :
4542 return RunResult (
4643 success = True ,
@@ -65,23 +62,19 @@ def sample_run_result() -> RunResult:
6562
6663
6764@pytest .fixture
68- def sample_eval_result (
69- sample_compile_result : CompileResult , sample_run_result : RunResult
70- ) -> EvalResult :
65+ def sample_eval_result () -> EvalResult :
7166 return EvalResult (
7267 start = datetime .datetime .now () - datetime .timedelta (minutes = 5 ),
7368 end = datetime .datetime .now (),
74- compilation = sample_compile_result ,
75- run = sample_run_result ,
69+ compilation = sample_compile_result () ,
70+ run = sample_run_result () ,
7671 )
7772
7873
7974@pytest .fixture
80- def sample_full_result (
81- sample_system_info : SystemInfo , sample_eval_result : EvalResult
82- ) -> FullResult :
75+ def sample_full_result (sample_eval_result : EvalResult ) -> FullResult :
8376 return FullResult (
84- success = True , error = "" , system = sample_system_info , runs = {"test" : sample_eval_result }
77+ success = True , error = "" , system = sample_system_info () , runs = {"test" : sample_eval_result }
8578 )
8679
8780
@@ -90,16 +83,17 @@ def sample_full_result(
9083################################################
9184
9285
93- def test_generate_compile_report_nvcc_not_found (sample_compile_result : CompileResult ):
94- sample_compile_result .success = False
95- sample_compile_result .nvcc_found = False
96- sample_compile_result .command = ""
97- sample_compile_result .exit_code = 127
98- sample_compile_result .stderr = "nvcc: command not found"
99- sample_compile_result .stdout = ""
86+ def test_generate_compile_report_nvcc_not_found ():
87+ compile_result = sample_compile_result ()
88+ compile_result .success = False
89+ compile_result .nvcc_found = False
90+ compile_result .command = ""
91+ compile_result .exit_code = 127
92+ compile_result .stderr = "nvcc: command not found"
93+ compile_result .stdout = ""
10094
10195 reporter = RunResultReport ()
102- _generate_compile_report (reporter , sample_compile_result )
96+ _generate_compile_report (reporter , compile_result )
10397
10498 assert len (reporter .data ) == 1
10599 assert hasattr (reporter .data [0 ], "text" )
@@ -110,16 +104,17 @@ def test_generate_compile_report_nvcc_not_found(sample_compile_result: CompileRe
110104 assert "notify the server admins" in text
111105
112106
113- def test_generate_compile_report_with_errors (sample_compile_result : CompileResult ):
114- sample_compile_result .success = False
115- sample_compile_result .nvcc_found = True
116- sample_compile_result .command = "nvcc -o test test.cu -arch=sm_75"
117- sample_compile_result .exit_code = 1
118- sample_compile_result .stderr = 'test.cu(15): error: identifier "invalid_function" is undefined'
119- sample_compile_result .stdout = "warning: deprecated feature used"
107+ def test_generate_compile_report_with_errors ():
108+ compile_result = sample_compile_result ()
109+ compile_result .success = False
110+ compile_result .nvcc_found = True
111+ compile_result .command = "nvcc -o test test.cu -arch=sm_75"
112+ compile_result .exit_code = 1
113+ compile_result .stderr = 'test.cu(15): error: identifier "invalid_function" is undefined'
114+ compile_result .stdout = "warning: deprecated feature used"
120115
121116 reporter = RunResultReport ()
122- _generate_compile_report (reporter , sample_compile_result )
117+ _generate_compile_report (reporter , compile_result )
123118
124119 # Should have compilation text + stderr log + stdout log
125120 assert len (reporter .data ) == 3
@@ -141,16 +136,17 @@ def test_generate_compile_report_with_errors(sample_compile_result: CompileResul
141136 assert "warning: deprecated feature used" in reporter .data [2 ].content
142137
143138
144- def test_generate_compile_report_no_stdout (sample_compile_result : CompileResult ):
145- sample_compile_result .success = False
146- sample_compile_result .nvcc_found = True
147- sample_compile_result .command = "nvcc -o test test.cu"
148- sample_compile_result .exit_code = 1
149- sample_compile_result .stderr = "compilation error"
150- sample_compile_result .stdout = ""
139+ def test_generate_compile_report_no_stdout ():
140+ compile_result = sample_compile_result ()
141+ compile_result .success = False
142+ compile_result .nvcc_found = True
143+ compile_result .command = "nvcc -o test test.cu"
144+ compile_result .exit_code = 1
145+ compile_result .stderr = "compilation error"
146+ compile_result .stdout = ""
151147
152148 reporter = RunResultReport ()
153- _generate_compile_report (reporter , sample_compile_result )
149+ _generate_compile_report (reporter , compile_result )
154150
155151 # Should have compilation text + stderr log (no stdout log)
156152 assert len (reporter .data ) == 2
@@ -168,19 +164,20 @@ def test_generate_compile_report_no_stdout(sample_compile_result: CompileResult)
168164################################################
169165
170166
171- def test_short_fail_reason (sample_run_result : RunResult ):
172- sample_run_result .exit_code = consts .ExitCode .TIMEOUT_EXPIRED
173- assert _short_fail_reason (sample_run_result ) == " (timeout)"
167+ def test_short_fail_reason ():
168+ run_result = sample_run_result ()
169+ run_result .exit_code = consts .ExitCode .TIMEOUT_EXPIRED
170+ assert _short_fail_reason (run_result ) == " (timeout)"
174171
175- sample_run_result .exit_code = consts .ExitCode .CUDA_FAIL
176- assert _short_fail_reason (sample_run_result ) == " (cuda api error)"
172+ run_result .exit_code = consts .ExitCode .CUDA_FAIL
173+ assert _short_fail_reason (run_result ) == " (cuda api error)"
177174
178175 # VALIDATE_FAIL means unit tests failed, which will be reported differently
179- sample_run_result .exit_code = consts .ExitCode .VALIDATE_FAIL
180- assert _short_fail_reason (sample_run_result ) == ""
176+ run_result .exit_code = consts .ExitCode .VALIDATE_FAIL
177+ assert _short_fail_reason (run_result ) == ""
181178
182- sample_run_result .exit_code = 42
183- assert _short_fail_reason (sample_run_result ) == " (internal error 42)"
179+ run_result .exit_code = 42
180+ assert _short_fail_reason (run_result ) == " (internal error 42)"
184181
185182
186183def test_make_short_report_compilation_failed (sample_eval_result : EvalResult ):
@@ -191,13 +188,13 @@ def test_make_short_report_compilation_failed(sample_eval_result: EvalResult):
191188 assert result == ["❌ Compilation failed" ]
192189
193190
194- def test_make_short_report_full_success (sample_compile_result : CompileResult ):
191+ def test_make_short_report_full_success ():
195192 runs = {}
196193 for run_type in ["test" , "benchmark" , "profile" , "leaderboard" ]:
197194 runs [run_type ] = EvalResult (
198195 start = datetime .datetime .now () - datetime .timedelta (minutes = 5 ),
199196 end = datetime .datetime .now (),
200- compilation = sample_compile_result ,
197+ compilation = sample_compile_result () ,
201198 run = RunResult (
202199 success = True ,
203200 passed = True ,
@@ -239,8 +236,8 @@ def test_make_short_report_missing_components(sample_eval_result: EvalResult):
239236################################################
240237
241238
242- def test_make_test_log (sample_run_result : RunResult ):
243- log = make_test_log (sample_run_result )
239+ def test_make_test_log ():
240+ log = make_test_log (sample_run_result () )
244241 expected_lines = [
245242 "✅ Test addition" ,
246243 "> Addition works correctly" ,
@@ -316,8 +313,8 @@ def test_make_profile_log_no_data():
316313 assert log == "❗ Could not find any profiling data"
317314
318315
319- def test_generate_system_info (sample_system_info : SystemInfo ):
320- info = generate_system_info (sample_system_info )
316+ def test_generate_system_info ():
317+ info = generate_system_info (sample_system_info () )
321318
322319 expected_parts = ["NVIDIA RTX 4090" , "Intel i9-12900K" , "Linux-5.15.0" , "2.0.1+cu118" ]
323320
0 commit comments