88 HuggingFaceHandler ,
99 get_inference_handler_either_custom_or_default_handler ,
1010)
11- from huggingface_inference_toolkit .utils import (
11+ from huggingface_inference_toolkit .heavy_utils import (
1212 _is_gpu_available ,
13- _load_repository_from_hf ,
13+ load_repository_from_hf ,
1414)
1515
1616TASK = "text-classification"
@@ -29,7 +29,7 @@ def test_pt_get_device() -> None:
2929
3030 with tempfile .TemporaryDirectory () as tmpdirname :
3131 # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py
32- storage_dir = _load_repository_from_hf (MODEL , tmpdirname , framework = "pytorch" )
32+ storage_dir = load_repository_from_hf (MODEL , tmpdirname , framework = "pytorch" )
3333 h = HuggingFaceHandler (model_dir = str (storage_dir ), task = TASK )
3434 if torch .cuda .is_available ():
3535 assert h .pipeline .model .device == torch .device (type = "cuda" , index = 0 )
@@ -41,7 +41,7 @@ def test_pt_get_device() -> None:
4141def test_pt_predict_call (input_data : Dict [str , str ]) -> None :
4242 with tempfile .TemporaryDirectory () as tmpdirname :
4343 # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py
44- storage_dir = _load_repository_from_hf (MODEL , tmpdirname , framework = "pytorch" )
44+ storage_dir = load_repository_from_hf (MODEL , tmpdirname , framework = "pytorch" )
4545 h = HuggingFaceHandler (model_dir = str (storage_dir ), task = TASK )
4646
4747 prediction = h (input_data )
@@ -52,7 +52,7 @@ def test_pt_predict_call(input_data: Dict[str, str]) -> None:
5252@require_torch
5353def test_pt_custom_pipeline (input_data : Dict [str , str ]) -> None :
5454 with tempfile .TemporaryDirectory () as tmpdirname :
55- storage_dir = _load_repository_from_hf (
55+ storage_dir = load_repository_from_hf (
5656 "philschmid/custom-pipeline-text-classification" ,
5757 tmpdirname ,
5858 framework = "pytorch" ,
@@ -64,7 +64,7 @@ def test_pt_custom_pipeline(input_data: Dict[str, str]) -> None:
6464@require_torch
6565def test_pt_sentence_transformers_pipeline (input_data : Dict [str , str ]) -> None :
6666 with tempfile .TemporaryDirectory () as tmpdirname :
67- storage_dir = _load_repository_from_hf (
67+ storage_dir = load_repository_from_hf (
6868 "sentence-transformers/all-MiniLM-L6-v2" , tmpdirname , framework = "pytorch"
6969 )
7070 h = get_inference_handler_either_custom_or_default_handler (str (storage_dir ), task = "sentence-embeddings" )
@@ -76,7 +76,7 @@ def test_pt_sentence_transformers_pipeline(input_data: Dict[str, str]) -> None:
7676def test_tf_get_device ():
7777 with tempfile .TemporaryDirectory () as tmpdirname :
7878 # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py
79- storage_dir = _load_repository_from_hf (MODEL , tmpdirname , framework = "tensorflow" )
79+ storage_dir = load_repository_from_hf (MODEL , tmpdirname , framework = "tensorflow" )
8080 h = HuggingFaceHandler (model_dir = str (storage_dir ), task = TASK )
8181 if _is_gpu_available ():
8282 assert h .pipeline .device == 0
@@ -88,7 +88,7 @@ def test_tf_get_device():
8888def test_tf_predict_call (input_data : Dict [str , str ]) -> None :
8989 with tempfile .TemporaryDirectory () as tmpdirname :
9090 # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py
91- storage_dir = _load_repository_from_hf (MODEL , tmpdirname , framework = "tensorflow" )
91+ storage_dir = load_repository_from_hf (MODEL , tmpdirname , framework = "tensorflow" )
9292 handler = HuggingFaceHandler (model_dir = str (storage_dir ), task = TASK , framework = "tf" )
9393
9494 prediction = handler (input_data )
@@ -99,7 +99,7 @@ def test_tf_predict_call(input_data: Dict[str, str]) -> None:
9999@require_tf
100100def test_tf_custom_pipeline (input_data : Dict [str , str ]) -> None :
101101 with tempfile .TemporaryDirectory () as tmpdirname :
102- storage_dir = _load_repository_from_hf (
102+ storage_dir = load_repository_from_hf (
103103 "philschmid/custom-pipeline-text-classification" ,
104104 tmpdirname ,
105105 framework = "tensorflow" ,
@@ -112,7 +112,7 @@ def test_tf_custom_pipeline(input_data: Dict[str, str]) -> None:
112112def test_tf_sentence_transformers_pipeline ():
113113 # TODO should fail! because TF is not supported yet
114114 with tempfile .TemporaryDirectory () as tmpdirname :
115- storage_dir = _load_repository_from_hf (
115+ storage_dir = load_repository_from_hf (
116116 "sentence-transformers/all-MiniLM-L6-v2" , tmpdirname , framework = "tensorflow"
117117 )
118118 with pytest .raises (Exception ) as _exc_info :
0 commit comments