1
1
import sys
2
2
import types
3
- from types import SimpleNamespace
4
3
from importlib import util as import_util
4
+ from types import SimpleNamespace
5
+ from typing import Any , Optional
5
6
6
7
import pytest
7
8
8
9
from langchain .chat_models import init_chat_model
9
10
10
-
11
- @pytest .fixture ()
11
+ git add libs / langchain / tests / unit_tests / chat_models / test_init_chat_model_hf . py
12
+ @pytest .fixture
12
13
def hf_fakes (monkeypatch : pytest .MonkeyPatch ) -> SimpleNamespace :
13
- """Install fake modules for `langchain_huggingface` and `transformers` and
14
- capture their call arguments for assertions."""
15
- pipeline_calls : list [tuple [str , dict ]] = []
16
- init_calls : list [dict ] = []
14
+ """
15
+ Install fake modules for `langchain_huggingface` and `transformers` and
16
+ capture their call arguments for assertions.
17
+
18
+ """
19
+ pipeline_calls : list [tuple [str , dict [str , Any ]]] = []
20
+ init_calls : list [dict [str , Any ]] = []
17
21
18
22
# Fake transformers.pipeline
19
- def fake_pipeline (task : str , ** kwargs ) :
23
+ def fake_pipeline (task : str , ** kwargs : Any ) -> SimpleNamespace :
20
24
pipeline_calls .append ((task , dict (kwargs )))
21
25
# A simple stand-in object for the HF pipeline
22
26
return SimpleNamespace (_kind = "dummy_hf_pipeline" )
@@ -27,7 +31,7 @@ def fake_pipeline(task: str, **kwargs):
27
31
28
32
# Fake langchain_huggingface.ChatHuggingFace that REQUIRES `llm`
29
33
class FakeChatHuggingFace :
30
- def __init__ (self , * , llm , ** kwargs ) :
34
+ def __init__ (self , * , llm : Any , ** kwargs : Any ) -> None :
31
35
init_calls .append ({"llm" : llm , "kwargs" : dict (kwargs )})
32
36
# minimal instance; tests only assert on ctor args
33
37
self ._llm = llm
@@ -49,7 +53,11 @@ def __init__(self, *, llm, **kwargs):
49
53
hf_pkg .ChatHuggingFace = FakeChatHuggingFace
50
54
51
55
monkeypatch .setitem (sys .modules , "langchain_huggingface" , hf_pkg )
52
- monkeypatch .setitem (sys .modules , "langchain_huggingface.chat_models" , hf_chat_models_pkg )
56
+ monkeypatch .setitem (
57
+ sys .modules ,
58
+ "langchain_huggingface.chat_models" ,
59
+ hf_chat_models_pkg ,
60
+ )
53
61
monkeypatch .setitem (
54
62
sys .modules ,
55
63
"langchain_huggingface.chat_models.huggingface" ,
@@ -59,7 +67,7 @@ def __init__(self, *, llm, **kwargs):
59
67
# Ensure _check_pkg sees both packages as installed
60
68
orig_find_spec = import_util .find_spec
61
69
62
- def fake_find_spec (name : str ):
70
+ def fake_find_spec (name : str ) -> Optional [ object ] :
63
71
if name in {
64
72
"transformers" ,
65
73
"langchain_huggingface" ,
@@ -74,19 +82,22 @@ def fake_find_spec(name: str):
74
82
return SimpleNamespace (pipeline_calls = pipeline_calls , init_calls = init_calls )
75
83
76
84
77
- def _last_pipeline_kwargs (hf_fakes : SimpleNamespace ) -> dict :
85
+ def _last_pipeline_kwargs (hf_fakes : SimpleNamespace ) -> dict [ str , Any ] :
78
86
assert hf_fakes .pipeline_calls , "transformers.pipeline was not called"
79
87
_ , kwargs = hf_fakes .pipeline_calls [- 1 ]
80
88
return kwargs
81
89
82
90
83
- def _last_chat_kwargs (hf_fakes : SimpleNamespace ) -> dict :
91
+ def _last_chat_kwargs (hf_fakes : SimpleNamespace ) -> dict [ str , Any ] :
84
92
assert hf_fakes .init_calls , "ChatHuggingFace was not constructed"
85
93
return hf_fakes .init_calls [- 1 ]["kwargs" ]
86
94
87
95
88
96
@pytest .mark .xfail (
89
- reason = "Pending fix for huggingface init (#28226 / #33167) — currently passes model_id to ChatHuggingFace" ,
97
+ reason = (
98
+ "Pending fix for huggingface init (#28226 / #33167) — currently passes "
99
+ "model_id to ChatHuggingFace"
100
+ ),
90
101
raises = TypeError ,
91
102
)
92
103
def test_hf_basic_wraps_pipeline (hf_fakes : SimpleNamespace ) -> None :
0 commit comments