Skip to content

Commit 98fb7d9

Browse files
committed
✨config ModelEngine Service
1 parent 72e5051 commit 98fb7d9

File tree

6 files changed

+69
-27
lines changed

6 files changed

+69
-27
lines changed

backend/services/model_management_service.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ async def create_model_for_tenant(user_id: str, tenant_id: str, model_data: Dict
4848
model_data['ssl_verify'] = True
4949
if "open/router" in model_base_url:
5050
model_data['ssl_verify'] = False
51-
52-
5351
# Split model_name into repo and name
5452
model_repo, model_name = split_repo_name(
5553
model_data["model_name"]) if model_data.get("model_name") else ("", "")

backend/services/model_provider_service.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,9 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
8585
List of models with canonical fields
8686
"""
8787
try:
88-
# Allow overriding host and api key via provider_config (from frontend).
89-
# Fall back to environment-configured values.
9088
model_type: str = provider_config.get("model_type", "")
91-
host = provider_config.get("base_url") or MODEL_ENGINE_HOST
92-
api_key = provider_config.get("api_key") or MODEL_ENGINE_APIKEY
89+
host = provider_config.get("base_url")
90+
api_key = provider_config.get("api_key")
9391

9492
if not host or not api_key:
9593
logger.warning("ModelEngine host or api key not configured")
@@ -135,7 +133,6 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
135133
"model_type": internal_type,
136134
"model_tag": me_type,
137135
"max_tokens": DEFAULT_LLM_MAX_TOKENS if internal_type in ("llm", "vlm") else 0,
138-
# ModelEngine models will get base_url and api_key from provider_config (or env)
139136
"base_url": host,
140137
"api_key": api_key,
141138
})

test/backend/services/test_model_health_service.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from consts.exceptions import TimeoutException
2-
import asyncio
31
import os
42
import sys
53
from unittest import mock
@@ -792,5 +790,3 @@ async def test_embedding_dimension_check_wrapper_value_error():
792790
mock_logger.error.assert_called_once_with(
793791
"Error checking embedding dimension: Unsupported model type"
794792
)
795-
796-

test/backend/services/test_model_management_service.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ async def test_create_model_for_tenant_success_llm():
307307
"base_url": "http://localhost:8000",
308308
"model_type": "llm",
309309
}
310+
model_data['ssl_verify'] = False
310311

311312
await svc.create_model_for_tenant(user_id, tenant_id, model_data)
312313

@@ -316,6 +317,32 @@ async def test_create_model_for_tenant_success_llm():
316317
assert mock_create.call_count == 1
317318

318319

320+
@pytest.mark.asyncio
321+
async def test_create_model_for_tenant_open_router_disables_ssl():
322+
"""When base_url contains 'open/router' ssl_verify should be set to False."""
323+
svc = import_svc()
324+
325+
with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \
326+
mock.patch.object(svc, "create_model_record") as mock_create, \
327+
mock.patch.object(svc, "split_repo_name", return_value=("modelengine", "m")):
328+
329+
user_id = "u1"
330+
tenant_id = "t1"
331+
model_data = {
332+
"model_name": "modelengine/m",
333+
"display_name": None,
334+
"base_url": "https://api.example.com/open/router/v1",
335+
"model_type": "llm",
336+
}
337+
338+
await svc.create_model_for_tenant(user_id, tenant_id, model_data)
339+
340+
# Ensure a single record created and ssl_verify was disabled
341+
assert mock_create.call_count == 1
342+
create_args = mock_create.call_args[0][0]
343+
assert create_args["ssl_verify"] is False
344+
345+
319346
@pytest.mark.asyncio
320347
async def test_create_model_for_tenant_conflict_raises():
321348
svc = import_svc()
@@ -459,7 +486,7 @@ async def test_create_model_for_tenant_multi_embedding_sets_default_chunk_batch(
459486
mock_dim.assert_awaited_once()
460487
# Should create two records: multi_embedding and its embedding variant
461488
assert mock_create.call_count == 2
462-
489+
463490
# Verify chunk_batch was set to 10 for both records
464491
create_calls = mock_create.call_args_list
465492
# First call is for multi_embedding
@@ -519,7 +546,7 @@ async def test_batch_create_models_for_tenant_other_provider():
519546
if not hasattr(svc.ProviderEnum, 'MODELENGINE'):
520547
modelengine_item = _EnumItem("modelengine")
521548
svc.ProviderEnum.MODELENGINE = modelengine_item
522-
549+
523550
with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[]), \
524551
mock.patch.object(svc, "delete_model_record"), \
525552
mock.patch.object(svc, "split_repo_name", return_value=("openai", "gpt-4")), \
@@ -529,7 +556,7 @@ async def test_batch_create_models_for_tenant_other_provider():
529556
mock.patch.object(svc, "create_model_record", return_value=True):
530557

531558
await svc.batch_create_models_for_tenant("u1", "t1", batch_payload)
532-
559+
533560
# Verify prepare_model_dict was called with empty model_url for non-Silicon/ModelEngine provider
534561
call_args = svc.prepare_model_dict.call_args
535562
assert call_args[1]["model_url"] == "" # Should be empty for other providers
@@ -618,7 +645,7 @@ def get_by_display(display_name, tenant_id):
618645
update_calls = [call for call in mock_update.call_args_list if call[0][0] == "id1"]
619646
if update_calls:
620647
assert update_calls[0][0][1] == {"max_tokens": 8192}
621-
648+
622649
# Should NOT update model2 (max_tokens same) or model3 (new max_tokens is None)
623650
# Verify model2 and model3 were not updated
624651
model2_calls = [call for call in mock_update.call_args_list if call[0][0] == "id2"]

test/backend/services/test_model_provider_service.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -776,11 +776,9 @@ async def test_modelengine_get_models_llm_success():
776776
"""ModelEngine provider should return LLM models with correct type mapping."""
777777
from backend.services.model_provider_service import ModelEngineProvider
778778

779-
provider_config = {"model_type": "llm"}
779+
provider_config = {"model_type": "llm", "base_url": "https://model-engine.com", "api_key": "test-key"}
780780

781-
with mock.patch("backend.services.model_provider_service.MODEL_ENGINE_HOST", "https://model-engine.com"), \
782-
mock.patch("backend.services.model_provider_service.MODEL_ENGINE_APIKEY", "test-key"), \
783-
mock.patch("backend.services.model_provider_service.aiohttp.ClientSession") as mock_session_class, \
781+
with mock.patch("backend.services.model_provider_service.aiohttp.ClientSession") as mock_session_class, \
784782
mock.patch("backend.services.model_provider_service.aiohttp.ClientTimeout"), \
785783
mock.patch("backend.services.model_provider_service.aiohttp.TCPConnector"):
786784

@@ -825,11 +823,9 @@ async def test_modelengine_get_models_embedding_success():
825823
"""ModelEngine provider should return embedding models with correct type mapping."""
826824
from backend.services.model_provider_service import ModelEngineProvider
827825

828-
provider_config = {"model_type": "embedding"}
826+
provider_config = {"model_type": "embedding", "base_url": "https://model-engine.com", "api_key": "test-key"}
829827

830-
with mock.patch("backend.services.model_provider_service.MODEL_ENGINE_HOST", "https://model-engine.com"), \
831-
mock.patch("backend.services.model_provider_service.MODEL_ENGINE_APIKEY", "test-key"), \
832-
mock.patch("backend.services.model_provider_service.aiohttp.ClientSession") as mock_session_class, \
828+
with mock.patch("backend.services.model_provider_service.aiohttp.ClientSession") as mock_session_class, \
833829
mock.patch("backend.services.model_provider_service.aiohttp.ClientTimeout"), \
834830
mock.patch("backend.services.model_provider_service.aiohttp.TCPConnector"):
835831

@@ -871,11 +867,9 @@ async def test_modelengine_get_models_all_types():
871867
"""ModelEngine provider should return all models when no type filter specified."""
872868
from backend.services.model_provider_service import ModelEngineProvider
873869

874-
provider_config = {} # No model_type filter
870+
provider_config = {"base_url": "https://model-engine.com", "api_key": "test-key"} # No model_type filter
875871

876-
with mock.patch("backend.services.model_provider_service.MODEL_ENGINE_HOST", "https://model-engine.com"), \
877-
mock.patch("backend.services.model_provider_service.MODEL_ENGINE_APIKEY", "test-key"), \
878-
mock.patch("backend.services.model_provider_service.aiohttp.ClientSession") as mock_session_class, \
872+
with mock.patch("backend.services.model_provider_service.aiohttp.ClientSession") as mock_session_class, \
879873
mock.patch("backend.services.model_provider_service.aiohttp.ClientTimeout"), \
880874
mock.patch("backend.services.model_provider_service.aiohttp.TCPConnector"):
881875

test/backend/test_model_consts.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
from pydantic import ValidationError
3+
4+
from backend.consts import model as model_consts
5+
6+
7+
def test_model_connect_status_enum_defaults_and_get_value():
8+
assert model_consts.ModelConnectStatusEnum.get_default() == "not_detected"
9+
assert model_consts.ModelConnectStatusEnum.get_value("") == "not_detected"
10+
assert model_consts.ModelConnectStatusEnum.get_value(None) == "not_detected"
11+
assert model_consts.ModelConnectStatusEnum.get_value("available") == "available"
12+
13+
14+
def test_model_request_and_validation():
15+
# Basic construction
16+
mr = model_consts.ModelRequest(model_name="mymodel", model_type="llm")
17+
assert mr.model_name == "mymodel"
18+
assert mr.model_type == "llm"
19+
20+
# Chunk create request requires non-empty content
21+
with pytest.raises(ValidationError):
22+
model_consts.ChunkCreateRequest(content="")
23+
24+
# Valid chunk create
25+
req = model_consts.ChunkCreateRequest(content="a", title="t", filename="f")
26+
assert req.content == "a"
27+
assert req.title == "t"
28+
assert req.filename == "f"
29+
30+

0 commit comments

Comments
 (0)