Skip to content

Commit 0c521e3

Browse files
test(chat): add xfail regression tests for Hugging Face init_chat_model
1 parent a89c549 commit 0c521e3

File tree

1 file changed

+162
-0
lines changed

1 file changed

+162
-0
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import sys
2+
import types
3+
from types import SimpleNamespace
4+
from importlib import util as import_util
5+
6+
import pytest
7+
8+
from langchain.chat_models import init_chat_model
9+
10+
pytestmark = pytest.mark.requires("langchain_huggingface", "transformers")
11+
12+
13+
@pytest.fixture()
14+
def hf_fakes(monkeypatch: pytest.MonkeyPatch) -> SimpleNamespace:
15+
"""Install fake modules for `langchain_huggingface` and `transformers` and
16+
capture their call arguments for assertions."""
17+
pipeline_calls: list[tuple[str, dict]] = []
18+
init_calls: list[dict] = []
19+
20+
# Fake transformers.pipeline
21+
def fake_pipeline(task: str, **kwargs):
22+
pipeline_calls.append((task, dict(kwargs)))
23+
# A simple stand-in object for the HF pipeline
24+
return SimpleNamespace(_kind="dummy_hf_pipeline")
25+
26+
transformers_mod = types.ModuleType("transformers")
27+
transformers_mod.pipeline = fake_pipeline
28+
monkeypatch.setitem(sys.modules, "transformers", transformers_mod)
29+
30+
# Fake langchain_huggingface.ChatHuggingFace that REQUIRES `llm`
31+
class FakeChatHuggingFace:
32+
def __init__(self, *, llm, **kwargs):
33+
init_calls.append({"llm": llm, "kwargs": dict(kwargs)})
34+
# minimal instance; tests only assert on ctor args
35+
self._llm = llm
36+
self._kwargs = kwargs
37+
38+
# Build full package path: langchain_huggingface.chat_models.huggingface
39+
hf_pkg = types.ModuleType("langchain_huggingface")
40+
hf_pkg.__path__ = [] # mark as package
41+
42+
hf_chat_models_pkg = types.ModuleType("langchain_huggingface.chat_models")
43+
hf_chat_models_pkg.__path__ = [] # mark as package
44+
45+
hf_chat_huggingface_mod = types.ModuleType(
46+
"langchain_huggingface.chat_models.huggingface"
47+
)
48+
hf_chat_huggingface_mod.ChatHuggingFace = FakeChatHuggingFace
49+
50+
# Optional: expose at package root for compatibility with top-level imports
51+
hf_pkg.ChatHuggingFace = FakeChatHuggingFace
52+
53+
monkeypatch.setitem(sys.modules, "langchain_huggingface", hf_pkg)
54+
monkeypatch.setitem(sys.modules, "langchain_huggingface.chat_models", hf_chat_models_pkg)
55+
monkeypatch.setitem(
56+
sys.modules,
57+
"langchain_huggingface.chat_models.huggingface",
58+
hf_chat_huggingface_mod,
59+
)
60+
61+
# Ensure _check_pkg sees both packages as installed
62+
orig_find_spec = import_util.find_spec
63+
64+
def fake_find_spec(name: str):
65+
if name in {
66+
"transformers",
67+
"langchain_huggingface",
68+
"langchain_huggingface.chat_models",
69+
"langchain_huggingface.chat_models.huggingface",
70+
}:
71+
return object()
72+
return orig_find_spec(name)
73+
74+
monkeypatch.setattr("importlib.util.find_spec", fake_find_spec)
75+
76+
return SimpleNamespace(pipeline_calls=pipeline_calls, init_calls=init_calls)
77+
78+
79+
def _last_pipeline_kwargs(hf_fakes: SimpleNamespace) -> dict:
80+
assert hf_fakes.pipeline_calls, "transformers.pipeline was not called"
81+
_, kwargs = hf_fakes.pipeline_calls[-1]
82+
return kwargs
83+
84+
85+
def _last_chat_kwargs(hf_fakes: SimpleNamespace) -> dict:
86+
assert hf_fakes.init_calls, "ChatHuggingFace was not constructed"
87+
return hf_fakes.init_calls[-1]["kwargs"]
88+
89+
90+
@pytest.mark.xfail(
91+
reason="Pending fix for huggingface init (#28226 / #33167) — currently passes model_id to ChatHuggingFace",
92+
raises=TypeError,
93+
)
94+
def test_hf_basic_wraps_pipeline(hf_fakes: SimpleNamespace) -> None:
95+
# provider specified inline
96+
llm = init_chat_model(
97+
"huggingface:microsoft/Phi-3-mini-4k-instruct",
98+
task="text-generation",
99+
temperature=0,
100+
)
101+
# Wrapped object should be constructed (we don't require a specific type here)
102+
assert llm is not None
103+
104+
# Make failure modes explicit
105+
assert hf_fakes.pipeline_calls, "Expected transformers.pipeline to be called"
106+
assert hf_fakes.init_calls, "Expected ChatHuggingFace to be constructed"
107+
108+
# pipeline called with correct model (don't assert task value)
109+
kwargs = _last_pipeline_kwargs(hf_fakes)
110+
assert kwargs["model"] == "microsoft/Phi-3-mini-4k-instruct"
111+
112+
# ChatHuggingFace must be constructed with llm
113+
assert "llm" in hf_fakes.init_calls[-1]
114+
assert hf_fakes.init_calls[-1]["llm"]._kind == "dummy_hf_pipeline"
115+
116+
117+
@pytest.mark.xfail(
118+
reason="Pending fix for huggingface init (#28226 / #33167)",
119+
raises=TypeError,
120+
)
121+
def test_hf_max_tokens_translated_to_max_new_tokens(
122+
hf_fakes: SimpleNamespace,
123+
) -> None:
124+
init_chat_model(
125+
model="mistralai/Mistral-7B-Instruct-v0.2",
126+
model_provider="huggingface",
127+
task="text-generation",
128+
max_tokens=42,
129+
)
130+
assert hf_fakes.pipeline_calls, "Expected transformers.pipeline to be called"
131+
assert hf_fakes.init_calls, "Expected ChatHuggingFace to be constructed"
132+
kwargs = _last_pipeline_kwargs(hf_fakes)
133+
assert kwargs.get("max_new_tokens") == 42
134+
# Ensure we don't leak the old name into pipeline kwargs
135+
assert "max_tokens" not in kwargs
136+
137+
138+
@pytest.mark.xfail(
139+
reason="Pending fix for huggingface init (#28226 / #33167)",
140+
raises=TypeError,
141+
)
142+
def test_hf_timeout_and_max_retries_pass_through_to_chat_wrapper(
143+
hf_fakes: SimpleNamespace,
144+
) -> None:
145+
init_chat_model(
146+
model="microsoft/Phi-3-mini-4k-instruct",
147+
model_provider="huggingface",
148+
task="text-generation",
149+
temperature=0.1,
150+
timeout=7,
151+
max_retries=3,
152+
)
153+
assert hf_fakes.pipeline_calls, "Expected transformers.pipeline to be called"
154+
assert hf_fakes.init_calls, "Expected ChatHuggingFace to be constructed"
155+
chat_kwargs = _last_chat_kwargs(hf_fakes)
156+
# Assert these control args are passed to the wrapper (not the pipeline)
157+
assert chat_kwargs.get("timeout") == 7
158+
assert chat_kwargs.get("max_retries") == 3
159+
# And that they are NOT passed to transformers.pipeline
160+
pipeline_kwargs = _last_pipeline_kwargs(hf_fakes)
161+
assert "timeout" not in pipeline_kwargs
162+
assert "max_retries" not in pipeline_kwargs

0 commit comments

Comments
 (0)