diff --git a/wren-ai-service/src/__main__.py b/wren-ai-service/src/__main__.py index de141c3fde..1323d8c5b3 100644 --- a/wren-ai-service/src/__main__.py +++ b/wren-ai-service/src/__main__.py @@ -1,7 +1,7 @@ from contextlib import asynccontextmanager import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, RedirectResponse @@ -9,11 +9,16 @@ from src.config import settings from src.globals import ( + create_pipe_components, create_service_container, create_service_metadata, ) from src.providers import generate_components +from src.providers.document_store.qdrant import QdrantProvider +from src.providers.embedder.litellm import LitellmEmbedderProvider +from src.providers.llm.litellm import LitellmLLMProvider from src.utils import ( + Configs, init_langfuse, setup_custom_logger, ) @@ -28,8 +33,13 @@ @asynccontextmanager async def lifespan(app: FastAPI): # startup events - pipe_components = generate_components(settings.components) + pipe_components, instantiated_providers = generate_components(settings.components) + app.state.pipe_components = pipe_components + app.state.instantiated_providers = instantiated_providers app.state.service_container = create_service_container(pipe_components, settings) + app.state.pipe_service_components = create_pipe_components( + app.state.service_container + ) app.state.service_metadata = create_service_metadata(pipe_components) init_langfuse(settings) @@ -86,6 +96,176 @@ def health(): return {"status": "ok"} +@app.get("/configs") +def get_configs(): + _configs = { + "env_vars": {}, + "providers": { + "llm": [], + "embedder": [], + }, + "pipelines": {}, + } + + _llm_model_alias_mapping = {} + _embedder_model_alias_mapping = {} + + _llm_configs = [] + for _, model_config in app.state.instantiated_providers["llm"].items(): + _llm_config = { + "model": model_config._model, + "alias": model_config._alias, + "context_window_size": model_config._context_window_size, + "timeout": model_config._timeout, + "kwargs": model_config._model_kwargs, + } + if model_config._api_base: + _llm_config["api_base"] = model_config._api_base + if model_config._api_version: + _llm_config["api_version"] = model_config._api_version + _llm_configs.append(_llm_config) + _llm_model_alias_mapping[model_config._model] = model_config._alias + _configs["providers"]["llm"] = _llm_configs + + _embedder_configs = [] + # we only support one embedding model now + for _, model_config in app.state.instantiated_providers["embedder"].items(): + _embedder_config = { + "model": model_config._model, + "alias": model_config._alias, + "dimension": app.state.instantiated_providers["document_store"][ + "qdrant" + ]._embedding_model_dim, + "timeout": model_config._timeout, + "kwargs": model_config._model_kwargs, + } + if model_config._api_base: + _embedder_config["api_base"] = model_config._api_base + if model_config._api_version: + _embedder_config["api_version"] = model_config._api_version + _embedder_configs.append(_embedder_config) + _embedder_model_alias_mapping[model_config._model] = model_config._alias + break + _configs["providers"]["embedder"] = _embedder_configs + + for pipe_name, pipe_component in app.state.pipe_service_components.items(): + llm_model = pipe_component.get("llm", None) + embedding_model = pipe_component.get("embedder", None) + description = pipe_component.get("description", "") + if llm_model or embedding_model: + _configs["pipelines"][pipe_name] = { + "has_db_data_in_llm_prompt": pipe_component.get( + "has_db_data_in_llm_prompt", False + ), + "description": description, + } + if llm_model: + if llm_model_alias := _llm_model_alias_mapping.get(llm_model): + _configs["pipelines"][pipe_name]["llm"] = llm_model_alias + else: + _configs["pipelines"][pipe_name]["llm"] = llm_model + if embedding_model: + if embedding_model_alias := _embedder_model_alias_mapping.get( + embedding_model + ): + _configs["pipelines"][pipe_name]["embedder"] = embedding_model_alias + else: + _configs["pipelines"][pipe_name]["embedder"] = embedding_model + + return _configs + + +@app.post("/configs") +def update_configs(configs_request: Configs): + try: + # override current instantiated_providers + app.state.instantiated_providers["embedder"] = { + f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider( + **embedder_provider.__dict__ + ) + for embedder_provider in configs_request.providers.embedder + } + app.state.instantiated_providers["llm"] = { + f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider( + **llm_provider.__dict__ + ) + for llm_provider in configs_request.providers.llm + } + app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider( + location=app.state.instantiated_providers["document_store"][ + "qdrant" + ]._location, + api_key=app.state.instantiated_providers["document_store"][ + "qdrant" + ]._api_key, + timeout=app.state.instantiated_providers["document_store"][ + "qdrant" + ]._timeout, + embedding_model_dim=configs_request.providers.embedder[0].dimension, + recreate_index=True, + ) + _embedder_providers = app.state.instantiated_providers["embedder"] + _llm_providers = app.state.instantiated_providers["llm"] + _document_store_provider = app.state.instantiated_providers["document_store"][ + "qdrant" + ] + + # override current pipe_components + for ( + pipe_name, + pipe_service_components, + ) in app.state.pipe_service_components.items(): + if pipe_name in configs_request.pipelines: + pipe_config = configs_request.pipelines[pipe_name] + pipe_service_components.update(pipe_config) + + # updating pipelines + for ( + pipeline_name, + pipe_service_components, + ) in app.state.pipe_service_components.items(): + for service in pipe_service_components.get("services", []): + if pipe_config := configs_request.pipelines.get(pipeline_name): + service._pipelines[pipeline_name].update_components( + llm_provider=( + _llm_providers[f"litellm_llm.{pipe_config.llm}"] + if pipe_config.llm + else None + ), + embedder_provider=( + _embedder_providers[ + f"litellm_embedder.{pipe_config.embedder}" + ] + if pipe_config.embedder + else None + ), + document_store_provider=( + _document_store_provider + if service._pipelines[ + pipeline_name + ]._document_store_provider + else None + ), + ) + else: + if service._pipelines[pipeline_name]._document_store_provider: + service._pipelines[pipeline_name].update_components( + llm_provider=service._pipelines[ + pipeline_name + ]._llm_provider, + embedder_provider=service._pipelines[ + pipeline_name + ]._embedder_provider, + document_store_provider=_document_store_provider, + ) + + # TODO: updating service_metadata + for pipeline_name, _ in app.state.pipe_components.items(): + pass + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error updating configs: {e}") + + if __name__ == "__main__": uvicorn.run( "src.__main__:app", diff --git a/wren-ai-service/src/core/pipeline.py b/wren-ai-service/src/core/pipeline.py index 02a477845c..9946ca5d53 100644 --- a/wren-ai-service/src/core/pipeline.py +++ b/wren-ai-service/src/core/pipeline.py @@ -1,7 +1,7 @@ from abc import ABCMeta, abstractmethod from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Dict +from typing import Any, Dict, Optional from hamilton.async_driver import AsyncDriver from hamilton.driver import Driver @@ -14,14 +14,42 @@ class BasicPipeline(metaclass=ABCMeta): def __init__(self, pipe: Pipeline | AsyncDriver | Driver): self._pipe = pipe + self._description = "" + self._llm_provider = None + self._embedder_provider = None + self._document_store_provider = None + self._components = {} @abstractmethod def run(self, *args, **kwargs) -> Dict[str, Any]: ... + def _update_components(self) -> dict: + ... + + def update_components( + self, + llm_provider: Optional[LLMProvider] = None, + embedder_provider: Optional[EmbedderProvider] = None, + document_store_provider: Optional[DocumentStoreProvider] = None, + update_components: bool = True, + ): + if llm_provider: + self._llm_provider = llm_provider + if embedder_provider: + self._embedder_provider = embedder_provider + if document_store_provider: + self._document_store_provider = document_store_provider + if update_components: + self._components = self._update_components() + + def __str__(self): + return f"BasicPipeline(llm_provider={self._llm_provider}, embedder_provider={self._embedder_provider}, document_store_provider={self._document_store_provider})" + @dataclass class PipelineComponent(Mapping): + description: str = None llm_provider: LLMProvider = None embedder_provider: EmbedderProvider = None document_store_provider: DocumentStoreProvider = None @@ -35,3 +63,6 @@ def __iter__(self): def __len__(self): return len(self.__dict__) + + def __str__(self): + return f"PipelineComponent(description={self.description}, llm_provider={self.llm_provider}, embedder_provider={self.embedder_provider}, document_store_provider={self.document_store_provider}, engine={self.engine})" diff --git a/wren-ai-service/src/core/provider.py b/wren-ai-service/src/core/provider.py index 4246dd4908..9b89d89a59 100644 --- a/wren-ai-service/src/core/provider.py +++ b/wren-ai-service/src/core/provider.py @@ -8,13 +8,20 @@ class LLMProvider(metaclass=ABCMeta): def get_generator(self, *args, **kwargs): ... - def get_model(self): + @property + def alias(self): + return self._alias + + @property + def model(self): return self._model - def get_model_kwargs(self): + @property + def model_kwargs(self): return self._model_kwargs - def get_context_window_size(self): + @property + def context_window_size(self): return self._context_window_size @@ -27,8 +34,17 @@ def get_text_embedder(self, *args, **kwargs): def get_document_embedder(self, *args, **kwargs): ... - def get_model(self): - return self._embedding_model + @property + def alias(self): + return self._alias + + @property + def model(self): + return self._model + + @property + def model_kwargs(self): + return self._model_kwargs class DocumentStoreProvider(metaclass=ABCMeta): diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index d6c40e05b4..fab06456fa 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -7,7 +7,7 @@ from src.core.pipeline import PipelineComponent from src.core.provider import EmbedderProvider, LLMProvider from src.pipelines import generation, indexing, retrieval -from src.utils import fetch_wren_ai_docs +from src.utils import fetch_wren_ai_docs, has_db_data_in_llm_prompt from src.web.v1 import services logger = logging.getLogger("wren-ai-service") @@ -104,8 +104,8 @@ def create_service_container( "table_description": indexing.TableDescription( **pipe_components["table_description_indexing"], ), - "sql_pairs": _sql_pair_indexing_pipeline, - "instructions": _instructions_indexing_pipeline, + "sql_pairs_indexing": _sql_pair_indexing_pipeline, + "instructions_indexing": _instructions_indexing_pipeline, "project_meta": indexing.ProjectMeta( **pipe_components["project_meta_indexing"], ), @@ -231,7 +231,7 @@ def create_service_container( ), sql_pairs_service=services.SqlPairsService( pipelines={ - "sql_pairs": _sql_pair_indexing_pipeline, + "sql_pairs_indexing": _sql_pair_indexing_pipeline, }, **query_cache, ), @@ -262,6 +262,30 @@ def create_service_container( ) +def create_pipe_components(service_container: ServiceContainer): + _pipe_components = {} + for _, service in service_container.__dict__.items(): + for pipe_name, pipe in service._pipelines.items(): + if pipe_name not in _pipe_components: + _pipe_components[pipe_name] = {} + if hasattr(pipe, "_llm_provider") and pipe._llm_provider is not None: + _pipe_components[pipe_name]["llm"] = pipe._llm_provider.alias + if ( + hasattr(pipe, "_embedder_provider") + and pipe._embedder_provider is not None + ): + _pipe_components[pipe_name]["embedder"] = pipe._embedder_provider.alias + if "services" not in _pipe_components[pipe_name]: + _pipe_components[pipe_name]["services"] = set() + _pipe_components[pipe_name]["services"].add(service) + _pipe_components[pipe_name][ + "has_db_data_in_llm_prompt" + ] = has_db_data_in_llm_prompt(pipe_name) + _pipe_components[pipe_name]["description"] = pipe._description or "" + + return _pipe_components + + # Create a dependency that will be used to access the ServiceContainer def get_service_container(): from src.__main__ import app @@ -289,8 +313,8 @@ def _convert_pipe_metadata( ) -> dict: llm_metadata = ( { - "llm_model": llm_provider.get_model(), - "llm_model_kwargs": llm_provider.get_model_kwargs(), + "llm_model": llm_provider.model, + "llm_model_kwargs": llm_provider.model_kwargs, } if llm_provider else {} @@ -298,7 +322,7 @@ def _convert_pipe_metadata( embedding_metadata = ( { - "embedding_model": embedder_provider.get_model(), + "embedding_model": embedder_provider.model, } if embedder_provider else {} diff --git a/wren-ai-service/src/pipelines/generation/chart_adjustment.py b/wren-ai-service/src/pipelines/generation/chart_adjustment.py index 2e4e7ba8a3..e0abace0d6 100644 --- a/wren-ai-service/src/pipelines/generation/chart_adjustment.py +++ b/wren-ai-service/src/pipelines/generation/chart_adjustment.py @@ -150,31 +150,38 @@ class ChartAdjustment(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): - self._components = { + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._llm_provider = llm_provider + self._description = description + self._components = self._update_components() + + with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f: + _vega_schema = orjson.loads(f.read()) + + self._configs = { + "vega_schema": _vega_schema, + } + + def _update_components(self): + return { "prompt_builder": PromptBuilder( template=chart_adjustment_user_prompt_template ), - "generator": llm_provider.get_generator( + "generator": self._llm_provider.get_generator( system_prompt=chart_adjustment_system_prompt, generation_kwargs=CHART_ADJUSTMENT_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "chart_data_preprocessor": ChartDataPreprocessor(), "post_processor": ChartGenerationPostProcessor(), } - with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f: - _vega_schema = orjson.loads(f.read()) - - self._configs = { - "vega_schema": _vega_schema, - } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Chart Adjustment") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/chart_generation.py b/wren-ai-service/src/pipelines/generation/chart_generation.py index 6daca5ec17..c091179cb1 100644 --- a/wren-ai-service/src/pipelines/generation/chart_generation.py +++ b/wren-ai-service/src/pipelines/generation/chart_generation.py @@ -123,32 +123,38 @@ class ChartGeneration(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): - self._components = { + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._llm_provider = llm_provider + self._description = description + self._components = self._update_components() + + with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f: + _vega_schema = orjson.loads(f.read()) + + self._configs = { + "vega_schema": _vega_schema, + } + + def _update_components(self): + return { "prompt_builder": PromptBuilder( template=chart_generation_user_prompt_template ), - "generator": llm_provider.get_generator( + "generator": self._llm_provider.get_generator( system_prompt=chart_generation_system_prompt, generation_kwargs=CHART_GENERATION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "chart_data_preprocessor": ChartDataPreprocessor(), "post_processor": ChartGenerationPostProcessor(), } - with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f: - _vega_schema = orjson.loads(f.read()) - - self._configs = { - "vega_schema": _vega_schema, - } - - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Chart Generation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/data_assistance.py b/wren-ai-service/src/pipelines/generation/data_assistance.py index 51b91197f9..451c1e96c8 100644 --- a/wren-ai-service/src/pipelines/generation/data_assistance.py +++ b/wren-ai-service/src/pipelines/generation/data_assistance.py @@ -93,24 +93,30 @@ class DataAssistance(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._user_queues = {} - self._components = { - "generator": llm_provider.get_generator( + self._llm_provider = llm_provider + self._description = description + self._components = self._update_components() + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=data_assistance_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=data_assistance_user_prompt_template ), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _streaming_callback(self, chunk, query_id): if query_id not in self._user_queues: self._user_queues[ diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py index 5f889ff00c..7cd076bb8a 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py @@ -154,28 +154,51 @@ def __init__( llm_provider: LLMProvider, document_store_provider: DocumentStoreProvider, engine: Engine, + description: str = "", **kwargs, ): - self._retriever = document_store_provider.get_retriever( - document_store_provider.get_store("project_meta") + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._document_store_provider = document_store_provider + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") ) + self._llm_provider = llm_provider + self._engine = engine + self._description = description + self._components = self._update_components() - self._components = { - "generator": llm_provider.get_generator( + def update_components( + self, + llm_provider: LLMProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + llm_provider=llm_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") + ) + self._components = self._update_components() + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_generation_system_prompt, generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=text_to_sql_with_followup_user_prompt_template ), - "post_processor": SQLGenPostProcessor(engine=engine), + "post_processor": SQLGenPostProcessor(engine=self._engine), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Follow-Up SQL Generation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py index 42b28c5b8f..baf2cc6e39 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py @@ -115,24 +115,30 @@ class FollowUpSQLGenerationReasoning(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._user_queues = {} - self._components = { - "generator": llm_provider.get_generator( + self._llm_provider = llm_provider + self._description = description + self._components = self._update_components() + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_generation_reasoning_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=sql_generation_reasoning_user_prompt_template ), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _streaming_callback(self, chunk, query_id): if query_id not in self._user_queues: self._user_queues[query_id] = asyncio.Queue() diff --git a/wren-ai-service/src/pipelines/generation/intent_classification.py b/wren-ai-service/src/pipelines/generation/intent_classification.py index 4d6cd313cd..176b1e71c9 100644 --- a/wren-ai-service/src/pipelines/generation/intent_classification.py +++ b/wren-ai-service/src/pipelines/generation/intent_classification.py @@ -344,36 +344,48 @@ def __init__( wren_ai_docs: list[dict], table_retrieval_size: Optional[int] = 50, table_column_retrieval_size: Optional[int] = 100, + description: str = "", **kwargs, ): - self._components = { - "embedder": embedder_provider.get_text_embedder(), - "table_retriever": document_store_provider.get_retriever( - document_store_provider.get_store(dataset_name="table_descriptions"), - top_k=table_retrieval_size, + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._llm_provider = llm_provider + self._table_retrieval_size = table_retrieval_size + self._table_column_retrieval_size = table_column_retrieval_size + self._document_store_provider = document_store_provider + self._embedder_provider = embedder_provider + self._description = description + self._components = self._update_components() + + self._configs = { + "wren_ai_docs": wren_ai_docs, + } + + def _update_components(self): + return { + "embedder": self._embedder_provider.get_text_embedder(), + "table_retriever": self._document_store_provider.get_retriever( + self._document_store_provider.get_store( + dataset_name="table_descriptions" + ), + top_k=self._table_retrieval_size, ), - "dbschema_retriever": document_store_provider.get_retriever( - document_store_provider.get_store(), - top_k=table_column_retrieval_size, + "dbschema_retriever": self._document_store_provider.get_retriever( + self._document_store_provider.get_store(), + top_k=self._table_column_retrieval_size, ), - "generator": llm_provider.get_generator( + "generator": self._llm_provider.get_generator( system_prompt=intent_classification_system_prompt, generation_kwargs=INTENT_CLASSIFICAION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=intent_classification_user_prompt_template ), } - self._configs = { - "wren_ai_docs": wren_ai_docs, - } - - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Intent Classification") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/misleading_assistance.py b/wren-ai-service/src/pipelines/generation/misleading_assistance.py index a35738ecf5..ca59b7075d 100644 --- a/wren-ai-service/src/pipelines/generation/misleading_assistance.py +++ b/wren-ai-service/src/pipelines/generation/misleading_assistance.py @@ -93,24 +93,30 @@ class MisleadingAssistance(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._user_queues = {} - self._components = { - "generator": llm_provider.get_generator( + self._llm_provider = llm_provider + self._description = description + self._components = self._update_components() + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=misleading_assistance_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=misleading_assistance_user_prompt_template ), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _streaming_callback(self, chunk, query_id): if query_id not in self._user_queues: self._user_queues[ diff --git a/wren-ai-service/src/pipelines/generation/question_recommendation.py b/wren-ai-service/src/pipelines/generation/question_recommendation.py index a6e7c17b02..1ec08039c5 100644 --- a/wren-ai-service/src/pipelines/generation/question_recommendation.py +++ b/wren-ai-service/src/pipelines/generation/question_recommendation.py @@ -235,23 +235,29 @@ class QuestionRecommendation(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **_, ): - self._components = { + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._llm_provider = llm_provider + self._description = description + self._components = self._update_components() + + self._final = "normalized" + + def _update_components(self): + return { "prompt_builder": PromptBuilder(template=user_prompt_template), - "generator": llm_provider.get_generator( + "generator": self._llm_provider.get_generator( system_prompt=system_prompt, generation_kwargs=QUESTION_RECOMMENDATION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, } - self._final = "normalized" - - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Question Recommendation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/relationship_recommendation.py b/wren-ai-service/src/pipelines/generation/relationship_recommendation.py index e0d0bed675..e80dbc94c8 100644 --- a/wren-ai-service/src/pipelines/generation/relationship_recommendation.py +++ b/wren-ai-service/src/pipelines/generation/relationship_recommendation.py @@ -202,23 +202,29 @@ class RelationshipRecommendation(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **_, ): - self._components = { + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._llm_provider = llm_provider + self._description = description + self._components = self._update_components() + + self._final = "validated" + + def _update_components(self): + return { "prompt_builder": PromptBuilder(template=user_prompt_template), - "generator": llm_provider.get_generator( + "generator": self._llm_provider.get_generator( system_prompt=system_prompt, generation_kwargs=RELATIONSHIP_RECOMMENDATION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, } - self._final = "validated" - - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Relationship Recommendation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/semantics_description.py b/wren-ai-service/src/pipelines/generation/semantics_description.py index acc5fc8594..7d072c5e9e 100644 --- a/wren-ai-service/src/pipelines/generation/semantics_description.py +++ b/wren-ai-service/src/pipelines/generation/semantics_description.py @@ -217,20 +217,30 @@ class SemanticResult(BaseModel): class SemanticsDescription(BasicPipeline): - def __init__(self, llm_provider: LLMProvider, **_): - self._components = { + def __init__( + self, + llm_provider: LLMProvider, + description: str = "", + **_, + ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._llm_provider = llm_provider + self._description = description + self._components = self._update_components() + self._final = "output" + + def _update_components(self): + return { "prompt_builder": PromptBuilder(template=user_prompt_template), - "generator": llm_provider.get_generator( + "generator": self._llm_provider.get_generator( system_prompt=system_prompt, generation_kwargs=SEMANTICS_DESCRIPTION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, } - self._final = "output" - - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) @observe(name="Semantics Description Generation") async def run( diff --git a/wren-ai-service/src/pipelines/generation/sql_answer.py b/wren-ai-service/src/pipelines/generation/sql_answer.py index 81289081b5..58d833696e 100644 --- a/wren-ai-service/src/pipelines/generation/sql_answer.py +++ b/wren-ai-service/src/pipelines/generation/sql_answer.py @@ -93,24 +93,30 @@ class SQLAnswer(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._user_queues = {} - self._components = { + self._llm_provider = llm_provider + self._description = description + self._components = self._update_components() + + def _update_components(self): + return { "prompt_builder": PromptBuilder( template=sql_to_answer_user_prompt_template ), - "generator": llm_provider.get_generator( + "generator": self._llm_provider.get_generator( system_prompt=sql_to_answer_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _streaming_callback(self, chunk, query_id): if query_id not in self._user_queues: self._user_queues[ diff --git a/wren-ai-service/src/pipelines/generation/sql_correction.py b/wren-ai-service/src/pipelines/generation/sql_correction.py index 7a9bcc092b..5affc3192d 100644 --- a/wren-ai-service/src/pipelines/generation/sql_correction.py +++ b/wren-ai-service/src/pipelines/generation/sql_correction.py @@ -131,28 +131,48 @@ def __init__( llm_provider: LLMProvider, document_store_provider: DocumentStoreProvider, engine: Engine, + description: str = "", **kwargs, ): - self._retriever = document_store_provider.get_retriever( - document_store_provider.get_store("project_meta") + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._llm_provider = llm_provider + self._engine = engine + self._description = description + self._components = self._update_components() + self._document_store_provider = document_store_provider + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") ) - self._components = { - "generator": llm_provider.get_generator( + def update_components( + self, + llm_provider: LLMProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + llm_provider=llm_provider, document_store_provider=document_store_provider + ) + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") + ) + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_correction_system_prompt, generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=sql_correction_user_prompt_template ), - "post_processor": SQLGenPostProcessor(engine=engine), + "post_processor": SQLGenPostProcessor(engine=self._engine), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="SQL Correction") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_diagnosis.py b/wren-ai-service/src/pipelines/generation/sql_diagnosis.py index 3f22b9d512..4d79911ee0 100644 --- a/wren-ai-service/src/pipelines/generation/sql_diagnosis.py +++ b/wren-ai-service/src/pipelines/generation/sql_diagnosis.py @@ -117,23 +117,29 @@ class SQLDiagnosis(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): - self._components = { - "generator": llm_provider.get_generator( + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._llm_provider = llm_provider + self._description = description + self._components = self._update_components() + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_diagnosis_system_prompt, generation_kwargs=SQL_DIAGNOSIS_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=sql_diagnosis_user_prompt_template ), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="SQL Diagnosis") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index 27d3bc5eab..9073388866 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -151,28 +151,48 @@ def __init__( llm_provider: LLMProvider, document_store_provider: DocumentStoreProvider, engine: Engine, + description: str = "", **kwargs, ): - self._retriever = document_store_provider.get_retriever( - document_store_provider.get_store("project_meta") + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._document_store_provider = document_store_provider + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") ) + self._llm_provider = llm_provider + self._engine = engine + self._description = description + self._components = self._update_components() - self._components = { - "generator": llm_provider.get_generator( + def update_components( + self, + llm_provider: LLMProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + llm_provider=llm_provider, document_store_provider=document_store_provider + ) + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") + ) + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_generation_system_prompt, generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=sql_generation_user_prompt_template ), - "post_processor": SQLGenPostProcessor(engine=engine), + "post_processor": SQLGenPostProcessor(engine=self._engine), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="SQL Generation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py index 00b731cb2c..ed4c57aeb5 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py @@ -100,24 +100,30 @@ class SQLGenerationReasoning(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._user_queues = {} - self._components = { - "generator": llm_provider.get_generator( + self._llm_provider = llm_provider + self._description = description + self._components = self._update_components() + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_generation_reasoning_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=sql_generation_reasoning_user_prompt_template ), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _streaming_callback(self, chunk, query_id): if query_id not in self._user_queues: self._user_queues[query_id] = asyncio.Queue() diff --git a/wren-ai-service/src/pipelines/generation/sql_question.py b/wren-ai-service/src/pipelines/generation/sql_question.py index 81f3d3e62b..1b4f59d901 100644 --- a/wren-ai-service/src/pipelines/generation/sql_question.py +++ b/wren-ai-service/src/pipelines/generation/sql_question.py @@ -96,21 +96,27 @@ class SQLQuestion(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): - self._components = { - "generator": llm_provider.get_generator( + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._llm_provider = llm_provider + self._description = description + self._components = self._update_components() + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_question_system_prompt, generation_kwargs=SQL_QUESTION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder(template=sql_question_user_prompt_template), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Sql Question Generation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_regeneration.py b/wren-ai-service/src/pipelines/generation/sql_regeneration.py index 0dbf2fd808..7de247fcd6 100644 --- a/wren-ai-service/src/pipelines/generation/sql_regeneration.py +++ b/wren-ai-service/src/pipelines/generation/sql_regeneration.py @@ -155,24 +155,31 @@ def __init__( self, llm_provider: LLMProvider, engine: Engine, + description: str = "", **kwargs, ): - self._components = { - "generator": llm_provider.get_generator( + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._llm_provider = llm_provider + self._engine = engine + self._description = description + self._components = self._update_components() + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_regeneration_system_prompt, generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=sql_regeneration_user_prompt_template ), - "post_processor": SQLGenPostProcessor(engine=engine), + "post_processor": SQLGenPostProcessor(engine=self._engine), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="SQL Regeneration") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py b/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py index 7060e750b3..e8e6b07eb4 100644 --- a/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py +++ b/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py @@ -100,23 +100,29 @@ class SQLTablesExtraction(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): - self._components = { - "generator": llm_provider.get_generator( + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._llm_provider = llm_provider + self._description = description + self._components = self._update_components() + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_tables_extraction_system_prompt, generation_kwargs=SQL_TABLES_EXTRACTION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=sql_tables_extraction_user_prompt_template ), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Sql Tables Extraction") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py index be437f883b..fd72ee3c85 100644 --- a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py +++ b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py @@ -83,26 +83,33 @@ def __init__( self, llm_provider: LLMProvider, wren_ai_docs: list[dict], + description: str = "", **kwargs, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._user_queues = {} - self._components = { - "generator": llm_provider.get_generator( + self._llm_provider = llm_provider + self._description = description + self._components = self._update_components() + + self._configs = { + "wren_ai_docs": wren_ai_docs, + } + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=user_guide_assistance_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=user_guide_assistance_user_prompt_template ), } - self._configs = { - "wren_ai_docs": wren_ai_docs, - } - - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) def _streaming_callback(self, chunk, query_id): if query_id not in self._user_queues: diff --git a/wren-ai-service/src/pipelines/indexing/db_schema.py b/wren-ai-service/src/pipelines/indexing/db_schema.py index 394d087b46..f4a9218805 100644 --- a/wren-ai-service/src/pipelines/indexing/db_schema.py +++ b/wren-ai-service/src/pipelines/indexing/db_schema.py @@ -342,29 +342,51 @@ def __init__( embedder_provider: EmbedderProvider, document_store_provider: DocumentStoreProvider, column_batch_size: int = 50, + description: str = "", **kwargs, ) -> None: - dbschema_store = document_store_provider.get_store() + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider + self._dbschema_store = self._document_store_provider.get_store() + self._description = description + self._components = self._update_components() - self._components = { - "cleaner": DocumentCleaner([dbschema_store]), - "validator": MDLValidator(), - "embedder": embedder_provider.get_document_embedder(), - "chunker": DDLChunker(), - "writer": AsyncDocumentWriter( - document_store=dbschema_store, - policy=DuplicatePolicy.OVERWRITE, - ), - } self._configs = { "column_batch_size": column_batch_size, } self._final = "write" helper.load_helpers() - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, ) + self._dbschema_store = self._document_store_provider.get_store() + self._components = self._update_components() + + def _update_components(self): + return { + "cleaner": DocumentCleaner([self._dbschema_store]), + "validator": MDLValidator(), + "embedder": self._embedder_provider.get_document_embedder(), + "chunker": DDLChunker(), + "writer": AsyncDocumentWriter( + document_store=self._dbschema_store, + policy=DuplicatePolicy.OVERWRITE, + ), + } @observe(name="DB Schema Indexing") async def run( diff --git a/wren-ai-service/src/pipelines/indexing/historical_question.py b/wren-ai-service/src/pipelines/indexing/historical_question.py index ff30d91f0f..a5888a3c76 100644 --- a/wren-ai-service/src/pipelines/indexing/historical_question.py +++ b/wren-ai-service/src/pipelines/indexing/historical_question.py @@ -137,27 +137,52 @@ def __init__( self, embedder_provider: EmbedderProvider, document_store_provider: DocumentStoreProvider, + description: str = "", **kwargs, ) -> None: + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + # keep the store name as it is for now, might change in the future - store = document_store_provider.get_store(dataset_name="view_questions") + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider + self._store = self._document_store_provider.get_store( + dataset_name="view_questions" + ) + self._description = description + self._components = self._update_components() - self._components = { - "cleaner": DocumentCleaner([store]), + self._configs = {} + self._final = "write" + + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._store = self._document_store_provider.get_store( + dataset_name="view_questions" + ) + self._components = self._update_components() + + def _update_components(self): + return { + "cleaner": DocumentCleaner([self._store]), "validator": MDLValidator(), - "embedder": embedder_provider.get_document_embedder(), + "embedder": self._embedder_provider.get_document_embedder(), "chunker": ViewChunker(), "writer": AsyncDocumentWriter( - document_store=store, + document_store=self._store, policy=DuplicatePolicy.OVERWRITE, ), } - self._configs = {} - self._final = "write" - - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) @observe(name="Historical Question Indexing") async def run( diff --git a/wren-ai-service/src/pipelines/indexing/instructions.py b/wren-ai-service/src/pipelines/indexing/instructions.py index b23f3cf2ab..337efce174 100644 --- a/wren-ai-service/src/pipelines/indexing/instructions.py +++ b/wren-ai-service/src/pipelines/indexing/instructions.py @@ -129,24 +129,48 @@ def __init__( self, embedder_provider: EmbedderProvider, document_store_provider: DocumentStoreProvider, + description: str = "", **kwargs, ) -> None: - store = document_store_provider.get_store(dataset_name="instructions") + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider + self._store = self._document_store_provider.get_store( + dataset_name="instructions" + ) + self._description = description + self._components = self._update_components() + + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._store = self._document_store_provider.get_store( + dataset_name="instructions" + ) + self._components = self._update_components() - self._components = { - "cleaner": InstructionsCleaner(store), - "embedder": embedder_provider.get_document_embedder(), + def _update_components(self): + return { + "cleaner": InstructionsCleaner(self._store), + "embedder": self._embedder_provider.get_document_embedder(), "document_converter": InstructionsConverter(), "writer": AsyncDocumentWriter( - document_store=store, + document_store=self._store, policy=DuplicatePolicy.OVERWRITE, ), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Instructions Indexing") async def run( self, diff --git a/wren-ai-service/src/pipelines/indexing/project_meta.py b/wren-ai-service/src/pipelines/indexing/project_meta.py index 2ee4a08074..0ec516df15 100644 --- a/wren-ai-service/src/pipelines/indexing/project_meta.py +++ b/wren-ai-service/src/pipelines/indexing/project_meta.py @@ -67,23 +67,40 @@ class ProjectMeta(BasicPipeline): def __init__( self, document_store_provider: DocumentStoreProvider, + description: str = "", **kwargs, ) -> None: - store = document_store_provider.get_store(dataset_name="project_meta") + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._document_store_provider = document_store_provider + self._store = self._document_store_provider.get_store( + dataset_name="project_meta" + ) + self._description = description + + self._components = self._update_components() + self._final = "write" - self._components = { + def update_components(self, document_store_provider: DocumentStoreProvider, **_): + super().update_components( + document_store_provider=document_store_provider, update_components=False + ) + self._store = self._document_store_provider.get_store( + dataset_name="project_meta" + ) + self._components = self._update_components() + + def _update_components(self): + return { "validator": MDLValidator(), - "cleaner": DocumentCleaner([store]), + "cleaner": DocumentCleaner([self._store]), "writer": AsyncDocumentWriter( - document_store=store, + document_store=self._store, policy=DuplicatePolicy.OVERWRITE, ), } - self._final = "write" - - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) @observe(name="Project Meta Indexing") async def run( diff --git a/wren-ai-service/src/pipelines/indexing/sql_pairs.py b/wren-ai-service/src/pipelines/indexing/sql_pairs.py index a92fb36df1..a2c6d3e7ca 100644 --- a/wren-ai-service/src/pipelines/indexing/sql_pairs.py +++ b/wren-ai-service/src/pipelines/indexing/sql_pairs.py @@ -169,26 +169,47 @@ def __init__( embedder_provider: EmbedderProvider, document_store_provider: DocumentStoreProvider, sql_pairs_path: str = "sql_pairs.json", + description: str = "", **kwargs, ) -> None: - store = document_store_provider.get_store(dataset_name="sql_pairs") + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider + self._store = self._document_store_provider.get_store(dataset_name="sql_pairs") + self._description = description + + self._components = self._update_components() - self._components = { - "cleaner": SqlPairsCleaner(store), - "embedder": embedder_provider.get_document_embedder(), + self._external_pairs = _load_sql_pairs(sql_pairs_path) + + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._store = self._document_store_provider.get_store(dataset_name="sql_pairs") + self._components = self._update_components() + + def _update_components(self): + return { + "cleaner": SqlPairsCleaner(self._store), + "embedder": self._embedder_provider.get_document_embedder(), "document_converter": SqlPairsConverter(), "writer": AsyncDocumentWriter( - document_store=store, + document_store=self._store, policy=DuplicatePolicy.OVERWRITE, ), } - self._external_pairs = _load_sql_pairs(sql_pairs_path) - - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="SQL Pairs Indexing") async def run( self, diff --git a/wren-ai-service/src/pipelines/indexing/table_description.py b/wren-ai-service/src/pipelines/indexing/table_description.py index 6da100868f..69c542d517 100644 --- a/wren-ai-service/src/pipelines/indexing/table_description.py +++ b/wren-ai-service/src/pipelines/indexing/table_description.py @@ -120,28 +120,51 @@ def __init__( self, embedder_provider: EmbedderProvider, document_store_provider: DocumentStoreProvider, + description: str = "", **kwargs, ) -> None: - table_description_store = document_store_provider.get_store( + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider + self._table_description_store = self._document_store_provider.get_store( dataset_name="table_descriptions" ) + self._description = description - self._components = { - "cleaner": DocumentCleaner([table_description_store]), + self._components = self._update_components() + self._configs = {} + self._final = "write" + + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._table_description_store = self._document_store_provider.get_store( + dataset_name="table_descriptions" + ) + self._components = self._update_components() + + def _update_components(self): + return { + "cleaner": DocumentCleaner([self._table_description_store]), "validator": MDLValidator(), - "embedder": embedder_provider.get_document_embedder(), + "embedder": self._embedder_provider.get_document_embedder(), "chunker": TableDescriptionChunker(), "writer": AsyncDocumentWriter( - document_store=table_description_store, + document_store=self._table_description_store, policy=DuplicatePolicy.OVERWRITE, ), } - self._configs = {} - self._final = "write" - - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) @observe(name="Table Description Indexing") async def run( diff --git a/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py b/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py index 6c8dd7bbe3..7c881490c2 100644 --- a/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py @@ -457,43 +457,80 @@ def __init__( document_store_provider: DocumentStoreProvider, table_retrieval_size: int = 10, table_column_retrieval_size: int = 100, + description: str = "", **kwargs, ): - self._components = { - "embedder": embedder_provider.get_text_embedder(), - "table_retriever": document_store_provider.get_retriever( - document_store_provider.get_store(dataset_name="table_descriptions"), - top_k=table_retrieval_size, + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._llm_provider = llm_provider + self._table_retrieval_size = table_retrieval_size + self._table_column_retrieval_size = table_column_retrieval_size + self._document_store_provider = document_store_provider + self._embedder_provider = embedder_provider + self._description = description + self._components = self._update_components() + self._configs = self._update_configs() + + def _update_configs(self): + _model = (self._llm_provider.model,) + if "gpt-4o" in _model or "gpt-4o-mini" in _model: + _encoding = tiktoken.get_encoding("o200k_base") + else: + _encoding = tiktoken.get_encoding("cl100k_base") + + return { + "encoding": _encoding, + "context_window_size": self._llm_provider.context_window_size, + } + + def _update_components(self): + return { + "embedder": self._embedder_provider.get_text_embedder(), + "table_retriever": self._document_store_provider.get_retriever( + self._document_store_provider.get_store( + dataset_name="table_descriptions" + ), + top_k=self._table_retrieval_size, ), - "dbschema_retriever": document_store_provider.get_retriever( - document_store_provider.get_store(), - top_k=table_column_retrieval_size, + "dbschema_retriever": self._document_store_provider.get_retriever( + self._document_store_provider.get_store(), + top_k=self._table_column_retrieval_size, ), - "table_columns_selection_generator": llm_provider.get_generator( + "table_columns_selection_generator": self._llm_provider.get_generator( system_prompt=table_columns_selection_system_prompt, generation_kwargs=RETRIEVAL_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=table_columns_selection_user_prompt_template ), } - # for the first time, we need to load the encodings - _model = llm_provider.get_model() - if "gpt-4o" in _model or "gpt-4o-mini" in _model: - _encoding = tiktoken.get_encoding("o200k_base") - else: - _encoding = tiktoken.get_encoding("cl100k_base") - - self._configs = { - "encoding": _encoding, - "context_window_size": llm_provider.get_context_window_size(), - } - - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + def update_components( + self, + llm_provider: LLMProvider, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + llm_provider=llm_provider, + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._table_retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store(dataset_name="table_descriptions"), + top_k=self._table_retrieval_size, + ) + self._dbschema_retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store(), + top_k=self._table_column_retrieval_size, ) + self._components = self._update_components() + self._configs = self._update_configs() @observe(name="Ask Retrieval") async def run( diff --git a/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py index 0dcbc839ab..1ba78ce5f5 100644 --- a/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py @@ -123,29 +123,51 @@ def __init__( embedder_provider: EmbedderProvider, document_store_provider: DocumentStoreProvider, historical_question_retrieval_similarity_threshold: float = 0.9, + description: str = "", **kwargs, ) -> None: - view_questions_store = document_store_provider.get_store( - dataset_name="view_questions" + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - self._components = { - "view_questions_store": view_questions_store, - "embedder": embedder_provider.get_text_embedder(), - "view_questions_retriever": document_store_provider.get_retriever( - document_store=view_questions_store, - ), - "score_filter": ScoreFilter(), - # TODO: add a llm filter to filter out low scoring document, in case ScoreFilter is not accurate enough - "output_formatter": OutputFormatter(), - } + self._view_questions_store = document_store_provider.get_store( + dataset_name="view_questions" + ) + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider + self._description = description + self._components = self._update_components() self._configs = { "historical_question_retrieval_similarity_threshold": historical_question_retrieval_similarity_threshold, } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._view_questions_store = self._document_store_provider.get_store( + dataset_name="view_questions" ) + self._components = self._update_components() + + def _update_components(self): + return { + "view_questions_store": self._view_questions_store, + "embedder": self._embedder_provider.get_text_embedder(), + "view_questions_retriever": self._document_store_provider.get_retriever( + document_store=self._view_questions_store, + ), + "score_filter": ScoreFilter(), + # TODO: add a llm filter to filter out low scoring document, in case ScoreFilter is not accurate enough + "output_formatter": OutputFormatter(), + } @observe(name="Historical Question") async def run(self, query: str, project_id: Optional[str] = None): diff --git a/wren-ai-service/src/pipelines/retrieval/instructions.py b/wren-ai-service/src/pipelines/retrieval/instructions.py index 86c17e93de..8b0cca7a23 100644 --- a/wren-ai-service/src/pipelines/retrieval/instructions.py +++ b/wren-ai-service/src/pipelines/retrieval/instructions.py @@ -189,27 +189,51 @@ def __init__( document_store_provider: DocumentStoreProvider, similarity_threshold: float = 0.7, top_k: int = 10, + description: str = "", **kwargs, ) -> None: - store = document_store_provider.get_store(dataset_name="instructions") - self._components = { - "store": store, - "embedder": embedder_provider.get_text_embedder(), - "retriever": document_store_provider.get_retriever( - document_store=store, - ), - "scope_filter": ScopeFilter(), - "score_filter": ScoreFilter(), - "output_formatter": OutputFormatter(), - } + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider + self._store = document_store_provider.get_store(dataset_name="instructions") + self._description = description + + self._components = self._update_components() self._configs = { "similarity_threshold": similarity_threshold, "top_k": top_k, } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._store = self._document_store_provider.get_store( + dataset_name="instructions" ) + self._components = self._update_components() + + def _update_components(self): + return { + "store": self._store, + "embedder": self._embedder_provider.get_text_embedder(), + "retriever": self._document_store_provider.get_retriever( + document_store=self._store, + ), + "scope_filter": ScopeFilter(), + "score_filter": ScoreFilter(), + "output_formatter": OutputFormatter(), + } @observe(name="Instructions Retrieval") async def run( diff --git a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py index e6dadd32d0..51e1d56c73 100644 --- a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py +++ b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py @@ -80,20 +80,30 @@ class PreprocessSqlData(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): - _model = llm_provider.get_model() + super().__init__(Driver({}, sys.modules[__name__], adapter=base.DictResult())) + + self._llm_provider = llm_provider + self._description = description + self._configs = self._update_configs() + + def _update_configs(self): + _model = (self._llm_provider.model,) if _model == "gpt-4o-mini" or _model == "gpt-4o": _encoding = tiktoken.get_encoding("o200k_base") else: _encoding = tiktoken.get_encoding("cl100k_base") - self._configs = { + return { "encoding": _encoding, - "context_window_size": llm_provider.get_context_window_size(), + "context_window_size": self._llm_provider.context_window_size, } - super().__init__(Driver({}, sys.modules[__name__], adapter=base.DictResult())) + def update_components(self, llm_provider: LLMProvider, **_): + self._llm_provider = llm_provider + self._configs = self._update_configs() @observe(name="Preprocess SQL Data") def run( diff --git a/wren-ai-service/src/pipelines/retrieval/sql_executor.py b/wren-ai-service/src/pipelines/retrieval/sql_executor.py index b41151469f..dae00cedbd 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_executor.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_executor.py @@ -64,16 +64,18 @@ class SQLExecutor(BasicPipeline): def __init__( self, engine: Engine, + description: str = "", **kwargs, ): - self._components = { - "data_fetcher": DataFetcher(engine=engine), - } - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + self._description = description + self._components = { + "data_fetcher": DataFetcher(engine=engine), + } + @observe(name="SQL Execution") async def run( self, sql: str, project_id: str | None = None, limit: int = 500 diff --git a/wren-ai-service/src/pipelines/retrieval/sql_functions.py b/wren-ai-service/src/pipelines/retrieval/sql_functions.py index 016aa9b1e6..43abced26a 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_functions.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_functions.py @@ -85,10 +85,17 @@ def __init__( engine: Engine, document_store_provider: DocumentStoreProvider, ttl: int = 60 * 60 * 24, + description: str = "", **kwargs, ) -> None: - self._retriever = document_store_provider.get_retriever( - document_store_provider.get_store("project_meta") + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._description = description + self._document_store_provider = document_store_provider + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") ) self._cache = TTLCache(maxsize=100, ttl=ttl) self._components = { @@ -96,8 +103,10 @@ def __init__( "ttl_cache": self._cache, } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + def update_components(self, document_store_provider: DocumentStoreProvider, **_): + self._document_store_provider = document_store_provider + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") ) @observe(name="SQL Functions Retrieval") diff --git a/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py b/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py index 3fe44f32eb..7bcf71061e 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py @@ -120,27 +120,49 @@ def __init__( document_store_provider: DocumentStoreProvider, sql_pairs_similarity_threshold: float = 0.7, sql_pairs_retrieval_max_size: int = 10, + description: str = "", **kwargs, ) -> None: - store = document_store_provider.get_store(dataset_name="sql_pairs") - self._components = { - "store": store, - "embedder": embedder_provider.get_text_embedder(), - "retriever": document_store_provider.get_retriever( - document_store=store, - ), - "score_filter": ScoreFilter(), - # TODO: add a llm filter to filter out low scoring document, in case ScoreFilter is not accurate enough - "output_formatter": OutputFormatter(), - } + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider + self._store = self._document_store_provider.get_store(dataset_name="sql_pairs") + self._description = description + self._components = self._update_components() + self._configs = { "sql_pairs_similarity_threshold": sql_pairs_similarity_threshold, "sql_pairs_retrieval_max_size": sql_pairs_retrieval_max_size, } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, ) + self._store = self._document_store_provider.get_store(dataset_name="sql_pairs") + self._components = self._update_components() + + def _update_components(self): + return { + "store": self._store, + "embedder": self._embedder_provider.get_text_embedder(), + "retriever": self._document_store_provider.get_retriever( + document_store=self._store, + ), + "score_filter": ScoreFilter(), + # TODO: add a llm filter to filter out low scoring document, in case ScoreFilter is not accurate enough + "output_formatter": OutputFormatter(), + } @observe(name="SqlPairs Retrieval") async def run(self, query: str, project_id: Optional[str] = None): diff --git a/wren-ai-service/src/providers/__init__.py b/wren-ai-service/src/providers/__init__.py index fb491c8277..78ec07a2ef 100644 --- a/wren-ai-service/src/providers/__init__.py +++ b/wren-ai-service/src/providers/__init__.py @@ -107,9 +107,10 @@ def build_fallback_params(all_models: dict) -> dict: ] returned[model_name] = { - "provider": entry["provider"], - "model": model["model"], - "kwargs": model["kwargs"], + "provider": entry.get("provider"), + "model": model.get("model"), + "alias": model.get("alias", model.get("model")), + "kwargs": model.get("kwargs"), "context_window_size": model.get("context_window_size", 100000), "fallback_model_list": fallback_model_list, **model_additional_params, @@ -163,8 +164,9 @@ def embedder_processor(entry: dict) -> dict: k: v for k, v in model.items() if k not in ["model", "kwargs", "alias"] } returned[identifier] = { - "provider": entry["provider"], - "model": model["model"], + "provider": entry.get("provider"), + "model": model.get("model"), + "alias": model.get("alias", model.get("model")), **model_additional_params, **others, } @@ -277,6 +279,7 @@ def pipeline_processor(entry: dict) -> dict: "embedder": "openai_embedder.text-embedding-3-large", "document_store": "qdrant", "engine": "wren_ui", + "description": "Indexing pipeline", } } @@ -292,6 +295,7 @@ def pipeline_processor(entry: dict) -> dict: "embedder": pipe.get("embedder"), "document_store": pipe.get("document_store"), "engine": pipe.get("engine"), + "description": pipe.get("description"), } for pipe in entry["pipes"] } @@ -381,6 +385,7 @@ def get(type: str, components: dict, instantiated_providers: dict): def componentize(components: dict, instantiated_providers: dict): return PipelineComponent( + description=components.get("description", ""), embedder_provider=get("embedder", components, instantiated_providers), llm_provider=get("llm", components, instantiated_providers), document_store_provider=get( @@ -392,4 +397,4 @@ def componentize(components: dict, instantiated_providers: dict): return { pipe_name: componentize(components, instantiated_providers) for pipe_name, components in config.pipelines.items() - } + }, instantiated_providers diff --git a/wren-ai-service/src/providers/document_store/qdrant.py b/wren-ai-service/src/providers/document_store/qdrant.py index b90961c456..165ade3ba1 100644 --- a/wren-ai-service/src/providers/document_store/qdrant.py +++ b/wren-ai-service/src/providers/document_store/qdrant.py @@ -389,7 +389,8 @@ def __init__( self._api_key = Secret.from_token(api_key) if api_key else None self._timeout = timeout self._embedding_model_dim = embedding_model_dim - self._reset_document_store(recreate_index) + if recreate_index: + self._reset_document_store(recreate_index) def _reset_document_store(self, recreate_index: bool): self.get_store(recreate_index=recreate_index) diff --git a/wren-ai-service/src/providers/embedder/litellm.py b/wren-ai-service/src/providers/embedder/litellm.py index 4d051e3284..24b8a27331 100644 --- a/wren-ai-service/src/providers/embedder/litellm.py +++ b/wren-ai-service/src/providers/embedder/litellm.py @@ -36,15 +36,17 @@ def __init__( self, model: str, api_key: Optional[str] = None, - api_base_url: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, timeout: Optional[float] = None, **kwargs, ): self._api_key = api_key self._model = model - self._api_base_url = api_base_url + self._api_base = api_base + self._api_version = api_version self._timeout = timeout - self._kwargs = kwargs + self._model_kwargs = kwargs @component.output_types(embedding=List[float], meta=Dict[str, Any]) @backoff.on_exception(backoff.expo, openai.APIError, max_time=60.0, max_tries=3) @@ -63,9 +65,10 @@ async def run(self, text: str): model=self._model, input=[text_to_embed], api_key=self._api_key, - api_base=self._api_base_url, + api_base=self._api_base, + api_version=self._api_version, timeout=self._timeout, - **self._kwargs, + **self._model_kwargs, ) meta = { @@ -83,16 +86,18 @@ def __init__( model: str, batch_size: int = 32, api_key: Optional[str] = None, - api_base_url: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, timeout: Optional[float] = None, **kwargs, ): self._api_key = api_key self._model = model self._batch_size = batch_size - self._api_base_url = api_base_url + self._api_base = api_base + self._api_version = api_version self._timeout = timeout - self._kwargs = kwargs + self._model_kwargs = kwargs async def _embed_batch( self, texts_to_embed: List[str], batch_size: int @@ -102,9 +107,10 @@ async def embed_single_batch(batch: List[str]) -> Any: model=self._model, input=batch, api_key=self._api_key, - api_base=self._api_base_url, + api_base=self._api_base, + api_version=self._api_version, timeout=self._timeout, - **self._kwargs, + **self._model_kwargs, ) batches = [ @@ -171,31 +177,37 @@ def __init__( str ] = None, # e.g. EMBEDDER_OPENAI_API_KEY, EMBEDDER_ANTHROPIC_API_KEY, etc. api_base: Optional[str] = None, + api_version: Optional[str] = None, + alias: Optional[str] = None, timeout: float = 120.0, **kwargs, ): self._api_key = os.getenv(api_key_name) if api_key_name else None self._api_base = remove_trailing_slash(api_base) if api_base else None - self._embedding_model = model + self._api_version = api_version + self._model = model + self._alias = alias self._timeout = timeout if "provider" in kwargs: del kwargs["provider"] - self._kwargs = kwargs + self._model_kwargs = kwargs def get_text_embedder(self): return AsyncTextEmbedder( api_key=self._api_key, api_base_url=self._api_base, - model=self._embedding_model, + api_version=self._api_version, + model=self._model, timeout=self._timeout, - **self._kwargs, + **self._model_kwargs, ) def get_document_embedder(self): return AsyncDocumentEmbedder( api_key=self._api_key, - api_base_url=self._api_base, - model=self._embedding_model, + api_base=self._api_base, + api_version=self._api_version, + model=self._model, timeout=self._timeout, - **self._kwargs, + **self._model_kwargs, ) diff --git a/wren-ai-service/src/providers/llm/litellm.py b/wren-ai-service/src/providers/llm/litellm.py index 3748a945da..97752a4eda 100644 --- a/wren-ai-service/src/providers/llm/litellm.py +++ b/wren-ai-service/src/providers/llm/litellm.py @@ -29,6 +29,7 @@ def __init__( ] = None, # e.g. OPENAI_API_KEY, LLM_ANTHROPIC_API_KEY, etc. api_base: Optional[str] = None, api_version: Optional[str] = None, + alias: Optional[str] = None, kwargs: Optional[Dict[str, Any]] = None, timeout: float = 120.0, context_window_size: int = 100000, @@ -37,7 +38,8 @@ def __init__( **_, ): self._model = model - # TODO: remove _api_key, _api_base, _api_version in the future, as it is not used in litellm + self._alias = alias + self._api_key_name = api_key_name self._api_key = os.getenv(api_key_name) if api_key_name else None self._api_base = remove_trailing_slash(api_base) if api_base else None self._api_version = api_version diff --git a/wren-ai-service/src/utils.py b/wren-ai-service/src/utils.py index d368080c3c..b27821426e 100644 --- a/wren-ai-service/src/utils.py +++ b/wren-ai-service/src/utils.py @@ -3,10 +3,12 @@ import os import re from pathlib import Path +from typing import Any, Optional import requests from dotenv import load_dotenv from langfuse.decorators import langfuse_context +from pydantic import BaseModel from src.config import Settings @@ -218,3 +220,49 @@ def extract_braces_content(resp: str) -> str: """ match = re.search(r"```json\s*(\{.*?\})\s*```", resp, re.DOTALL) return match.group(1) if match else resp + + +class Configs(BaseModel): + class Providers(BaseModel): + class LLMProvider(BaseModel): + model: str + alias: str + context_window_size: int + timeout: float = 600 + api_base: Optional[str] = None + api_version: Optional[str] = None + kwargs: Optional[dict[str, Any]] = None + + class EmbedderProvider(BaseModel): + model: str + alias: str + dimension: int + timeout: float = 600 + kwargs: Optional[dict[str, Any]] = None + api_base: Optional[str] = None + api_version: Optional[str] = None + + llm: list[LLMProvider] + embedder: list[EmbedderProvider] + + class Pipeline(BaseModel): + has_db_data_in_llm_prompt: bool + llm: Optional[str] = None + embedder: Optional[str] = None + description: Optional[str] = None + + env_vars: dict[str, str] + providers: Providers + pipelines: dict[str, Pipeline] + + +def has_db_data_in_llm_prompt(pipe_name: str) -> bool: + pipes_containing_db_data = set( + [ + "sql_answer", + "chart_adjustment", + "chart_generation", + ] + ) + + return pipe_name in pipes_containing_db_data diff --git a/wren-ai-service/src/web/v1/services/semantics_preparation.py b/wren-ai-service/src/web/v1/services/semantics_preparation.py index 2ff6215cbe..ba019fac9b 100644 --- a/wren-ai-service/src/web/v1/services/semantics_preparation.py +++ b/wren-ai-service/src/web/v1/services/semantics_preparation.py @@ -84,7 +84,7 @@ async def prepare_semantics( "db_schema", "historical_question", "table_description", - "sql_pairs", + "sql_pairs_indexing", "project_meta", ] ] @@ -153,7 +153,7 @@ async def delete_semantics(self, project_id: str, **kwargs): project_id=project_id, delete_all=True, ) - for name in ["sql_pairs", "instructions"] + for name in ["sql_pairs_indexing", "instructions_indexing"] ] await asyncio.gather(*tasks)