@@ -190,36 +190,9 @@ def pytorch_inference_py_version(pytorch_inference_version, request):
190190 return "py3"
191191
192192
193- def _huggingface_base_fm_version (huggingface_vesion , base_fw ):
194- config = image_uris .config_for_framework ("huggingface" )
195- training_config = config .get ("training" )
196- original_version = huggingface_vesion
197- if "version_aliases" in training_config :
198- huggingface_vesion = training_config .get ("version_aliases" ).get (
199- huggingface_vesion , huggingface_vesion
200- )
201- version_config = training_config .get ("versions" ).get (huggingface_vesion )
202- for key in list (version_config .keys ()):
203- if key .startswith (base_fw ):
204- base_fw_version = key [len (base_fw ) :]
205- if len (original_version .split ("." )) == 2 :
206- base_fw_version = "." .join (base_fw_version .split ("." )[:- 1 ])
207- return base_fw_version
208-
209-
210193@pytest .fixture (scope = "module" )
211194def huggingface_pytorch_version (huggingface_training_version ):
212- return _huggingface_base_fm_version (huggingface_training_version , "pytorch" )
213-
214-
215- @pytest .fixture (scope = "module" )
216- def huggingface_pytorch_latest_version (huggingface_training_latest_version ):
217- return _huggingface_base_fm_version (huggingface_training_latest_version , "pytorch" )
218-
219-
220- @pytest .fixture (scope = "module" )
221- def huggingface_tensorflow_latest_version (huggingface_training_latest_version ):
222- return _huggingface_base_fm_version (huggingface_training_latest_version , "tensorflow" )
195+ return _huggingface_base_fm_version (huggingface_training_version , "pytorch" )[0 ]
223196
224197
225198@pytest .fixture (scope = "module" )
@@ -395,6 +368,32 @@ def _generate_all_framework_version_fixtures(metafunc):
395368 )
396369
397370
371+ def _huggingface_base_fm_version (huggingface_vesion , base_fw ):
372+ config = image_uris .config_for_framework ("huggingface" )
373+ training_config = config .get ("training" )
374+ original_version = huggingface_vesion
375+ if "version_aliases" in training_config :
376+ huggingface_vesion = training_config .get ("version_aliases" ).get (
377+ huggingface_vesion , huggingface_vesion
378+ )
379+ version_config = training_config .get ("versions" ).get (huggingface_vesion )
380+ versions = list ()
381+ for key in list (version_config .keys ()):
382+ if key .startswith (base_fw ):
383+ base_fw_version = key [len (base_fw ) :]
384+ if len (original_version .split ("." )) == 2 :
385+ base_fw_version = "." .join (base_fw_version .split ("." )[:- 1 ])
386+ versions .append (base_fw_version )
387+ return versions
388+
389+
390+ def _generate_huggingface_base_fw_latest_versions (metafunc , huggingface_version , base_fw ):
391+ versions = _huggingface_base_fm_version (huggingface_version , base_fw )
392+ fixture_name = f"huggingface_{ base_fw } _latest_version"
393+ if fixture_name in metafunc .fixturenames :
394+ metafunc .parametrize (fixture_name , versions , scope = "session" )
395+
396+
398397def _parametrize_framework_version_fixtures (metafunc , fixture_prefix , config ):
399398 fixture_name = "{}_version" .format (fixture_prefix )
400399 if fixture_name in metafunc .fixturenames :
@@ -407,6 +406,10 @@ def _parametrize_framework_version_fixtures(metafunc, fixture_prefix, config):
407406 if fixture_name in metafunc .fixturenames :
408407 metafunc .parametrize (fixture_name , (latest_version ,), scope = "session" )
409408
409+ if "huggingface" in fixture_prefix :
410+ _generate_huggingface_base_fw_latest_versions (metafunc , latest_version , "pytorch" )
411+ _generate_huggingface_base_fw_latest_versions (metafunc , latest_version , "tensorflow" )
412+
410413 fixture_name = "{}_latest_py_version" .format (fixture_prefix )
411414 if fixture_name in metafunc .fixturenames :
412415 config = config ["versions" ]
0 commit comments