Skip to content

Commit 86943a9

Browse files
test(huggingface): ruff fixes; remove requires marker for extended tests
1 parent 7791e8b commit 86943a9

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

libs/langchain/tests/unit_tests/chat_models/test_init_chat_model_hf.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
11
import sys
22
import types
3-
from types import SimpleNamespace
43
from importlib import util as import_util
4+
from types import SimpleNamespace
5+
from typing import Any, Optional
56

67
import pytest
78

89
from langchain.chat_models import init_chat_model
910

10-
11-
@pytest.fixture()
11+
git add libs/langchain/tests/unit_tests/chat_models/test_init_chat_model_hf.py
12+
@pytest.fixture
1213
def hf_fakes(monkeypatch: pytest.MonkeyPatch) -> SimpleNamespace:
13-
"""Install fake modules for `langchain_huggingface` and `transformers` and
14-
capture their call arguments for assertions."""
15-
pipeline_calls: list[tuple[str, dict]] = []
16-
init_calls: list[dict] = []
14+
"""
15+
Install fake modules for `langchain_huggingface` and `transformers` and
16+
capture their call arguments for assertions.
17+
18+
"""
19+
pipeline_calls: list[tuple[str, dict[str, Any]]] = []
20+
init_calls: list[dict[str, Any]] = []
1721

1822
# Fake transformers.pipeline
19-
def fake_pipeline(task: str, **kwargs):
23+
def fake_pipeline(task: str, **kwargs: Any) -> SimpleNamespace:
2024
pipeline_calls.append((task, dict(kwargs)))
2125
# A simple stand-in object for the HF pipeline
2226
return SimpleNamespace(_kind="dummy_hf_pipeline")
@@ -27,7 +31,7 @@ def fake_pipeline(task: str, **kwargs):
2731

2832
# Fake langchain_huggingface.ChatHuggingFace that REQUIRES `llm`
2933
class FakeChatHuggingFace:
30-
def __init__(self, *, llm, **kwargs):
34+
def __init__(self, *, llm: Any, **kwargs: Any) -> None:
3135
init_calls.append({"llm": llm, "kwargs": dict(kwargs)})
3236
# minimal instance; tests only assert on ctor args
3337
self._llm = llm
@@ -49,7 +53,11 @@ def __init__(self, *, llm, **kwargs):
4953
hf_pkg.ChatHuggingFace = FakeChatHuggingFace
5054

5155
monkeypatch.setitem(sys.modules, "langchain_huggingface", hf_pkg)
52-
monkeypatch.setitem(sys.modules, "langchain_huggingface.chat_models", hf_chat_models_pkg)
56+
monkeypatch.setitem(
57+
sys.modules,
58+
"langchain_huggingface.chat_models",
59+
hf_chat_models_pkg,
60+
)
5361
monkeypatch.setitem(
5462
sys.modules,
5563
"langchain_huggingface.chat_models.huggingface",
@@ -59,7 +67,7 @@ def __init__(self, *, llm, **kwargs):
5967
# Ensure _check_pkg sees both packages as installed
6068
orig_find_spec = import_util.find_spec
6169

62-
def fake_find_spec(name: str):
70+
def fake_find_spec(name: str) -> Optional[object]:
6371
if name in {
6472
"transformers",
6573
"langchain_huggingface",
@@ -74,19 +82,22 @@ def fake_find_spec(name: str):
7482
return SimpleNamespace(pipeline_calls=pipeline_calls, init_calls=init_calls)
7583

7684

77-
def _last_pipeline_kwargs(hf_fakes: SimpleNamespace) -> dict:
85+
def _last_pipeline_kwargs(hf_fakes: SimpleNamespace) -> dict[str, Any]:
7886
assert hf_fakes.pipeline_calls, "transformers.pipeline was not called"
7987
_, kwargs = hf_fakes.pipeline_calls[-1]
8088
return kwargs
8189

8290

83-
def _last_chat_kwargs(hf_fakes: SimpleNamespace) -> dict:
91+
def _last_chat_kwargs(hf_fakes: SimpleNamespace) -> dict[str, Any]:
8492
assert hf_fakes.init_calls, "ChatHuggingFace was not constructed"
8593
return hf_fakes.init_calls[-1]["kwargs"]
8694

8795

8896
@pytest.mark.xfail(
89-
reason="Pending fix for huggingface init (#28226 / #33167) — currently passes model_id to ChatHuggingFace",
97+
reason=(
98+
"Pending fix for huggingface init (#28226 / #33167) — currently passes "
99+
"model_id to ChatHuggingFace"
100+
),
90101
raises=TypeError,
91102
)
92103
def test_hf_basic_wraps_pipeline(hf_fakes: SimpleNamespace) -> None:

0 commit comments

Comments
 (0)