8686 "huggingface_training_compiler" ,
8787)
8888
89+ PYTORCH_RENEWED_GPU = "ml.g4dn.xlarge"
90+
8991
9092def pytest_addoption (parser ):
9193 parser .addoption ("--sagemaker-client-config" , action = "store" , default = None )
@@ -221,22 +223,26 @@ def mxnet_eia_latest_py_version():
221223
222224@pytest .fixture (scope = "module" , params = ["py2" , "py3" ])
223225def pytorch_training_py_version (pytorch_training_version , request ):
224- if Version (pytorch_training_version ) < Version ("1.5.0 " ):
225- return request . param
226+ if Version (pytorch_training_version ) >= Version ("1.13 " ):
227+ return "py39"
226228 elif Version (pytorch_training_version ) >= Version ("1.9" ):
227229 return "py38"
228- else :
230+ elif Version ( pytorch_training_version ) >= Version ( "1.5.0" ) :
229231 return "py3"
232+ else :
233+ return request .param
230234
231235
232236@pytest .fixture (scope = "module" , params = ["py2" , "py3" ])
233237def pytorch_inference_py_version (pytorch_inference_version , request ):
234- if Version (pytorch_inference_version ) < Version ("1.4.0 " ):
235- return request . param
238+ if Version (pytorch_inference_version ) >= Version ("1.13 " ):
239+ return "py39"
236240 elif Version (pytorch_inference_version ) >= Version ("1.9" ):
237241 return "py38"
238- else :
242+ elif Version ( pytorch_inference_version ) >= Version ( "1.4.0" ) :
239243 return "py3"
244+ else :
245+ return request .param
240246
241247
242248@pytest .fixture (scope = "module" )
@@ -252,9 +258,13 @@ def huggingface_pytorch_training_py_version(huggingface_pytorch_training_version
252258
253259
254260@pytest .fixture (scope = "module" )
255- def huggingface_training_compiler_pytorch_version (huggingface_training_compiler_version ):
261+ def huggingface_training_compiler_pytorch_version (
262+ huggingface_training_compiler_version ,
263+ ):
256264 versions = _huggingface_base_fm_version (
257- huggingface_training_compiler_version , "pytorch" , "huggingface_training_compiler"
265+ huggingface_training_compiler_version ,
266+ "pytorch" ,
267+ "huggingface_training_compiler" ,
258268 )
259269 if not versions :
260270 pytest .skip (
@@ -265,9 +275,13 @@ def huggingface_training_compiler_pytorch_version(huggingface_training_compiler_
265275
266276
267277@pytest .fixture (scope = "module" )
268- def huggingface_training_compiler_tensorflow_version (huggingface_training_compiler_version ):
278+ def huggingface_training_compiler_tensorflow_version (
279+ huggingface_training_compiler_version ,
280+ ):
269281 versions = _huggingface_base_fm_version (
270- huggingface_training_compiler_version , "tensorflow" , "huggingface_training_compiler"
282+ huggingface_training_compiler_version ,
283+ "tensorflow" ,
284+ "huggingface_training_compiler" ,
271285 )
272286 if not versions :
273287 pytest .skip (
@@ -289,19 +303,25 @@ def huggingface_training_compiler_tensorflow_py_version(
289303
290304
291305@pytest .fixture (scope = "module" )
292- def huggingface_training_compiler_pytorch_py_version (huggingface_training_compiler_pytorch_version ):
306+ def huggingface_training_compiler_pytorch_py_version (
307+ huggingface_training_compiler_pytorch_version ,
308+ ):
293309 return "py38"
294310
295311
296312@pytest .fixture (scope = "module" )
297- def huggingface_pytorch_latest_training_py_version (huggingface_training_pytorch_latest_version ):
313+ def huggingface_pytorch_latest_training_py_version (
314+ huggingface_training_pytorch_latest_version ,
315+ ):
298316 return (
299317 "py38" if Version (huggingface_training_pytorch_latest_version ) >= Version ("1.9" ) else "py36"
300318 )
301319
302320
303321@pytest .fixture (scope = "module" )
304- def huggingface_pytorch_latest_inference_py_version (huggingface_inference_pytorch_latest_version ):
322+ def huggingface_pytorch_latest_inference_py_version (
323+ huggingface_inference_pytorch_latest_version ,
324+ ):
305325 return (
306326 "py38"
307327 if Version (huggingface_inference_pytorch_latest_version ) >= Version ("1.9" )
@@ -477,7 +497,8 @@ def pytorch_ddp_py_version():
477497
478498
479499@pytest .fixture (
480- scope = "module" , params = ["1.10" , "1.10.0" , "1.10.2" , "1.11" , "1.11.0" , "1.12" , "1.12.0" ]
500+ scope = "module" ,
501+ params = ["1.10" , "1.10.0" , "1.10.2" , "1.11" , "1.11.0" , "1.12" , "1.12.0" ],
481502)
482503def pytorch_ddp_framework_version (request ):
483504 return request .param
@@ -511,6 +532,23 @@ def gpu_instance_type(sagemaker_session, request):
511532 return "ml.p3.2xlarge"
512533
513534
535+ @pytest .fixture ()
536+ def gpu_pytorch_instance_type (sagemaker_session , request ):
537+ if "pytorch_inference_version" in request .fixturenames :
538+ fw_version = request .getfixturevalue ("pytorch_inference_version" )
539+ else :
540+ fw_version = request .param
541+
542+ region = sagemaker_session .boto_session .region_name
543+ if region in NO_P3_REGIONS :
544+ if Version (fw_version ) >= Version ("1.13" ):
545+ return PYTORCH_RENEWED_GPU
546+ else :
547+ return "ml.p2.xlarge"
548+ else :
549+ return "ml.p3.2xlarge"
550+
551+
514552@pytest .fixture (scope = "session" )
515553def gpu_instance_type_list (sagemaker_session , request ):
516554 region = sagemaker_session .boto_session .region_name
0 commit comments