@@ -190,7 +190,7 @@ def pytorch_inference_py_version(pytorch_inference_version, request):
190190 return "py3"
191191
192192
193- def _huggingface_pytorch_version (huggingface_vesion ):
193+ def _huggingface_base_fm_version (huggingface_vesion , base_fw ):
194194 config = image_uris .config_for_framework ("huggingface" )
195195 training_config = config .get ("training" )
196196 original_version = huggingface_vesion
@@ -200,21 +200,26 @@ def _huggingface_pytorch_version(huggingface_vesion):
200200 )
201201 version_config = training_config .get ("versions" ).get (huggingface_vesion )
202202 for key in list (version_config .keys ()):
203- if key .startswith ("pytorch" ):
204- pt_version = key [7 :]
203+ if key .startswith (base_fw ):
204+ base_fw_version = key [len ( base_fw ) :]
205205 if len (original_version .split ("." )) == 2 :
206- pt_version = "." .join (pt_version .split ("." )[:- 1 ])
207- return pt_version
206+ base_fw_version = "." .join (base_fw_version .split ("." )[:- 1 ])
207+ return base_fw_version
208208
209209
210210@pytest .fixture (scope = "module" )
211211def huggingface_pytorch_version (huggingface_training_version ):
212- return _huggingface_pytorch_version (huggingface_training_version )
212+ return _huggingface_base_fm_version (huggingface_training_version , "pytorch" )
213213
214214
215215@pytest .fixture (scope = "module" )
216216def huggingface_pytorch_latest_version (huggingface_training_latest_version ):
217- return _huggingface_pytorch_version (huggingface_training_latest_version )
217+ return _huggingface_base_fm_version (huggingface_training_latest_version , "pytorch" )
218+
219+
220+ @pytest .fixture (scope = "module" )
221+ def huggingface_tensorflow_latest_version (huggingface_training_latest_version ):
222+ return _huggingface_base_fm_version (huggingface_training_latest_version , "tensorflow" )
218223
219224
220225@pytest .fixture (scope = "module" )
0 commit comments