|
1 | 1 | import base64 |
2 | 2 | import logging |
3 | 3 | from typing import Dict |
4 | | -from unittest.mock import patch |
| 4 | +from unittest.mock import MagicMock, patch |
5 | 5 |
|
6 | 6 | import pytest |
7 | 7 | from pytest import LogCaptureFixture |
|
33 | 33 | from huggingface_hub.inference._providers.hf_inference import ( |
34 | 34 | HFInferenceBinaryInputTask, |
35 | 35 | HFInferenceConversational, |
| 36 | + HFInferenceFeatureExtractionTask, |
36 | 37 | HFInferenceTask, |
37 | 38 | ) |
38 | 39 | from huggingface_hub.inference._providers.hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask |
@@ -654,6 +655,15 @@ def test_prepare_payload_as_dict_conversational(self, mapped_model, parameters, |
654 | 655 | assert payload["model"] == expected_model |
655 | 656 | assert payload["messages"] == messages |
656 | 657 |
|
| 658 | + def test_prepare_payload_feature_extraction(self): |
| 659 | + helper = HFInferenceFeatureExtractionTask() |
| 660 | + payload = helper._prepare_payload_as_dict( |
| 661 | + inputs="This is a test sentence.", |
| 662 | + parameters={"truncate": True}, |
| 663 | + provider_mapping_info=MagicMock(), |
| 664 | + ) |
| 665 | + assert payload == {"inputs": "This is a test sentence.", "truncate": True} # not under "parameters" |
| 666 | + |
657 | 667 | @pytest.mark.parametrize( |
658 | 668 | "pipeline_tag,tags,task,should_raise", |
659 | 669 | [ |
|
0 commit comments