@@ -318,13 +318,16 @@ def test_empty_model_size(max_depth):
318318
319319
320320@pytest .mark .parametrize (
321- "accelerator" ,
321+ ( "accelerator" , "precision" ) ,
322322 [
323- pytest .param ("gpu" , marks = RunIf (min_cuda_gpus = 1 )),
324- pytest .param ("mps" , marks = RunIf (mps = True )),
323+ pytest .param ("gpu" , "16-true" , marks = RunIf (min_cuda_gpus = 1 )),
324+ pytest .param ("gpu" , "32-true" , marks = RunIf (min_cuda_gpus = 1 )),
325+ pytest .param ("gpu" , "64-true" , marks = RunIf (min_cuda_gpus = 1 )),
326+ pytest .param ("mps" , "16-true" , marks = RunIf (mps = True )),
327+ pytest .param ("mps" , "32-true" , marks = RunIf (mps = True )),
328+ # Note: "64-true" with "mps" is skipped because MPS does not support float64
325329 ],
326330)
327- @pytest .mark .parametrize ("precision" , ["16-true" , "32-true" , "64-true" ])
328331def test_model_size_precision (tmp_path , accelerator , precision ):
329332 """Test model size for different precision types."""
330333 model = PreCalculatedModel (precision = int (precision .split ("-" )[0 ]))
0 commit comments