@@ -319,13 +319,16 @@ def test_empty_model_size(max_depth):
319319
320320
321321@pytest .mark .parametrize (
322- "accelerator" ,
322+ ( "accelerator" , "precision" ) ,
323323 [
324- pytest .param ("gpu" , marks = RunIf (min_cuda_gpus = 1 )),
325- pytest .param ("mps" , marks = RunIf (mps = True )),
324+ pytest .param ("gpu" , "16-true" , marks = RunIf (min_cuda_gpus = 1 )),
325+ pytest .param ("gpu" , "32-true" , marks = RunIf (min_cuda_gpus = 1 )),
326+ pytest .param ("gpu" , "64-true" , marks = RunIf (min_cuda_gpus = 1 )),
327+ pytest .param ("mps" , "16-true" , marks = RunIf (mps = True )),
328+ pytest .param ("mps" , "32-true" , marks = RunIf (mps = True )),
329+ # Note: "64-true" with "mps" is skipped because MPS does not support float64
326330 ],
327331)
328- @pytest .mark .parametrize ("precision" , ["16-true" , "32-true" , "64-true" ])
329332def test_model_size_precision (tmp_path , accelerator , precision ):
330333 """Test model size for different precision types."""
331334 model = PreCalculatedModel (precision = int (precision .split ("-" )[0 ]))
0 commit comments