Skip to content

Commit eeb855f

Browse files
committed
🐛 Bugfix: bacth create model failed due to missing expected chunk size
1 parent b636f42 commit eeb855f

File tree

2 files changed

+127
-7
lines changed

2 files changed

+127
-7
lines changed

backend/services/model_provider_service.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a
8989
model_repo, model_name = split_repo_name(model["id"])
9090
model_display_name = add_repo_to_name(model_repo, model_name)
9191

92+
# Initialize chunk size variables for all model types; only embeddings use them
93+
expected_chunk_size = None
94+
maximum_chunk_size = None
95+
9296
# For embedding models, apply default values when chunk sizes are null
9397
if model["model_type"] in ["embedding", "multi_embedding"]:
9498
expected_chunk_size = model.get("expected_chunk_size", DEFAULT_EXPECTED_CHUNK_SIZE)

test/backend/services/test_model_provider_service.py

Lines changed: 123 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,19 +172,89 @@ async def test_get_models_exception():
172172

173173
@pytest.mark.asyncio
174174
async def test_prepare_model_dict_llm():
175-
"""LLM models should not trigger embedding_dimension_check and keep base_url untouched."""
176-
with mock.patch("backend.services.model_provider_service.split_repo_name", return_value=("openai", "gpt-4")), \
177-
mock.patch("backend.services.model_provider_service.add_repo_to_name", return_value="openai/gpt-4"):
175+
"""LLM models should not call emb dim check; chunk sizes are None; base_url untouched."""
176+
with mock.patch("backend.services.model_provider_service.split_repo_name", return_value=("openai", "gpt-4")) as mock_split_repo, \
177+
mock.patch("backend.services.model_provider_service.add_repo_to_name", return_value="openai/gpt-4") as mock_add_repo_to_name, \
178+
mock.patch("backend.services.model_provider_service.ModelRequest") as mock_model_request, \
179+
mock.patch("backend.services.model_provider_service.embedding_dimension_check", new_callable=mock.AsyncMock) as mock_emb_dim_check:
180+
181+
mock_model_req_instance = mock.MagicMock()
182+
dump_dict = {
183+
"model_factory": "openai",
184+
"model_name": "gpt-4",
185+
"model_type": "llm",
186+
"api_key": "test-key",
187+
"max_tokens": sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS,
188+
"display_name": "openai/gpt-4",
189+
}
190+
mock_model_req_instance.model_dump.return_value = dump_dict
191+
mock_model_request.return_value = mock_model_req_instance
178192

179-
# Current implementation passes chunk-size kwargs unconditionally,
180-
# which raises UnboundLocalError for non-embedding types. Assert that.
181193
provider = "openai"
182194
model = {"id": "openai/gpt-4", "model_type": "llm", "max_tokens": sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS}
183195
base_url = "https://api.openai.com/v1"
184196
api_key = "test-key"
185197

186-
with pytest.raises(UnboundLocalError):
187-
await prepare_model_dict(provider, model, base_url, api_key)
198+
result = await prepare_model_dict(provider, model, base_url, api_key)
199+
200+
mock_split_repo.assert_called_once_with("openai/gpt-4")
201+
mock_add_repo_to_name.assert_called_once_with("openai", "gpt-4")
202+
203+
# Ensure chunk sizes are None for non-embedding types and emb check not called
204+
_, kwargs = mock_model_request.call_args
205+
assert kwargs["expected_chunk_size"] is None
206+
assert kwargs["maximum_chunk_size"] is None
207+
mock_emb_dim_check.assert_not_called()
208+
209+
expected = dump_dict | {
210+
"model_repo": "openai",
211+
"base_url": "https://api.openai.com/v1",
212+
"connect_status": "not_detected",
213+
}
214+
assert result == expected
215+
216+
217+
@pytest.mark.asyncio
218+
async def test_prepare_model_dict_vlm():
219+
"""VLM models should behave like LLM: no emb dim check; chunk sizes None; base_url untouched."""
220+
with mock.patch("backend.services.model_provider_service.split_repo_name", return_value=("openai", "gpt-4-vision")) as mock_split_repo, \
221+
mock.patch("backend.services.model_provider_service.add_repo_to_name", return_value="openai/gpt-4-vision") as mock_add_repo_to_name, \
222+
mock.patch("backend.services.model_provider_service.ModelRequest") as mock_model_request, \
223+
mock.patch("backend.services.model_provider_service.embedding_dimension_check", new_callable=mock.AsyncMock) as mock_emb_dim_check:
224+
225+
mock_model_req_instance = mock.MagicMock()
226+
dump_dict = {
227+
"model_factory": "openai",
228+
"model_name": "gpt-4-vision",
229+
"model_type": "vlm",
230+
"api_key": "test-key",
231+
"max_tokens": sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS,
232+
"display_name": "openai/gpt-4-vision",
233+
}
234+
mock_model_req_instance.model_dump.return_value = dump_dict
235+
mock_model_request.return_value = mock_model_req_instance
236+
237+
provider = "openai"
238+
model = {"id": "openai/gpt-4-vision", "model_type": "vlm", "max_tokens": sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS}
239+
base_url = "https://api.openai.com/v1"
240+
api_key = "test-key"
241+
242+
result = await prepare_model_dict(provider, model, base_url, api_key)
243+
244+
mock_split_repo.assert_called_once_with("openai/gpt-4-vision")
245+
mock_add_repo_to_name.assert_called_once_with("openai", "gpt-4-vision")
246+
247+
_, kwargs = mock_model_request.call_args
248+
assert kwargs["expected_chunk_size"] is None
249+
assert kwargs["maximum_chunk_size"] is None
250+
mock_emb_dim_check.assert_not_called()
251+
252+
expected = dump_dict | {
253+
"model_repo": "openai",
254+
"base_url": "https://api.openai.com/v1",
255+
"connect_status": "not_detected",
256+
}
257+
assert result == expected
188258

189259

190260
@pytest.mark.asyncio
@@ -293,6 +363,52 @@ async def test_prepare_model_dict_embedding_with_explicit_chunk_sizes():
293363
assert result == expected
294364

295365

366+
@pytest.mark.asyncio
367+
async def test_prepare_model_dict_multi_embedding_defaults():
368+
"""multi_embedding should mirror embedding: default chunk sizes and emb base_url."""
369+
with mock.patch("backend.services.model_provider_service.split_repo_name", return_value=("openai", "text-embedding-3-large")) as mock_split_repo, \
370+
mock.patch("backend.services.model_provider_service.add_repo_to_name", return_value="openai/text-embedding-3-large") as mock_add_repo_to_name, \
371+
mock.patch("backend.services.model_provider_service.ModelRequest") as mock_model_request, \
372+
mock.patch("backend.services.model_provider_service.embedding_dimension_check", new_callable=mock.AsyncMock, return_value=1536) as mock_emb_dim_check, \
373+
mock.patch("backend.services.model_provider_service.ModelConnectStatusEnum") as mock_enum:
374+
375+
mock_model_req_instance = mock.MagicMock()
376+
dump_dict = {
377+
"model_factory": "openai",
378+
"model_name": "text-embedding-3-large",
379+
"model_type": "multi_embedding",
380+
"api_key": "test-key",
381+
"max_tokens": 1024,
382+
"display_name": "openai/text-embedding-3-large",
383+
}
384+
mock_model_req_instance.model_dump.return_value = dump_dict
385+
mock_model_request.return_value = mock_model_req_instance
386+
mock_enum.NOT_DETECTED.value = "not_detected"
387+
388+
provider = "openai"
389+
model = {"id": "openai/text-embedding-3-large", "model_type": "multi_embedding", "max_tokens": 1024}
390+
base_url = "https://api.openai.com/v1/"
391+
api_key = "test-key"
392+
393+
result = await prepare_model_dict(provider, model, base_url, api_key)
394+
395+
mock_split_repo.assert_called_once_with("openai/text-embedding-3-large")
396+
mock_add_repo_to_name.assert_called_once_with("openai", "text-embedding-3-large")
397+
398+
_, kwargs = mock_model_request.call_args
399+
assert kwargs["expected_chunk_size"] == sys.modules["consts.const"].DEFAULT_EXPECTED_CHUNK_SIZE
400+
assert kwargs["maximum_chunk_size"] == sys.modules["consts.const"].DEFAULT_MAXIMUM_CHUNK_SIZE
401+
mock_emb_dim_check.assert_called_once_with(dump_dict)
402+
403+
expected = dump_dict | {
404+
"model_repo": "openai",
405+
"base_url": "https://api.openai.com/v1/embeddings",
406+
"connect_status": "not_detected",
407+
"max_tokens": 1536,
408+
}
409+
assert result == expected
410+
411+
296412
# ---------------------------------------------------------------------------
297413
# Test-cases for merge_existing_model_tokens
298414
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)