Skip to content

Commit a80b741

Browse files
committed
Fixed middlewares and qdrant tests
1 parent f340145 commit a80b741

File tree

2 files changed

+49
-21
lines changed

2 files changed

+49
-21
lines changed

tests/test_vector_mcp.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,10 @@ def test_insert_docs_success(chromadb_instance, sample_docs):
7070

7171
chromadb_instance.insert_documents(sample_docs)
7272

73-
mock_index.insert_documents.assert_called_once()
74-
call_args = mock_index.insert_documents.call_args[0][0]
75-
assert len(call_args) == 2
73+
# The implementation iterates and calls insert for each doc
74+
assert mock_index.insert.call_count == len(sample_docs)
75+
# Check first call argument
76+
call_args = mock_index.insert.call_args_list[0][0]
7677
assert call_args[0].text == "Test document 1"
7778

7879

@@ -84,4 +85,4 @@ def test_update_docs_success(chromadb_instance, sample_docs):
8485

8586
chromadb_instance.update_documents(sample_docs)
8687

87-
mock_index.insert_documents.assert_called_once()
88+
assert mock_index.insert.call_count == len(sample_docs)

vector_mcp/utils.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -442,44 +442,71 @@ def filter_tools_by_tag(tools: List[Any], tag: str) -> List[Any]:
442442
return [t for t in tools if tool_in_tag(t, tag)]
443443

444444

445-
def get_embedding_model() -> BaseEmbedding:
445+
def get_embedding_model(
446+
provider: Optional[str] = os.environ.get("EMBEDDING_PROVIDER", "openai").lower(),
447+
model: Optional[str] = os.environ.get(
448+
"EMBEDDING_MODEL", "text-embedding-nomic-embed-text-v2-moe"
449+
),
450+
base_url: Optional[str] = os.environ.get(
451+
"LLM_BASE_URL", "http://localhost:1234/v1"
452+
),
453+
api_key: Optional[str] = os.environ.get("LLM_API_KEY", None),
454+
ssl_verify: bool = to_boolean(string=os.environ.get("VECTOR_SSL_VERIFY", "true")),
455+
timeout: float = 300.0,
456+
) -> BaseEmbedding:
446457
"""
447-
Get the embedding model based on environment variables.
458+
Get the embedding model based on parameters or environment variables.
459+
460+
Args:
461+
provider: The embedding provider (openai, huggingface, ollama, local)
462+
model: The specific model ID to use
463+
base_url: Optional base URL for the API
464+
api_key: Optional API key
465+
ssl_verify: Whether to verify SSL certificates (default: True).
466+
Checks VECTOR_SSL_VERIFY env var if not explicitly disabled.
467+
timeout: Request timeout in seconds
448468
449469
Returns:
450470
BaseEmbedding: The LlamaIndex embedding model.
451471
"""
452-
provider = os.environ.get("EMBEDDING_PROVIDER", "openai").lower()
472+
http_client = None
473+
if not ssl_verify:
474+
http_client = httpx.AsyncClient(verify=False, timeout=timeout)
453475

454476
if provider == "openai":
455477
return OpenAIEmbedding(
456-
model_name=os.environ.get(
457-
"EMBEDDING_MODEL", "text-embedding-nomic-embed-text-v2-moe"
458-
),
459-
api_key=os.environ.get("LLM_API_KEY"),
460-
api_base=os.environ.get("LLM_BASE_URL", "http://localhost:1234/v1"),
461-
timeout=32400.0,
478+
model_name=model,
479+
api_key=api_key,
480+
api_base=base_url,
481+
timeout=timeout,
482+
http_client=http_client,
462483
)
484+
463485
elif provider == "huggingface":
464486
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
465487

488+
cache_folder = os.environ.get("HF_HOME")
489+
466490
return HuggingFaceEmbedding(
467-
model_name=os.environ.get("EMBEDDING_MODEL", "BAAI/bge-small-en-v1.5"),
468-
cache_folder=os.environ.get("HF_HOME"),
469-
request_timeout=32400.0,
491+
model_name=model,
492+
cache_folder=cache_folder,
493+
request_timeout=timeout,
470494
)
495+
471496
elif provider == "ollama":
472497
if OllamaEmbedding is None:
473498
raise ImportError("llama-index-embeddings-ollama is not installed.")
499+
474500
return OllamaEmbedding(
475-
model_name=os.environ.get("EMBEDDING_MODEL", "nomic-embed-text"),
476-
base_url=os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434"),
477-
timeout=32400.0,
501+
model_name=model,
502+
base_url=base_url,
503+
timeout=timeout,
478504
)
505+
479506
elif provider == "local":
480507
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
481508

482-
model_name = os.environ.get("EMBEDDING_MODEL", "all-MiniLM-L6-v2")
483-
return HuggingFaceEmbedding(model_name=model_name)
509+
return HuggingFaceEmbedding(model_name=model)
510+
484511
else:
485512
raise ValueError(f"Unsupported embedding provider: {provider}")

0 commit comments

Comments
 (0)