@@ -34,7 +34,6 @@ def make_sure_other_containers_are_stopped(client: DockerClient, container_name:
3434# reraise = True
3535# )
3636def wait_for_container_to_be_ready (base_url , time_between_retries = 3 , max_retries = 30 ):
37-
3837 retries = 0
3938 error = None
4039
@@ -46,9 +45,7 @@ def wait_for_container_to_be_ready(base_url, time_between_retries=3, max_retries
4645 logging .info ("Container ready!" )
4746 return True
4847 else :
49- raise ConnectionError (
50- f"Couldn'start container, Error: { response .status_code } "
51- )
48+ raise ConnectionError (f"Couldn'start container, Error: { response .status_code } " )
5249 except Exception as exception :
5350 error = exception
5451 logging .warning (f"Container at { base_url } not ready, trying again..." )
@@ -62,7 +59,6 @@ def verify_task(
6259 # container: DockerClient,
6360 task : str ,
6461 port : int = 5000 ,
65- framework : str = "pytorch" ,
6662):
6763 BASE_URL = f"http://localhost:{ port } "
6864 logging .info (f"Base URL: { BASE_URL } " )
@@ -90,10 +86,7 @@ def verify_task(
9086 headers = {"content-type" : "audio/x-audio" },
9187 ).json ()
9288 elif task == "text-to-image" :
93- prediction = requests .post (
94- f"{ BASE_URL } " , json = input , headers = {"accept" : "image/png" }
95- ).content
96-
89+ prediction = requests .post (f"{ BASE_URL } " , json = input , headers = {"accept" : "image/png" }).content
9790 else :
9891 prediction = requests .post (f"{ BASE_URL } " , json = input ).json ()
9992
@@ -119,6 +112,8 @@ def verify_task(
119112@pytest .mark .parametrize (
120113 "task" ,
121114 [
115+ # transformers
116+ # TODO: "visual-question-answering" and "zero-shot-image-classification" not supported yet due to multimodality input
122117 "text-classification" ,
123118 "zero-shot-classification" ,
124119 "token-classification" ,
@@ -136,25 +131,22 @@ def verify_task(
136131 "image-segmentation" ,
137132 "table-question-answering" ,
138133 "conversational" ,
139- # TODO currently not supported due to multimodality input
140- # "visual-question-answering",
141- # "zero-shot-image-classification",
134+ "image-text-to-text" ,
135+ # sentence-transformers
142136 "sentence-similarity" ,
143137 "sentence-embeddings" ,
144138 "sentence-ranking" ,
145139 # diffusers
146140 "text-to-image" ,
147141 ],
148142)
149- def test_pt_container_remote_model (task ) -> None :
143+ def test_pt_container_remote_model (task : str ) -> None :
150144 container_name = f"integration-test-{ task } "
151145 container_image = f"starlette-transformers:{ DEVICE } "
152146 framework = "pytorch"
153147 model = task2model [task ][framework ]
154148 port = random .randint (5000 , 6000 )
155- device_request = (
156- [docker .types .DeviceRequest (count = - 1 , capabilities = [["gpu" ]])] if IS_GPU else []
157- )
149+ device_request = [docker .types .DeviceRequest (count = - 1 , capabilities = [["gpu" ]])] if IS_GPU else []
158150
159151 make_sure_other_containers_are_stopped (client , container_name )
160152 container = client .containers .run (
@@ -177,6 +169,8 @@ def test_pt_container_remote_model(task) -> None:
177169@pytest .mark .parametrize (
178170 "task" ,
179171 [
172+ # transformers
173+ # TODO: "visual-question-answering" and "zero-shot-image-classification" not supported yet due to multimodality input
180174 "text-classification" ,
181175 "zero-shot-classification" ,
182176 "token-classification" ,
@@ -194,29 +188,26 @@ def test_pt_container_remote_model(task) -> None:
194188 "image-segmentation" ,
195189 "table-question-answering" ,
196190 "conversational" ,
197- # TODO currently not supported due to multimodality input
198- # "visual-question-answering",
199- # "zero-shot-image-classification",
191+ "image-text-to-text" ,
192+ # sentence-transformers
200193 "sentence-similarity" ,
201194 "sentence-embeddings" ,
202195 "sentence-ranking" ,
203196 # diffusers
204197 "text-to-image" ,
205198 ],
206199)
207- def test_pt_container_local_model (task ) -> None :
200+ def test_pt_container_local_model (task : str ) -> None :
208201 container_name = f"integration-test-{ task } "
209202 container_image = f"starlette-transformers:{ DEVICE } "
210203 framework = "pytorch"
211204 model = task2model [task ][framework ]
212205 port = random .randint (5000 , 6000 )
213- device_request = (
214- [docker .types .DeviceRequest (count = - 1 , capabilities = [["gpu" ]])] if IS_GPU else []
215- )
206+ device_request = [docker .types .DeviceRequest (count = - 1 , capabilities = [["gpu" ]])] if IS_GPU else []
216207 make_sure_other_containers_are_stopped (client , container_name )
217208 with tempfile .TemporaryDirectory () as tmpdirname :
218209 # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py
219- _storage_dir = _load_repository_from_hf (model , tmpdirname , framework = "pytorch" )
210+ _load_repository_from_hf (model , tmpdirname , framework = "pytorch" )
220211 container = client .containers .run (
221212 container_image ,
222213 name = container_name ,
@@ -241,9 +232,7 @@ def test_pt_container_local_model(task) -> None:
241232def test_pt_container_custom_handler (repository_id ) -> None :
242233 container_name = "integration-test-custom"
243234 container_image = f"starlette-transformers:{ DEVICE } "
244- device_request = (
245- [docker .types .DeviceRequest (count = - 1 , capabilities = [["gpu" ]])] if IS_GPU else []
246- )
235+ device_request = [docker .types .DeviceRequest (count = - 1 , capabilities = [["gpu" ]])] if IS_GPU else []
247236 port = random .randint (5000 , 6000 )
248237
249238 make_sure_other_containers_are_stopped (client , container_name )
@@ -277,12 +266,10 @@ def test_pt_container_custom_handler(repository_id) -> None:
277266 "repository_id" ,
278267 ["philschmid/custom-pipeline-text-classification" ],
279268)
280- def test_pt_container_legacy_custom_pipeline (repository_id ) -> None :
269+ def test_pt_container_legacy_custom_pipeline (repository_id : str ) -> None :
281270 container_name = "integration-test-custom"
282271 container_image = f"starlette-transformers:{ DEVICE } "
283- device_request = (
284- [docker .types .DeviceRequest (count = - 1 , capabilities = [["gpu" ]])] if IS_GPU else []
285- )
272+ device_request = [docker .types .DeviceRequest (count = - 1 , capabilities = [["gpu" ]])] if IS_GPU else []
286273 port = random .randint (5000 , 6000 )
287274
288275 make_sure_other_containers_are_stopped (client , container_name )
@@ -345,9 +332,7 @@ def test_tf_container_remote_model(task) -> None:
345332 container_image = f"starlette-transformers:{ DEVICE } "
346333 framework = "tensorflow"
347334 model = task2model [task ][framework ]
348- device_request = (
349- [docker .types .DeviceRequest (count = - 1 , capabilities = [["gpu" ]])] if IS_GPU else []
350- )
335+ device_request = [docker .types .DeviceRequest (count = - 1 , capabilities = [["gpu" ]])] if IS_GPU else []
351336 if model is None :
352337 pytest .skip ("no supported TF model" )
353338 port = random .randint (5000 , 6000 )
@@ -401,9 +386,7 @@ def test_tf_container_local_model(task) -> None:
401386 container_image = f"starlette-transformers:{ DEVICE } "
402387 framework = "tensorflow"
403388 model = task2model [task ][framework ]
404- device_request = (
405- [docker .types .DeviceRequest (count = - 1 , capabilities = [["gpu" ]])] if IS_GPU else []
406- )
389+ device_request = [docker .types .DeviceRequest (count = - 1 , capabilities = [["gpu" ]])] if IS_GPU else []
407390 if model is None :
408391 pytest .skip ("no supported TF model" )
409392 port = random .randint (5000 , 6000 )
0 commit comments