|
35 | 35 | UserPromptPart,
|
36 | 36 | VideoUrl,
|
37 | 37 | )
|
38 |
| -from ..profiles import ModelProfile |
| 38 | +from ..profiles import ModelProfile, ModelProfileSpec |
39 | 39 | from ..providers import Provider, infer_provider
|
40 | 40 | from ..settings import ModelSettings
|
41 | 41 | from ..tools import ToolDefinition
|
@@ -121,20 +121,26 @@ def __init__(
|
121 | 121 | model_name: str,
|
122 | 122 | *,
|
123 | 123 | provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
|
| 124 | + profile: ModelProfileSpec | None = None, |
| 125 | + settings: ModelSettings | None = None, |
124 | 126 | ):
|
125 | 127 | """Initialize a Hugging Face model.
|
126 | 128 |
|
127 | 129 | Args:
|
128 | 130 | model_name: The name of the Model to use. You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
|
129 | 131 | provider: The provider to use for Hugging Face Inference Providers. Can be either the string 'huggingface' or an
|
130 | 132 | instance of `Provider[AsyncInferenceClient]`. If not provided, the other parameters will be used.
|
| 133 | + profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. |
| 134 | + settings: Model-specific settings that will be used as defaults for this model. |
131 | 135 | """
|
132 | 136 | self._model_name = model_name
|
133 | 137 | self._provider = provider
|
134 | 138 | if isinstance(provider, str):
|
135 | 139 | provider = infer_provider(provider)
|
136 | 140 | self.client = provider.client
|
137 | 141 |
|
| 142 | + super().__init__(settings=settings, profile=profile or provider.model_profile) |
| 143 | + |
138 | 144 | async def request(
|
139 | 145 | self,
|
140 | 146 | messages: list[ModelMessage],
|
@@ -444,11 +450,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
444 | 450 |
|
445 | 451 | # Handle the text part of the response
|
446 | 452 | content = choice.delta.content
|
447 |
| - if content: |
| 453 | + if content is not None: |
448 | 454 | maybe_event = self._parts_manager.handle_text_delta(
|
449 | 455 | vendor_part_id='content',
|
450 | 456 | content=content,
|
451 | 457 | thinking_tags=self._model_profile.thinking_tags,
|
| 458 | + ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, |
452 | 459 | )
|
453 | 460 | if maybe_event is not None: # pragma: no branch
|
454 | 461 | yield maybe_event
|
|
0 commit comments