From e98b034529b7b39ac4b9d4de2ba32994a7590923 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Tue, 2 Sep 2025 17:29:36 +0800 Subject: [PATCH 1/7] prototype --- wren-ai-service/src/__main__.py | 31 ++++++++- wren-ai-service/src/core/pipeline.py | 9 +++ wren-ai-service/src/globals.py | 16 ++++- .../pipelines/generation/chart_adjustment.py | 28 ++++---- .../pipelines/generation/chart_generation.py | 28 ++++---- .../pipelines/generation/data_assistance.py | 18 +++-- .../generation/followup_sql_generation.py | 20 +++--- .../followup_sql_generation_reasoning.py | 18 +++-- .../generation/intent_classification.py | 46 ++++++++----- .../generation/misleading_assistance.py | 18 +++-- .../generation/question_recommendation.py | 20 +++--- .../generation/relationship_recommendation.py | 20 +++--- .../generation/semantics_description.py | 20 +++--- .../src/pipelines/generation/sql_answer.py | 18 +++-- .../pipelines/generation/sql_correction.py | 20 +++--- .../pipelines/generation/sql_generation.py | 20 +++--- .../generation/sql_generation_reasoning.py | 20 ++++-- .../src/pipelines/generation/sql_question.py | 18 +++-- .../pipelines/generation/sql_regeneration.py | 21 +++--- .../generation/sql_tables_extraction.py | 18 +++-- .../generation/user_guide_assistance.py | 25 ++++--- .../retrieval/db_schema_retrieval.py | 66 ++++++++++++------- .../retrieval/preprocess_sql_data.py | 16 +++-- wren-ai-service/src/providers/__init__.py | 2 +- wren-ai-service/src/utils.py | 6 ++ .../web/v1/services/semantics_preparation.py | 4 +- 26 files changed, 353 insertions(+), 193 deletions(-) diff --git a/wren-ai-service/src/__main__.py b/wren-ai-service/src/__main__.py index de141c3fde..01779387f5 100644 --- a/wren-ai-service/src/__main__.py +++ b/wren-ai-service/src/__main__.py @@ -1,7 +1,7 @@ from contextlib import asynccontextmanager import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, RedirectResponse @@ -9,11 +9,13 @@ from src.config import settings from src.globals import ( + create_pipe_component_service_mapping, create_service_container, create_service_metadata, ) from src.providers import generate_components from src.utils import ( + SinglePipeComponentRequest, init_langfuse, setup_custom_logger, ) @@ -28,9 +30,13 @@ @asynccontextmanager async def lifespan(app: FastAPI): # startup events - pipe_components = generate_components(settings.components) + pipe_components, instantiated_providers = generate_components(settings.components) app.state.service_container = create_service_container(pipe_components, settings) + app.state.pipe_component_service_mapping = create_pipe_component_service_mapping( + app.state.service_container + ) app.state.service_metadata = create_service_metadata(pipe_components) + app.state.instantiated_providers = instantiated_providers init_langfuse(settings) yield @@ -86,6 +92,27 @@ def health(): return {"status": "ok"} +@app.get("/pipe_components") +def get_pipe_components(): + return sorted(list(app.state.pipe_component_service_mapping.keys())) + + +@app.post("/pipe_components") +def update_pipe_components(pipe_components_request: list[SinglePipeComponentRequest]): + try: + for payload in pipe_components_request: + for service in app.state.pipe_component_service_mapping[ + payload.pipeline_name + ]: + service._pipelines[payload.pipeline_name].update_llm_provider( + app.state.instantiated_providers["llm"][payload.llm_config] + ) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error updating pipe components: {e}" + ) + + if __name__ == "__main__": uvicorn.run( "src.__main__:app", diff --git a/wren-ai-service/src/core/pipeline.py b/wren-ai-service/src/core/pipeline.py index 02a477845c..897c8709b7 100644 --- a/wren-ai-service/src/core/pipeline.py +++ b/wren-ai-service/src/core/pipeline.py @@ -14,11 +14,20 @@ class BasicPipeline(metaclass=ABCMeta): def __init__(self, pipe: Pipeline | AsyncDriver | Driver): self._pipe = pipe + self._llm_provider = None + self._components = {} @abstractmethod def run(self, *args, **kwargs) -> Dict[str, Any]: ... + def _update_components(self) -> dict: + ... + + def update_llm_provider(self, llm_provider: LLMProvider): + self._llm_provider = llm_provider + self._components = self._update_components() + @dataclass class PipelineComponent(Mapping): diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index d6c40e05b4..903247d30a 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -1,4 +1,5 @@ import logging +from collections import defaultdict from dataclasses import asdict, dataclass import toml @@ -104,8 +105,8 @@ def create_service_container( "table_description": indexing.TableDescription( **pipe_components["table_description_indexing"], ), - "sql_pairs": _sql_pair_indexing_pipeline, - "instructions": _instructions_indexing_pipeline, + "sql_pairs_indexing": _sql_pair_indexing_pipeline, + "instructions_indexing": _instructions_indexing_pipeline, "project_meta": indexing.ProjectMeta( **pipe_components["project_meta_indexing"], ), @@ -231,7 +232,7 @@ def create_service_container( ), sql_pairs_service=services.SqlPairsService( pipelines={ - "sql_pairs": _sql_pair_indexing_pipeline, + "sql_pairs_indexing": _sql_pair_indexing_pipeline, }, **query_cache, ), @@ -262,6 +263,15 @@ def create_service_container( ) +def create_pipe_component_service_mapping(service_container: ServiceContainer): + _pipe_component_service_mapping = defaultdict(set) + for _, service in service_container.__dict__.items(): + for pipe_name in service._pipelines.keys(): + _pipe_component_service_mapping[pipe_name].add(service) + + return _pipe_component_service_mapping + + # Create a dependency that will be used to access the ServiceContainer def get_service_container(): from src.__main__ import app diff --git a/wren-ai-service/src/pipelines/generation/chart_adjustment.py b/wren-ai-service/src/pipelines/generation/chart_adjustment.py index 2e4e7ba8a3..f84105c9c0 100644 --- a/wren-ai-service/src/pipelines/generation/chart_adjustment.py +++ b/wren-ai-service/src/pipelines/generation/chart_adjustment.py @@ -152,18 +152,8 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): - self._components = { - "prompt_builder": PromptBuilder( - template=chart_adjustment_user_prompt_template - ), - "generator": llm_provider.get_generator( - system_prompt=chart_adjustment_system_prompt, - generation_kwargs=CHART_ADJUSTMENT_MODEL_KWARGS, - ), - "generator_name": llm_provider.get_model(), - "chart_data_preprocessor": ChartDataPreprocessor(), - "post_processor": ChartGenerationPostProcessor(), - } + self._llm_provider = llm_provider + self._components = self._update_components() with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f: _vega_schema = orjson.loads(f.read()) @@ -175,6 +165,20 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + def _update_components(self): + return { + "prompt_builder": PromptBuilder( + template=chart_adjustment_user_prompt_template + ), + "generator": self._llm_provider.get_generator( + system_prompt=chart_adjustment_system_prompt, + generation_kwargs=CHART_ADJUSTMENT_MODEL_KWARGS, + ), + "generator_name": self._llm_provider.get_model(), + "chart_data_preprocessor": ChartDataPreprocessor(), + "post_processor": ChartGenerationPostProcessor(), + } + @observe(name="Chart Adjustment") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/chart_generation.py b/wren-ai-service/src/pipelines/generation/chart_generation.py index 6daca5ec17..b077f5af39 100644 --- a/wren-ai-service/src/pipelines/generation/chart_generation.py +++ b/wren-ai-service/src/pipelines/generation/chart_generation.py @@ -125,18 +125,8 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): - self._components = { - "prompt_builder": PromptBuilder( - template=chart_generation_user_prompt_template - ), - "generator": llm_provider.get_generator( - system_prompt=chart_generation_system_prompt, - generation_kwargs=CHART_GENERATION_MODEL_KWARGS, - ), - "generator_name": llm_provider.get_model(), - "chart_data_preprocessor": ChartDataPreprocessor(), - "post_processor": ChartGenerationPostProcessor(), - } + self._llm_provider = llm_provider + self._components = self._update_components() with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f: _vega_schema = orjson.loads(f.read()) @@ -149,6 +139,20 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + def _update_components(self): + return { + "prompt_builder": PromptBuilder( + template=chart_generation_user_prompt_template + ), + "generator": self._llm_provider.get_generator( + system_prompt=chart_generation_system_prompt, + generation_kwargs=CHART_GENERATION_MODEL_KWARGS, + ), + "generator_name": self._llm_provider.get_model(), + "chart_data_preprocessor": ChartDataPreprocessor(), + "post_processor": ChartGenerationPostProcessor(), + } + @observe(name="Chart Generation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/data_assistance.py b/wren-ai-service/src/pipelines/generation/data_assistance.py index 51b91197f9..9b8899ad82 100644 --- a/wren-ai-service/src/pipelines/generation/data_assistance.py +++ b/wren-ai-service/src/pipelines/generation/data_assistance.py @@ -96,21 +96,25 @@ def __init__( **kwargs, ): self._user_queues = {} - self._components = { - "generator": llm_provider.get_generator( + self._llm_provider = llm_provider + self._components = self._update_components() + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=data_assistance_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.get_model(), "prompt_builder": PromptBuilder( template=data_assistance_user_prompt_template ), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _streaming_callback(self, chunk, query_id): if query_id not in self._user_queues: self._user_queues[ diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py index 5f889ff00c..72d1446687 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py @@ -159,23 +159,27 @@ def __init__( self._retriever = document_store_provider.get_retriever( document_store_provider.get_store("project_meta") ) + self._llm_provider = llm_provider + self._engine = engine + self._components = self._update_components() - self._components = { - "generator": llm_provider.get_generator( + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_generation_system_prompt, generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.get_model(), "prompt_builder": PromptBuilder( template=text_to_sql_with_followup_user_prompt_template ), - "post_processor": SQLGenPostProcessor(engine=engine), + "post_processor": SQLGenPostProcessor(engine=self._engine), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Follow-Up SQL Generation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py index 42b28c5b8f..791da07c7f 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py @@ -118,21 +118,25 @@ def __init__( **kwargs, ): self._user_queues = {} - self._components = { - "generator": llm_provider.get_generator( + self._llm_provider = llm_provider + self._components = self._update_components() + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_generation_reasoning_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.get_model(), "prompt_builder": PromptBuilder( template=sql_generation_reasoning_user_prompt_template ), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _streaming_callback(self, chunk, query_id): if query_id not in self._user_queues: self._user_queues[query_id] = asyncio.Queue() diff --git a/wren-ai-service/src/pipelines/generation/intent_classification.py b/wren-ai-service/src/pipelines/generation/intent_classification.py index 4d6cd313cd..b73e7b1544 100644 --- a/wren-ai-service/src/pipelines/generation/intent_classification.py +++ b/wren-ai-service/src/pipelines/generation/intent_classification.py @@ -346,34 +346,44 @@ def __init__( table_column_retrieval_size: Optional[int] = 100, **kwargs, ): - self._components = { - "embedder": embedder_provider.get_text_embedder(), - "table_retriever": document_store_provider.get_retriever( - document_store_provider.get_store(dataset_name="table_descriptions"), - top_k=table_retrieval_size, + self._llm_provider = llm_provider + self._table_retrieval_size = table_retrieval_size + self._table_column_retrieval_size = table_column_retrieval_size + self._document_store_provider = document_store_provider + self._embedder_provider = embedder_provider + self._components = self._update_components() + + self._configs = { + "wren_ai_docs": wren_ai_docs, + } + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _update_components(self): + return { + "embedder": self._embedder_provider.get_text_embedder(), + "table_retriever": self._document_store_provider.get_retriever( + self._document_store_provider.get_store( + dataset_name="table_descriptions" + ), + top_k=self._table_retrieval_size, ), - "dbschema_retriever": document_store_provider.get_retriever( - document_store_provider.get_store(), - top_k=table_column_retrieval_size, + "dbschema_retriever": self._document_store_provider.get_retriever( + self._document_store_provider.get_store(), + top_k=self._table_column_retrieval_size, ), - "generator": llm_provider.get_generator( + "generator": self._llm_provider.get_generator( system_prompt=intent_classification_system_prompt, generation_kwargs=INTENT_CLASSIFICAION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.get_model(), "prompt_builder": PromptBuilder( template=intent_classification_user_prompt_template ), } - self._configs = { - "wren_ai_docs": wren_ai_docs, - } - - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Intent Classification") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/misleading_assistance.py b/wren-ai-service/src/pipelines/generation/misleading_assistance.py index a35738ecf5..e53be8f636 100644 --- a/wren-ai-service/src/pipelines/generation/misleading_assistance.py +++ b/wren-ai-service/src/pipelines/generation/misleading_assistance.py @@ -96,21 +96,25 @@ def __init__( **kwargs, ): self._user_queues = {} - self._components = { - "generator": llm_provider.get_generator( + self._llm_provider = llm_provider + self._components = self._update_components() + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=misleading_assistance_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.get_model(), "prompt_builder": PromptBuilder( template=misleading_assistance_user_prompt_template ), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _streaming_callback(self, chunk, query_id): if query_id not in self._user_queues: self._user_queues[ diff --git a/wren-ai-service/src/pipelines/generation/question_recommendation.py b/wren-ai-service/src/pipelines/generation/question_recommendation.py index a6e7c17b02..a4acf01b6d 100644 --- a/wren-ai-service/src/pipelines/generation/question_recommendation.py +++ b/wren-ai-service/src/pipelines/generation/question_recommendation.py @@ -237,14 +237,8 @@ def __init__( llm_provider: LLMProvider, **_, ): - self._components = { - "prompt_builder": PromptBuilder(template=user_prompt_template), - "generator": llm_provider.get_generator( - system_prompt=system_prompt, - generation_kwargs=QUESTION_RECOMMENDATION_MODEL_KWARGS, - ), - "generator_name": llm_provider.get_model(), - } + self._llm_provider = llm_provider + self._components = self._update_components() self._final = "normalized" @@ -252,6 +246,16 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + def _update_components(self): + return { + "prompt_builder": PromptBuilder(template=user_prompt_template), + "generator": self._llm_provider.get_generator( + system_prompt=system_prompt, + generation_kwargs=QUESTION_RECOMMENDATION_MODEL_KWARGS, + ), + "generator_name": self._llm_provider.get_model(), + } + @observe(name="Question Recommendation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/relationship_recommendation.py b/wren-ai-service/src/pipelines/generation/relationship_recommendation.py index e0d0bed675..e46ab6df3f 100644 --- a/wren-ai-service/src/pipelines/generation/relationship_recommendation.py +++ b/wren-ai-service/src/pipelines/generation/relationship_recommendation.py @@ -204,14 +204,8 @@ def __init__( llm_provider: LLMProvider, **_, ): - self._components = { - "prompt_builder": PromptBuilder(template=user_prompt_template), - "generator": llm_provider.get_generator( - system_prompt=system_prompt, - generation_kwargs=RELATIONSHIP_RECOMMENDATION_MODEL_KWARGS, - ), - "generator_name": llm_provider.get_model(), - } + self._llm_provider = llm_provider + self._components = self._update_components() self._final = "validated" @@ -219,6 +213,16 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + def _update_components(self): + return { + "prompt_builder": PromptBuilder(template=user_prompt_template), + "generator": self._llm_provider.get_generator( + system_prompt=system_prompt, + generation_kwargs=RELATIONSHIP_RECOMMENDATION_MODEL_KWARGS, + ), + "generator_name": self._llm_provider.get_model(), + } + @observe(name="Relationship Recommendation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/semantics_description.py b/wren-ai-service/src/pipelines/generation/semantics_description.py index acc5fc8594..ce1136959c 100644 --- a/wren-ai-service/src/pipelines/generation/semantics_description.py +++ b/wren-ai-service/src/pipelines/generation/semantics_description.py @@ -218,20 +218,24 @@ class SemanticResult(BaseModel): class SemanticsDescription(BasicPipeline): def __init__(self, llm_provider: LLMProvider, **_): - self._components = { - "prompt_builder": PromptBuilder(template=user_prompt_template), - "generator": llm_provider.get_generator( - system_prompt=system_prompt, - generation_kwargs=SEMANTICS_DESCRIPTION_MODEL_KWARGS, - ), - "generator_name": llm_provider.get_model(), - } + self._llm_provider = llm_provider + self._components = self._update_components() self._final = "output" super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + def _update_components(self): + return { + "prompt_builder": PromptBuilder(template=user_prompt_template), + "generator": self._llm_provider.get_generator( + system_prompt=system_prompt, + generation_kwargs=SEMANTICS_DESCRIPTION_MODEL_KWARGS, + ), + "generator_name": self._llm_provider.get_model(), + } + @observe(name="Semantics Description Generation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_answer.py b/wren-ai-service/src/pipelines/generation/sql_answer.py index 81289081b5..a7f5e2b3b7 100644 --- a/wren-ai-service/src/pipelines/generation/sql_answer.py +++ b/wren-ai-service/src/pipelines/generation/sql_answer.py @@ -96,21 +96,25 @@ def __init__( **kwargs, ): self._user_queues = {} - self._components = { + self._llm_provider = llm_provider + self._components = self._update_components() + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _update_components(self): + return { "prompt_builder": PromptBuilder( template=sql_to_answer_user_prompt_template ), - "generator": llm_provider.get_generator( + "generator": self._llm_provider.get_generator( system_prompt=sql_to_answer_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.get_model(), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _streaming_callback(self, chunk, query_id): if query_id not in self._user_queues: self._user_queues[ diff --git a/wren-ai-service/src/pipelines/generation/sql_correction.py b/wren-ai-service/src/pipelines/generation/sql_correction.py index 7a9bcc092b..beb526e980 100644 --- a/wren-ai-service/src/pipelines/generation/sql_correction.py +++ b/wren-ai-service/src/pipelines/generation/sql_correction.py @@ -133,26 +133,30 @@ def __init__( engine: Engine, **kwargs, ): + self._llm_provider = llm_provider + self._engine = engine + self._components = self._update_components() self._retriever = document_store_provider.get_retriever( document_store_provider.get_store("project_meta") ) - self._components = { - "generator": llm_provider.get_generator( + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_correction_system_prompt, generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.get_model(), "prompt_builder": PromptBuilder( template=sql_correction_user_prompt_template ), - "post_processor": SQLGenPostProcessor(engine=engine), + "post_processor": SQLGenPostProcessor(engine=self._engine), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="SQL Correction") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index 27d3bc5eab..2a21d0fb2c 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -156,23 +156,27 @@ def __init__( self._retriever = document_store_provider.get_retriever( document_store_provider.get_store("project_meta") ) + self._llm_provider = llm_provider + self._engine = engine + self._components = self._update_components() - self._components = { - "generator": llm_provider.get_generator( + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_generation_system_prompt, generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.get_model(), "prompt_builder": PromptBuilder( template=sql_generation_user_prompt_template ), - "post_processor": SQLGenPostProcessor(engine=engine), + "post_processor": SQLGenPostProcessor(engine=self._engine), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="SQL Generation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py index 00b731cb2c..e831773b95 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py @@ -103,20 +103,28 @@ def __init__( **kwargs, ): self._user_queues = {} - self._components = { - "generator": llm_provider.get_generator( + self._llm_provider = llm_provider + self._components = self._update_components() + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_generation_reasoning_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.get_model(), "prompt_builder": PromptBuilder( template=sql_generation_reasoning_user_prompt_template ), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) + def update_llm_provider(self, llm_provider: LLMProvider): + self._llm_provider = llm_provider + self._components = self._update_components() def _streaming_callback(self, chunk, query_id): if query_id not in self._user_queues: diff --git a/wren-ai-service/src/pipelines/generation/sql_question.py b/wren-ai-service/src/pipelines/generation/sql_question.py index 81f3d3e62b..1c3d7253b5 100644 --- a/wren-ai-service/src/pipelines/generation/sql_question.py +++ b/wren-ai-service/src/pipelines/generation/sql_question.py @@ -98,19 +98,23 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): - self._components = { - "generator": llm_provider.get_generator( + self._llm_provider = llm_provider + self._components = self._update_components() + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_question_system_prompt, generation_kwargs=SQL_QUESTION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.get_model(), "prompt_builder": PromptBuilder(template=sql_question_user_prompt_template), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Sql Question Generation") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_regeneration.py b/wren-ai-service/src/pipelines/generation/sql_regeneration.py index 0dbf2fd808..76f67c3075 100644 --- a/wren-ai-service/src/pipelines/generation/sql_regeneration.py +++ b/wren-ai-service/src/pipelines/generation/sql_regeneration.py @@ -157,22 +157,27 @@ def __init__( engine: Engine, **kwargs, ): - self._components = { - "generator": llm_provider.get_generator( + self._llm_provider = llm_provider + self._engine = engine + self._components = self._update_components() + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_regeneration_system_prompt, generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.get_model(), "prompt_builder": PromptBuilder( template=sql_regeneration_user_prompt_template ), - "post_processor": SQLGenPostProcessor(engine=engine), + "post_processor": SQLGenPostProcessor(engine=self._engine), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="SQL Regeneration") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py b/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py index 7060e750b3..2656612435 100644 --- a/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py +++ b/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py @@ -102,21 +102,25 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): - self._components = { - "generator": llm_provider.get_generator( + self._llm_provider = llm_provider + self._components = self._update_components() + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_tables_extraction_system_prompt, generation_kwargs=SQL_TABLES_EXTRACTION_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.get_model(), "prompt_builder": PromptBuilder( template=sql_tables_extraction_user_prompt_template ), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Sql Tables Extraction") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py index be437f883b..3dd4a51600 100644 --- a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py +++ b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py @@ -86,16 +86,9 @@ def __init__( **kwargs, ): self._user_queues = {} - self._components = { - "generator": llm_provider.get_generator( - system_prompt=user_guide_assistance_system_prompt, - streaming_callback=self._streaming_callback, - ), - "generator_name": llm_provider.get_model(), - "prompt_builder": PromptBuilder( - template=user_guide_assistance_user_prompt_template - ), - } + self._llm_provider = llm_provider + self._components = self._update_components() + self._configs = { "wren_ai_docs": wren_ai_docs, } @@ -104,6 +97,18 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( + system_prompt=user_guide_assistance_system_prompt, + streaming_callback=self._streaming_callback, + ), + "generator_name": self._llm_provider.get_model(), + "prompt_builder": PromptBuilder( + template=user_guide_assistance_user_prompt_template + ), + } + def _streaming_callback(self, chunk, query_id): if query_id not in self._user_queues: self._user_queues[ diff --git a/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py b/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py index 6c8dd7bbe3..c90d829dcd 100644 --- a/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py @@ -459,41 +459,57 @@ def __init__( table_column_retrieval_size: int = 100, **kwargs, ): - self._components = { - "embedder": embedder_provider.get_text_embedder(), - "table_retriever": document_store_provider.get_retriever( - document_store_provider.get_store(dataset_name="table_descriptions"), - top_k=table_retrieval_size, + self._llm_provider = llm_provider + self._table_retrieval_size = table_retrieval_size + self._table_column_retrieval_size = table_column_retrieval_size + self._document_store_provider = document_store_provider + self._embedder_provider = embedder_provider + self._components = self._update_components() + self._configs = self._update_configs() + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _update_configs(self): + _model = self._llm_provider.get_model() + if "gpt-4o" in _model or "gpt-4o-mini" in _model: + _encoding = tiktoken.get_encoding("o200k_base") + else: + _encoding = tiktoken.get_encoding("cl100k_base") + + return { + "encoding": _encoding, + "context_window_size": self._llm_provider.get_context_window_size(), + } + + def _update_components(self): + return { + "embedder": self._embedder_provider.get_text_embedder(), + "table_retriever": self._document_store_provider.get_retriever( + self._document_store_provider.get_store( + dataset_name="table_descriptions" + ), + top_k=self._table_retrieval_size, ), - "dbschema_retriever": document_store_provider.get_retriever( - document_store_provider.get_store(), - top_k=table_column_retrieval_size, + "dbschema_retriever": self._document_store_provider.get_retriever( + self._document_store_provider.get_store(), + top_k=self._table_column_retrieval_size, ), - "table_columns_selection_generator": llm_provider.get_generator( + "table_columns_selection_generator": self._llm_provider.get_generator( system_prompt=table_columns_selection_system_prompt, generation_kwargs=RETRIEVAL_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.get_model(), "prompt_builder": PromptBuilder( template=table_columns_selection_user_prompt_template ), } - # for the first time, we need to load the encodings - _model = llm_provider.get_model() - if "gpt-4o" in _model or "gpt-4o-mini" in _model: - _encoding = tiktoken.get_encoding("o200k_base") - else: - _encoding = tiktoken.get_encoding("cl100k_base") - - self._configs = { - "encoding": _encoding, - "context_window_size": llm_provider.get_context_window_size(), - } - - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) + def update_llm_provider(self, llm_provider: LLMProvider): + self._llm_provider = llm_provider + self._components = self._update_components() + self._configs = self._update_configs() @observe(name="Ask Retrieval") async def run( diff --git a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py index e6dadd32d0..5bbdbd4f9a 100644 --- a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py +++ b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py @@ -82,18 +82,26 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): - _model = llm_provider.get_model() + self._llm_provider = llm_provider + self._configs = self._update_configs() + + super().__init__(Driver({}, sys.modules[__name__], adapter=base.DictResult())) + + def _update_configs(self): + _model = self._llm_provider.get_model() if _model == "gpt-4o-mini" or _model == "gpt-4o": _encoding = tiktoken.get_encoding("o200k_base") else: _encoding = tiktoken.get_encoding("cl100k_base") - self._configs = { + return { "encoding": _encoding, - "context_window_size": llm_provider.get_context_window_size(), + "context_window_size": self._llm_provider.get_context_window_size(), } - super().__init__(Driver({}, sys.modules[__name__], adapter=base.DictResult())) + def update_llm_provider(self, llm_provider: LLMProvider): + self._llm_provider = llm_provider + self._configs = self._update_configs() @observe(name="Preprocess SQL Data") def run( diff --git a/wren-ai-service/src/providers/__init__.py b/wren-ai-service/src/providers/__init__.py index fb491c8277..2f3d7dc943 100644 --- a/wren-ai-service/src/providers/__init__.py +++ b/wren-ai-service/src/providers/__init__.py @@ -392,4 +392,4 @@ def componentize(components: dict, instantiated_providers: dict): return { pipe_name: componentize(components, instantiated_providers) for pipe_name, components in config.pipelines.items() - } + }, instantiated_providers diff --git a/wren-ai-service/src/utils.py b/wren-ai-service/src/utils.py index d368080c3c..4662b901d4 100644 --- a/wren-ai-service/src/utils.py +++ b/wren-ai-service/src/utils.py @@ -7,6 +7,7 @@ import requests from dotenv import load_dotenv from langfuse.decorators import langfuse_context +from pydantic import BaseModel from src.config import Settings @@ -218,3 +219,8 @@ def extract_braces_content(resp: str) -> str: """ match = re.search(r"```json\s*(\{.*?\})\s*```", resp, re.DOTALL) return match.group(1) if match else resp + + +class SinglePipeComponentRequest(BaseModel): + pipeline_name: str + llm_config: str diff --git a/wren-ai-service/src/web/v1/services/semantics_preparation.py b/wren-ai-service/src/web/v1/services/semantics_preparation.py index 2ff6215cbe..ba019fac9b 100644 --- a/wren-ai-service/src/web/v1/services/semantics_preparation.py +++ b/wren-ai-service/src/web/v1/services/semantics_preparation.py @@ -84,7 +84,7 @@ async def prepare_semantics( "db_schema", "historical_question", "table_description", - "sql_pairs", + "sql_pairs_indexing", "project_meta", ] ] @@ -153,7 +153,7 @@ async def delete_semantics(self, project_id: str, **kwargs): project_id=project_id, delete_all=True, ) - for name in ["sql_pairs", "instructions"] + for name in ["sql_pairs_indexing", "instructions_indexing"] ] await asyncio.gather(*tasks) From 59d63e7555727b5e6dbcda31bfc48aaf84b3537c Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Wed, 17 Sep 2025 10:54:50 +0800 Subject: [PATCH 2/7] update --- wren-ai-service/src/__main__.py | 56 +++++++++++++++---- wren-ai-service/src/core/pipeline.py | 8 +++ wren-ai-service/src/globals.py | 30 +++++++--- .../pipelines/generation/chart_adjustment.py | 7 ++- .../pipelines/generation/chart_generation.py | 8 +-- .../pipelines/generation/data_assistance.py | 8 +-- .../generation/followup_sql_generation.py | 8 +-- .../followup_sql_generation_reasoning.py | 8 +-- .../generation/intent_classification.py | 8 +-- .../generation/misleading_assistance.py | 8 +-- .../generation/question_recommendation.py | 8 +-- .../generation/relationship_recommendation.py | 8 +-- .../generation/semantics_description.py | 8 +-- .../src/pipelines/generation/sql_answer.py | 8 +-- .../pipelines/generation/sql_correction.py | 8 +-- .../pipelines/generation/sql_generation.py | 8 +-- .../generation/sql_generation_reasoning.py | 8 +-- .../src/pipelines/generation/sql_question.py | 6 +- .../pipelines/generation/sql_regeneration.py | 8 +-- .../generation/sql_tables_extraction.py | 6 +- .../generation/user_guide_assistance.py | 8 +-- .../src/pipelines/indexing/db_schema.py | 7 ++- .../pipelines/indexing/historical_question.py | 8 +-- .../src/pipelines/indexing/instructions.py | 8 +-- .../src/pipelines/indexing/project_meta.py | 8 +-- .../src/pipelines/indexing/sql_pairs.py | 8 +-- .../pipelines/indexing/table_description.py | 8 +-- .../retrieval/db_schema_retrieval.py | 8 +-- .../historical_question_retrieval.py | 8 +-- .../src/pipelines/retrieval/instructions.py | 8 +-- .../retrieval/preprocess_sql_data.py | 4 +- .../src/pipelines/retrieval/sql_executor.py | 8 +-- .../src/pipelines/retrieval/sql_functions.py | 8 +-- .../retrieval/sql_pairs_retrieval.py | 8 +-- wren-ai-service/src/utils.py | 12 ++++ 35 files changed, 207 insertions(+), 137 deletions(-) diff --git a/wren-ai-service/src/__main__.py b/wren-ai-service/src/__main__.py index 01779387f5..9ae26d699a 100644 --- a/wren-ai-service/src/__main__.py +++ b/wren-ai-service/src/__main__.py @@ -9,7 +9,7 @@ from src.config import settings from src.globals import ( - create_pipe_component_service_mapping, + create_pipe_components, create_service_container, create_service_metadata, ) @@ -32,9 +32,7 @@ async def lifespan(app: FastAPI): # startup events pipe_components, instantiated_providers = generate_components(settings.components) app.state.service_container = create_service_container(pipe_components, settings) - app.state.pipe_component_service_mapping = create_pipe_component_service_mapping( - app.state.service_container - ) + app.state.pipe_components = create_pipe_components(app.state.service_container) app.state.service_metadata = create_service_metadata(pipe_components) app.state.instantiated_providers = instantiated_providers init_langfuse(settings) @@ -92,18 +90,52 @@ def health(): return {"status": "ok"} -@app.get("/pipe_components") +@app.get("/configs") def get_pipe_components(): - return sorted(list(app.state.pipe_component_service_mapping.keys())) - - -@app.post("/pipe_components") + _configs = { + "env_vars": {}, + "providers": { + "llm": [], + "embedder": [], + }, + "pipelines": {}, + } + + _llm_configs = [] + for model_name, model_config in app.state.instantiated_providers["llm"].items(): + print(f"model_name: {model_name}") + print(f"model: {model_config._model}") + + _embedder_configs = [] + for model_name, model_config in app.state.instantiated_providers[ + "embedder" + ].items(): + pass + + for pipe_name, pipe_component in app.state.pipe_components.items(): + llm_provider = pipe_component.get("llm", None) + embedder_provider = pipe_component.get("embedder", None) + if llm_provider or embedder_provider: + _configs["pipelines"][pipe_name] = { + "has_db_data_in_llm_prompt": pipe_component.get( + "has_db_data_in_llm_prompt", False + ), + } + if llm_provider: + _configs["pipelines"][pipe_name]["llm"] = llm_provider + if embedder_provider: + _configs["pipelines"][pipe_name]["embedder"] = embedder_provider + + return _configs + + +@app.post("/configs") def update_pipe_components(pipe_components_request: list[SinglePipeComponentRequest]): try: for payload in pipe_components_request: - for service in app.state.pipe_component_service_mapping[ - payload.pipeline_name - ]: + for service in app.state.pipe_components[payload.pipeline_name].get( + "services", [] + ): service._pipelines[payload.pipeline_name].update_llm_provider( app.state.instantiated_providers["llm"][payload.llm_config] ) diff --git a/wren-ai-service/src/core/pipeline.py b/wren-ai-service/src/core/pipeline.py index 897c8709b7..22292d32dc 100644 --- a/wren-ai-service/src/core/pipeline.py +++ b/wren-ai-service/src/core/pipeline.py @@ -15,6 +15,7 @@ class BasicPipeline(metaclass=ABCMeta): def __init__(self, pipe: Pipeline | AsyncDriver | Driver): self._pipe = pipe self._llm_provider = None + self._embedder_provider = None self._components = {} @abstractmethod @@ -28,6 +29,13 @@ def update_llm_provider(self, llm_provider: LLMProvider): self._llm_provider = llm_provider self._components = self._update_components() + def update_embedder_provider(self, embedder_provider: EmbedderProvider): + self._embedder_provider = embedder_provider + self._components = self._update_components() + + def __str__(self): + return f"BasicPipeline(llm_provider={self._llm_provider}, embedder_provider={self._embedder_provider})" + @dataclass class PipelineComponent(Mapping): diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 903247d30a..164baff976 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -1,5 +1,4 @@ import logging -from collections import defaultdict from dataclasses import asdict, dataclass import toml @@ -8,7 +7,7 @@ from src.core.pipeline import PipelineComponent from src.core.provider import EmbedderProvider, LLMProvider from src.pipelines import generation, indexing, retrieval -from src.utils import fetch_wren_ai_docs +from src.utils import fetch_wren_ai_docs, has_db_data_in_llm_prompt from src.web.v1 import services logger = logging.getLogger("wren-ai-service") @@ -263,13 +262,30 @@ def create_service_container( ) -def create_pipe_component_service_mapping(service_container: ServiceContainer): - _pipe_component_service_mapping = defaultdict(set) +def create_pipe_components(service_container: ServiceContainer): + _pipe_components = {} for _, service in service_container.__dict__.items(): - for pipe_name in service._pipelines.keys(): - _pipe_component_service_mapping[pipe_name].add(service) + for pipe_name, pipe in service._pipelines.items(): + print(f"pipe_name: {pipe_name}, pipe: {pipe}") + if pipe_name not in _pipe_components: + _pipe_components[pipe_name] = {} + if hasattr(pipe, "_llm_provider") and pipe._llm_provider is not None: + _pipe_components[pipe_name]["llm"] = pipe._llm_provider.get_model() + if ( + hasattr(pipe, "_embedder_provider") + and pipe._embedder_provider is not None + ): + _pipe_components[pipe_name][ + "embedder" + ] = pipe._embedder_provider.get_model() + if "services" not in _pipe_components[pipe_name]: + _pipe_components[pipe_name]["services"] = set() + _pipe_components[pipe_name]["services"].add(service) + _pipe_components[pipe_name][ + "has_db_data_in_llm_prompt" + ] = has_db_data_in_llm_prompt(pipe_name) - return _pipe_component_service_mapping + return _pipe_components # Create a dependency that will be used to access the ServiceContainer diff --git a/wren-ai-service/src/pipelines/generation/chart_adjustment.py b/wren-ai-service/src/pipelines/generation/chart_adjustment.py index f84105c9c0..7480082966 100644 --- a/wren-ai-service/src/pipelines/generation/chart_adjustment.py +++ b/wren-ai-service/src/pipelines/generation/chart_adjustment.py @@ -152,6 +152,10 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._llm_provider = llm_provider self._components = self._update_components() @@ -161,9 +165,6 @@ def __init__( self._configs = { "vega_schema": _vega_schema, } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) def _update_components(self): return { diff --git a/wren-ai-service/src/pipelines/generation/chart_generation.py b/wren-ai-service/src/pipelines/generation/chart_generation.py index b077f5af39..b1639e26ea 100644 --- a/wren-ai-service/src/pipelines/generation/chart_generation.py +++ b/wren-ai-service/src/pipelines/generation/chart_generation.py @@ -125,6 +125,10 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._llm_provider = llm_provider self._components = self._update_components() @@ -135,10 +139,6 @@ def __init__( "vega_schema": _vega_schema, } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _update_components(self): return { "prompt_builder": PromptBuilder( diff --git a/wren-ai-service/src/pipelines/generation/data_assistance.py b/wren-ai-service/src/pipelines/generation/data_assistance.py index 9b8899ad82..7a96de2c2a 100644 --- a/wren-ai-service/src/pipelines/generation/data_assistance.py +++ b/wren-ai-service/src/pipelines/generation/data_assistance.py @@ -95,14 +95,14 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): - self._user_queues = {} - self._llm_provider = llm_provider - self._components = self._update_components() - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + self._user_queues = {} + self._llm_provider = llm_provider + self._components = self._update_components() + def _update_components(self): return { "generator": self._llm_provider.get_generator( diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py index 72d1446687..bde753697c 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py @@ -156,6 +156,10 @@ def __init__( engine: Engine, **kwargs, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._retriever = document_store_provider.get_retriever( document_store_provider.get_store("project_meta") ) @@ -163,10 +167,6 @@ def __init__( self._engine = engine self._components = self._update_components() - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _update_components(self): return { "generator": self._llm_provider.get_generator( diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py index 791da07c7f..b225888a89 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py @@ -117,14 +117,14 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): - self._user_queues = {} - self._llm_provider = llm_provider - self._components = self._update_components() - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + self._user_queues = {} + self._llm_provider = llm_provider + self._components = self._update_components() + def _update_components(self): return { "generator": self._llm_provider.get_generator( diff --git a/wren-ai-service/src/pipelines/generation/intent_classification.py b/wren-ai-service/src/pipelines/generation/intent_classification.py index b73e7b1544..3698281c44 100644 --- a/wren-ai-service/src/pipelines/generation/intent_classification.py +++ b/wren-ai-service/src/pipelines/generation/intent_classification.py @@ -346,6 +346,10 @@ def __init__( table_column_retrieval_size: Optional[int] = 100, **kwargs, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._llm_provider = llm_provider self._table_retrieval_size = table_retrieval_size self._table_column_retrieval_size = table_column_retrieval_size @@ -357,10 +361,6 @@ def __init__( "wren_ai_docs": wren_ai_docs, } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _update_components(self): return { "embedder": self._embedder_provider.get_text_embedder(), diff --git a/wren-ai-service/src/pipelines/generation/misleading_assistance.py b/wren-ai-service/src/pipelines/generation/misleading_assistance.py index e53be8f636..286a97ae15 100644 --- a/wren-ai-service/src/pipelines/generation/misleading_assistance.py +++ b/wren-ai-service/src/pipelines/generation/misleading_assistance.py @@ -95,14 +95,14 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): - self._user_queues = {} - self._llm_provider = llm_provider - self._components = self._update_components() - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + self._user_queues = {} + self._llm_provider = llm_provider + self._components = self._update_components() + def _update_components(self): return { "generator": self._llm_provider.get_generator( diff --git a/wren-ai-service/src/pipelines/generation/question_recommendation.py b/wren-ai-service/src/pipelines/generation/question_recommendation.py index a4acf01b6d..14ecbf9634 100644 --- a/wren-ai-service/src/pipelines/generation/question_recommendation.py +++ b/wren-ai-service/src/pipelines/generation/question_recommendation.py @@ -237,15 +237,15 @@ def __init__( llm_provider: LLMProvider, **_, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._llm_provider = llm_provider self._components = self._update_components() self._final = "normalized" - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _update_components(self): return { "prompt_builder": PromptBuilder(template=user_prompt_template), diff --git a/wren-ai-service/src/pipelines/generation/relationship_recommendation.py b/wren-ai-service/src/pipelines/generation/relationship_recommendation.py index e46ab6df3f..517d02fa85 100644 --- a/wren-ai-service/src/pipelines/generation/relationship_recommendation.py +++ b/wren-ai-service/src/pipelines/generation/relationship_recommendation.py @@ -204,15 +204,15 @@ def __init__( llm_provider: LLMProvider, **_, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._llm_provider = llm_provider self._components = self._update_components() self._final = "validated" - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _update_components(self): return { "prompt_builder": PromptBuilder(template=user_prompt_template), diff --git a/wren-ai-service/src/pipelines/generation/semantics_description.py b/wren-ai-service/src/pipelines/generation/semantics_description.py index ce1136959c..37f3e3f637 100644 --- a/wren-ai-service/src/pipelines/generation/semantics_description.py +++ b/wren-ai-service/src/pipelines/generation/semantics_description.py @@ -218,14 +218,14 @@ class SemanticResult(BaseModel): class SemanticsDescription(BasicPipeline): def __init__(self, llm_provider: LLMProvider, **_): - self._llm_provider = llm_provider - self._components = self._update_components() - self._final = "output" - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + self._llm_provider = llm_provider + self._components = self._update_components() + self._final = "output" + def _update_components(self): return { "prompt_builder": PromptBuilder(template=user_prompt_template), diff --git a/wren-ai-service/src/pipelines/generation/sql_answer.py b/wren-ai-service/src/pipelines/generation/sql_answer.py index a7f5e2b3b7..f8d530b256 100644 --- a/wren-ai-service/src/pipelines/generation/sql_answer.py +++ b/wren-ai-service/src/pipelines/generation/sql_answer.py @@ -95,14 +95,14 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): - self._user_queues = {} - self._llm_provider = llm_provider - self._components = self._update_components() - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + self._user_queues = {} + self._llm_provider = llm_provider + self._components = self._update_components() + def _update_components(self): return { "prompt_builder": PromptBuilder( diff --git a/wren-ai-service/src/pipelines/generation/sql_correction.py b/wren-ai-service/src/pipelines/generation/sql_correction.py index beb526e980..850837f9a1 100644 --- a/wren-ai-service/src/pipelines/generation/sql_correction.py +++ b/wren-ai-service/src/pipelines/generation/sql_correction.py @@ -133,6 +133,10 @@ def __init__( engine: Engine, **kwargs, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._llm_provider = llm_provider self._engine = engine self._components = self._update_components() @@ -140,10 +144,6 @@ def __init__( document_store_provider.get_store("project_meta") ) - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _update_components(self): return { "generator": self._llm_provider.get_generator( diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index 2a21d0fb2c..a2f9832279 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -153,6 +153,10 @@ def __init__( engine: Engine, **kwargs, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._retriever = document_store_provider.get_retriever( document_store_provider.get_store("project_meta") ) @@ -160,10 +164,6 @@ def __init__( self._engine = engine self._components = self._update_components() - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _update_components(self): return { "generator": self._llm_provider.get_generator( diff --git a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py index e831773b95..21c141437d 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py @@ -102,14 +102,14 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): - self._user_queues = {} - self._llm_provider = llm_provider - self._components = self._update_components() - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + self._user_queues = {} + self._llm_provider = llm_provider + self._components = self._update_components() + def _update_components(self): return { "generator": self._llm_provider.get_generator( diff --git a/wren-ai-service/src/pipelines/generation/sql_question.py b/wren-ai-service/src/pipelines/generation/sql_question.py index 1c3d7253b5..e9a8cceef0 100644 --- a/wren-ai-service/src/pipelines/generation/sql_question.py +++ b/wren-ai-service/src/pipelines/generation/sql_question.py @@ -98,13 +98,13 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): - self._llm_provider = llm_provider - self._components = self._update_components() - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + self._llm_provider = llm_provider + self._components = self._update_components() + def _update_components(self): return { "generator": self._llm_provider.get_generator( diff --git a/wren-ai-service/src/pipelines/generation/sql_regeneration.py b/wren-ai-service/src/pipelines/generation/sql_regeneration.py index 76f67c3075..3d1bfaa9f5 100644 --- a/wren-ai-service/src/pipelines/generation/sql_regeneration.py +++ b/wren-ai-service/src/pipelines/generation/sql_regeneration.py @@ -157,14 +157,14 @@ def __init__( engine: Engine, **kwargs, ): - self._llm_provider = llm_provider - self._engine = engine - self._components = self._update_components() - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + self._llm_provider = llm_provider + self._engine = engine + self._components = self._update_components() + def _update_components(self): return { "generator": self._llm_provider.get_generator( diff --git a/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py b/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py index 2656612435..1b6170e5f4 100644 --- a/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py +++ b/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py @@ -102,13 +102,13 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): - self._llm_provider = llm_provider - self._components = self._update_components() - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + self._llm_provider = llm_provider + self._components = self._update_components() + def _update_components(self): return { "generator": self._llm_provider.get_generator( diff --git a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py index 3dd4a51600..60233f1e94 100644 --- a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py +++ b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py @@ -85,6 +85,10 @@ def __init__( wren_ai_docs: list[dict], **kwargs, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._user_queues = {} self._llm_provider = llm_provider self._components = self._update_components() @@ -93,10 +97,6 @@ def __init__( "wren_ai_docs": wren_ai_docs, } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _update_components(self): return { "generator": self._llm_provider.get_generator( diff --git a/wren-ai-service/src/pipelines/indexing/db_schema.py b/wren-ai-service/src/pipelines/indexing/db_schema.py index 394d087b46..ea0596f2a9 100644 --- a/wren-ai-service/src/pipelines/indexing/db_schema.py +++ b/wren-ai-service/src/pipelines/indexing/db_schema.py @@ -344,6 +344,10 @@ def __init__( column_batch_size: int = 50, **kwargs, ) -> None: + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + dbschema_store = document_store_provider.get_store() self._components = { @@ -362,9 +366,6 @@ def __init__( self._final = "write" helper.load_helpers() - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) @observe(name="DB Schema Indexing") async def run( diff --git a/wren-ai-service/src/pipelines/indexing/historical_question.py b/wren-ai-service/src/pipelines/indexing/historical_question.py index ff30d91f0f..b4ec96741e 100644 --- a/wren-ai-service/src/pipelines/indexing/historical_question.py +++ b/wren-ai-service/src/pipelines/indexing/historical_question.py @@ -139,6 +139,10 @@ def __init__( document_store_provider: DocumentStoreProvider, **kwargs, ) -> None: + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + # keep the store name as it is for now, might change in the future store = document_store_provider.get_store(dataset_name="view_questions") @@ -155,10 +159,6 @@ def __init__( self._configs = {} self._final = "write" - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Historical Question Indexing") async def run( self, mdl_str: str, project_id: Optional[str] = None diff --git a/wren-ai-service/src/pipelines/indexing/instructions.py b/wren-ai-service/src/pipelines/indexing/instructions.py index b23f3cf2ab..fd473eb2fe 100644 --- a/wren-ai-service/src/pipelines/indexing/instructions.py +++ b/wren-ai-service/src/pipelines/indexing/instructions.py @@ -131,6 +131,10 @@ def __init__( document_store_provider: DocumentStoreProvider, **kwargs, ) -> None: + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + store = document_store_provider.get_store(dataset_name="instructions") self._components = { @@ -143,10 +147,6 @@ def __init__( ), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Instructions Indexing") async def run( self, diff --git a/wren-ai-service/src/pipelines/indexing/project_meta.py b/wren-ai-service/src/pipelines/indexing/project_meta.py index 2ee4a08074..e7b9dd5443 100644 --- a/wren-ai-service/src/pipelines/indexing/project_meta.py +++ b/wren-ai-service/src/pipelines/indexing/project_meta.py @@ -69,6 +69,10 @@ def __init__( document_store_provider: DocumentStoreProvider, **kwargs, ) -> None: + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + store = document_store_provider.get_store(dataset_name="project_meta") self._components = { @@ -81,10 +85,6 @@ def __init__( } self._final = "write" - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Project Meta Indexing") async def run( self, mdl_str: str, project_id: Optional[str] = None diff --git a/wren-ai-service/src/pipelines/indexing/sql_pairs.py b/wren-ai-service/src/pipelines/indexing/sql_pairs.py index a92fb36df1..6676891f80 100644 --- a/wren-ai-service/src/pipelines/indexing/sql_pairs.py +++ b/wren-ai-service/src/pipelines/indexing/sql_pairs.py @@ -171,6 +171,10 @@ def __init__( sql_pairs_path: str = "sql_pairs.json", **kwargs, ) -> None: + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + store = document_store_provider.get_store(dataset_name="sql_pairs") self._components = { @@ -185,10 +189,6 @@ def __init__( self._external_pairs = _load_sql_pairs(sql_pairs_path) - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="SQL Pairs Indexing") async def run( self, diff --git a/wren-ai-service/src/pipelines/indexing/table_description.py b/wren-ai-service/src/pipelines/indexing/table_description.py index 6da100868f..40b2646aca 100644 --- a/wren-ai-service/src/pipelines/indexing/table_description.py +++ b/wren-ai-service/src/pipelines/indexing/table_description.py @@ -122,6 +122,10 @@ def __init__( document_store_provider: DocumentStoreProvider, **kwargs, ) -> None: + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + table_description_store = document_store_provider.get_store( dataset_name="table_descriptions" ) @@ -139,10 +143,6 @@ def __init__( self._configs = {} self._final = "write" - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Table Description Indexing") async def run( self, mdl_str: str, project_id: Optional[str] = None diff --git a/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py b/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py index c90d829dcd..995db02c2e 100644 --- a/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py @@ -459,6 +459,10 @@ def __init__( table_column_retrieval_size: int = 100, **kwargs, ): + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._llm_provider = llm_provider self._table_retrieval_size = table_retrieval_size self._table_column_retrieval_size = table_column_retrieval_size @@ -467,10 +471,6 @@ def __init__( self._components = self._update_components() self._configs = self._update_configs() - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - def _update_configs(self): _model = self._llm_provider.get_model() if "gpt-4o" in _model or "gpt-4o-mini" in _model: diff --git a/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py index 0dcbc839ab..4942380aa9 100644 --- a/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py @@ -125,6 +125,10 @@ def __init__( historical_question_retrieval_similarity_threshold: float = 0.9, **kwargs, ) -> None: + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + view_questions_store = document_store_provider.get_store( dataset_name="view_questions" ) @@ -143,10 +147,6 @@ def __init__( "historical_question_retrieval_similarity_threshold": historical_question_retrieval_similarity_threshold, } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Historical Question") async def run(self, query: str, project_id: Optional[str] = None): logger.info("HistoricalQuestion Retrieval pipeline is running...") diff --git a/wren-ai-service/src/pipelines/retrieval/instructions.py b/wren-ai-service/src/pipelines/retrieval/instructions.py index 86c17e93de..552b4d06fd 100644 --- a/wren-ai-service/src/pipelines/retrieval/instructions.py +++ b/wren-ai-service/src/pipelines/retrieval/instructions.py @@ -191,6 +191,10 @@ def __init__( top_k: int = 10, **kwargs, ) -> None: + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + store = document_store_provider.get_store(dataset_name="instructions") self._components = { "store": store, @@ -207,10 +211,6 @@ def __init__( "top_k": top_k, } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="Instructions Retrieval") async def run( self, query: str, project_id: Optional[str] = None, scope: str = "sql" diff --git a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py index 5bbdbd4f9a..7f185a642a 100644 --- a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py +++ b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py @@ -82,11 +82,11 @@ def __init__( llm_provider: LLMProvider, **kwargs, ): + super().__init__(Driver({}, sys.modules[__name__], adapter=base.DictResult())) + self._llm_provider = llm_provider self._configs = self._update_configs() - super().__init__(Driver({}, sys.modules[__name__], adapter=base.DictResult())) - def _update_configs(self): _model = self._llm_provider.get_model() if _model == "gpt-4o-mini" or _model == "gpt-4o": diff --git a/wren-ai-service/src/pipelines/retrieval/sql_executor.py b/wren-ai-service/src/pipelines/retrieval/sql_executor.py index b41151469f..a239e6d57d 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_executor.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_executor.py @@ -66,14 +66,14 @@ def __init__( engine: Engine, **kwargs, ): - self._components = { - "data_fetcher": DataFetcher(engine=engine), - } - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + self._components = { + "data_fetcher": DataFetcher(engine=engine), + } + @observe(name="SQL Execution") async def run( self, sql: str, project_id: str | None = None, limit: int = 500 diff --git a/wren-ai-service/src/pipelines/retrieval/sql_functions.py b/wren-ai-service/src/pipelines/retrieval/sql_functions.py index 016aa9b1e6..78d179151f 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_functions.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_functions.py @@ -87,6 +87,10 @@ def __init__( ttl: int = 60 * 60 * 24, **kwargs, ) -> None: + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + self._retriever = document_store_provider.get_retriever( document_store_provider.get_store("project_meta") ) @@ -96,10 +100,6 @@ def __init__( "ttl_cache": self._cache, } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="SQL Functions Retrieval") async def run( self, diff --git a/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py b/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py index 3fe44f32eb..414d3c9524 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py @@ -122,6 +122,10 @@ def __init__( sql_pairs_retrieval_max_size: int = 10, **kwargs, ) -> None: + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + store = document_store_provider.get_store(dataset_name="sql_pairs") self._components = { "store": store, @@ -138,10 +142,6 @@ def __init__( "sql_pairs_retrieval_max_size": sql_pairs_retrieval_max_size, } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="SqlPairs Retrieval") async def run(self, query: str, project_id: Optional[str] = None): logger.info("SqlPairs Retrieval pipeline is running...") diff --git a/wren-ai-service/src/utils.py b/wren-ai-service/src/utils.py index 4662b901d4..7abfb40a58 100644 --- a/wren-ai-service/src/utils.py +++ b/wren-ai-service/src/utils.py @@ -224,3 +224,15 @@ def extract_braces_content(resp: str) -> str: class SinglePipeComponentRequest(BaseModel): pipeline_name: str llm_config: str + + +def has_db_data_in_llm_prompt(pipe_name: str) -> bool: + pipes_containing_db_data = set( + [ + "sql_answer", + "chart_adjustment", + "chart_generation", + ] + ) + + return pipe_name in pipes_containing_db_data From 5201083adb778bb6c80b0f77a6ca24c6a7340b35 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Thu, 18 Sep 2025 08:56:43 +0800 Subject: [PATCH 3/7] update --- wren-ai-service/src/__main__.py | 104 +++++++++++++----- wren-ai-service/src/core/pipeline.py | 5 + wren-ai-service/src/core/provider.py | 2 +- wren-ai-service/src/globals.py | 2 +- .../pipelines/generation/chart_adjustment.py | 2 + .../pipelines/generation/chart_generation.py | 2 + .../pipelines/generation/data_assistance.py | 2 + .../generation/followup_sql_generation.py | 2 + .../followup_sql_generation_reasoning.py | 2 + .../generation/intent_classification.py | 2 + .../generation/misleading_assistance.py | 2 + .../generation/question_recommendation.py | 2 + .../generation/relationship_recommendation.py | 2 + .../generation/semantics_description.py | 8 +- .../src/pipelines/generation/sql_answer.py | 2 + .../pipelines/generation/sql_correction.py | 2 + .../src/pipelines/generation/sql_diagnosis.py | 20 ++-- .../pipelines/generation/sql_generation.py | 2 + .../generation/sql_generation_reasoning.py | 2 + .../src/pipelines/generation/sql_question.py | 2 + .../pipelines/generation/sql_regeneration.py | 2 + .../generation/sql_tables_extraction.py | 2 + .../generation/user_guide_assistance.py | 2 + .../src/pipelines/indexing/db_schema.py | 2 + .../pipelines/indexing/historical_question.py | 2 + .../src/pipelines/indexing/instructions.py | 2 + .../src/pipelines/indexing/project_meta.py | 2 + .../src/pipelines/indexing/sql_pairs.py | 2 + .../pipelines/indexing/table_description.py | 2 + .../retrieval/db_schema_retrieval.py | 2 + .../historical_question_retrieval.py | 3 + .../src/pipelines/retrieval/instructions.py | 3 + .../retrieval/preprocess_sql_data.py | 2 + .../src/pipelines/retrieval/sql_executor.py | 2 + .../src/pipelines/retrieval/sql_functions.py | 2 + .../retrieval/sql_pairs_retrieval.py | 3 + wren-ai-service/src/providers/__init__.py | 15 ++- .../src/providers/embedder/litellm.py | 46 +++++--- wren-ai-service/src/providers/llm/litellm.py | 4 +- wren-ai-service/src/utils.py | 38 ++++++- 40 files changed, 241 insertions(+), 66 deletions(-) diff --git a/wren-ai-service/src/__main__.py b/wren-ai-service/src/__main__.py index 9ae26d699a..1b4e2c7702 100644 --- a/wren-ai-service/src/__main__.py +++ b/wren-ai-service/src/__main__.py @@ -1,7 +1,7 @@ from contextlib import asynccontextmanager import uvicorn -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, RedirectResponse @@ -14,8 +14,9 @@ create_service_metadata, ) from src.providers import generate_components +from src.providers.embedder.litellm import LitellmEmbedderProvider from src.utils import ( - SinglePipeComponentRequest, + Configs, init_langfuse, setup_custom_logger, ) @@ -91,7 +92,7 @@ def health(): @app.get("/configs") -def get_pipe_components(): +def get_configs(): _configs = { "env_vars": {}, "providers": { @@ -101,49 +102,94 @@ def get_pipe_components(): "pipelines": {}, } + _llm_model_alias_mapping = {} + _embedder_model_alias_mapping = {} + _llm_configs = [] - for model_name, model_config in app.state.instantiated_providers["llm"].items(): - print(f"model_name: {model_name}") - print(f"model: {model_config._model}") + for _, model_config in app.state.instantiated_providers["llm"].items(): + _llm_configs.append( + { + "model": model_config._model, + "alias": model_config._alias, + "api_base": model_config._api_base, + "api_version": model_config._api_version, + "context_window_size": model_config._context_window_size, + "timeout": model_config._timeout, + "kwargs": model_config._model_kwargs, + } + ) + _llm_model_alias_mapping[model_config._model] = model_config._alias + _configs["providers"]["llm"] = _llm_configs _embedder_configs = [] - for model_name, model_config in app.state.instantiated_providers[ - "embedder" - ].items(): - pass + # we only support one embedding model now + for _, model_config in app.state.instantiated_providers["embedder"].items(): + _embedder_configs.append( + { + "model": model_config._model, + "alias": model_config._alias, + "dimension": app.state.instantiated_providers["document_store"][ + "qdrant" + ]._embedding_model_dim, + "api_base": model_config._api_base, + "api_version": model_config._api_version, + "timeout": model_config._timeout, + "kwargs": model_config._model_kwargs, + } + ) + _embedder_model_alias_mapping[model_config._model] = model_config._alias + break + _configs["providers"]["embedder"] = _embedder_configs for pipe_name, pipe_component in app.state.pipe_components.items(): - llm_provider = pipe_component.get("llm", None) - embedder_provider = pipe_component.get("embedder", None) - if llm_provider or embedder_provider: + llm_model = pipe_component.get("llm", None) + embedding_model = pipe_component.get("embedder", None) + description = pipe_component.get("description", "") + if llm_model or embedding_model: _configs["pipelines"][pipe_name] = { "has_db_data_in_llm_prompt": pipe_component.get( "has_db_data_in_llm_prompt", False ), + "description": description, } - if llm_provider: - _configs["pipelines"][pipe_name]["llm"] = llm_provider - if embedder_provider: - _configs["pipelines"][pipe_name]["embedder"] = embedder_provider + if llm_model: + if llm_model_alias := _llm_model_alias_mapping.get(llm_model): + _configs["pipelines"][pipe_name]["llm"] = llm_model_alias + else: + _configs["pipelines"][pipe_name]["llm"] = llm_model + if embedding_model: + if embedding_model_alias := _embedder_model_alias_mapping.get( + embedding_model + ): + _configs["pipelines"][pipe_name]["embedder"] = embedding_model_alias + else: + _configs["pipelines"][pipe_name]["embedder"] = embedding_model return _configs @app.post("/configs") -def update_pipe_components(pipe_components_request: list[SinglePipeComponentRequest]): - try: - for payload in pipe_components_request: - for service in app.state.pipe_components[payload.pipeline_name].get( - "services", [] - ): - service._pipelines[payload.pipeline_name].update_llm_provider( - app.state.instantiated_providers["llm"][payload.llm_config] - ) - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Error updating pipe components: {e}" +def update_configs(configs_request: Configs): + embedder_providers = {} + for embedder_provider in configs_request.providers.embedder: + identifier = f"litellm_embedder.{embedder_provider.alias if embedder_provider.alias else embedder_provider.model}" + embedder_providers[identifier] = LitellmEmbedderProvider( + **embedder_provider.__dict__ ) + # try: + # for payload in pipe_components_request: + # for service in app.state.pipe_components[payload.pipeline_name].get( + # "services", [] + # ): + # service._pipelines[payload.pipeline_name].update_llm_provider( + # app.state.instantiated_providers["llm"][payload.llm_config] + # ) + # except Exception as e: + # raise HTTPException( + # status_code=500, detail=f"Error updating pipe components: {e}" + # ) + if __name__ == "__main__": uvicorn.run( diff --git a/wren-ai-service/src/core/pipeline.py b/wren-ai-service/src/core/pipeline.py index 22292d32dc..aeabf9b150 100644 --- a/wren-ai-service/src/core/pipeline.py +++ b/wren-ai-service/src/core/pipeline.py @@ -14,6 +14,7 @@ class BasicPipeline(metaclass=ABCMeta): def __init__(self, pipe: Pipeline | AsyncDriver | Driver): self._pipe = pipe + self._description = "" self._llm_provider = None self._embedder_provider = None self._components = {} @@ -39,6 +40,7 @@ def __str__(self): @dataclass class PipelineComponent(Mapping): + description: str = None llm_provider: LLMProvider = None embedder_provider: EmbedderProvider = None document_store_provider: DocumentStoreProvider = None @@ -52,3 +54,6 @@ def __iter__(self): def __len__(self): return len(self.__dict__) + + def __str__(self): + return f"PipelineComponent(description={self.description}, llm_provider={self.llm_provider}, embedder_provider={self.embedder_provider}, document_store_provider={self.document_store_provider}, engine={self.engine})" diff --git a/wren-ai-service/src/core/provider.py b/wren-ai-service/src/core/provider.py index 4246dd4908..822777424c 100644 --- a/wren-ai-service/src/core/provider.py +++ b/wren-ai-service/src/core/provider.py @@ -28,7 +28,7 @@ def get_document_embedder(self, *args, **kwargs): ... def get_model(self): - return self._embedding_model + return self._model class DocumentStoreProvider(metaclass=ABCMeta): diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 164baff976..8c2e67c7f0 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -266,7 +266,6 @@ def create_pipe_components(service_container: ServiceContainer): _pipe_components = {} for _, service in service_container.__dict__.items(): for pipe_name, pipe in service._pipelines.items(): - print(f"pipe_name: {pipe_name}, pipe: {pipe}") if pipe_name not in _pipe_components: _pipe_components[pipe_name] = {} if hasattr(pipe, "_llm_provider") and pipe._llm_provider is not None: @@ -284,6 +283,7 @@ def create_pipe_components(service_container: ServiceContainer): _pipe_components[pipe_name][ "has_db_data_in_llm_prompt" ] = has_db_data_in_llm_prompt(pipe_name) + _pipe_components[pipe_name]["description"] = pipe._description return _pipe_components diff --git a/wren-ai-service/src/pipelines/generation/chart_adjustment.py b/wren-ai-service/src/pipelines/generation/chart_adjustment.py index 7480082966..168671e4f4 100644 --- a/wren-ai-service/src/pipelines/generation/chart_adjustment.py +++ b/wren-ai-service/src/pipelines/generation/chart_adjustment.py @@ -150,6 +150,7 @@ class ChartAdjustment(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): super().__init__( @@ -157,6 +158,7 @@ def __init__( ) self._llm_provider = llm_provider + self._description = description self._components = self._update_components() with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f: diff --git a/wren-ai-service/src/pipelines/generation/chart_generation.py b/wren-ai-service/src/pipelines/generation/chart_generation.py index b1639e26ea..9d18e77c5a 100644 --- a/wren-ai-service/src/pipelines/generation/chart_generation.py +++ b/wren-ai-service/src/pipelines/generation/chart_generation.py @@ -123,6 +123,7 @@ class ChartGeneration(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): super().__init__( @@ -130,6 +131,7 @@ def __init__( ) self._llm_provider = llm_provider + self._description = description self._components = self._update_components() with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f: diff --git a/wren-ai-service/src/pipelines/generation/data_assistance.py b/wren-ai-service/src/pipelines/generation/data_assistance.py index 7a96de2c2a..1906552689 100644 --- a/wren-ai-service/src/pipelines/generation/data_assistance.py +++ b/wren-ai-service/src/pipelines/generation/data_assistance.py @@ -93,6 +93,7 @@ class DataAssistance(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): super().__init__( @@ -101,6 +102,7 @@ def __init__( self._user_queues = {} self._llm_provider = llm_provider + self._description = description self._components = self._update_components() def _update_components(self): diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py index bde753697c..a4c4b344be 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py @@ -154,6 +154,7 @@ def __init__( llm_provider: LLMProvider, document_store_provider: DocumentStoreProvider, engine: Engine, + description: str = "", **kwargs, ): super().__init__( @@ -165,6 +166,7 @@ def __init__( ) self._llm_provider = llm_provider self._engine = engine + self._description = description self._components = self._update_components() def _update_components(self): diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py index b225888a89..05a1f4231a 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py @@ -115,6 +115,7 @@ class FollowUpSQLGenerationReasoning(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): super().__init__( @@ -123,6 +124,7 @@ def __init__( self._user_queues = {} self._llm_provider = llm_provider + self._description = description self._components = self._update_components() def _update_components(self): diff --git a/wren-ai-service/src/pipelines/generation/intent_classification.py b/wren-ai-service/src/pipelines/generation/intent_classification.py index 3698281c44..b4f7d516ff 100644 --- a/wren-ai-service/src/pipelines/generation/intent_classification.py +++ b/wren-ai-service/src/pipelines/generation/intent_classification.py @@ -344,6 +344,7 @@ def __init__( wren_ai_docs: list[dict], table_retrieval_size: Optional[int] = 50, table_column_retrieval_size: Optional[int] = 100, + description: str = "", **kwargs, ): super().__init__( @@ -355,6 +356,7 @@ def __init__( self._table_column_retrieval_size = table_column_retrieval_size self._document_store_provider = document_store_provider self._embedder_provider = embedder_provider + self._description = description self._components = self._update_components() self._configs = { diff --git a/wren-ai-service/src/pipelines/generation/misleading_assistance.py b/wren-ai-service/src/pipelines/generation/misleading_assistance.py index 286a97ae15..8fe4cc41cc 100644 --- a/wren-ai-service/src/pipelines/generation/misleading_assistance.py +++ b/wren-ai-service/src/pipelines/generation/misleading_assistance.py @@ -93,6 +93,7 @@ class MisleadingAssistance(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): super().__init__( @@ -101,6 +102,7 @@ def __init__( self._user_queues = {} self._llm_provider = llm_provider + self._description = description self._components = self._update_components() def _update_components(self): diff --git a/wren-ai-service/src/pipelines/generation/question_recommendation.py b/wren-ai-service/src/pipelines/generation/question_recommendation.py index 14ecbf9634..5c78a0af10 100644 --- a/wren-ai-service/src/pipelines/generation/question_recommendation.py +++ b/wren-ai-service/src/pipelines/generation/question_recommendation.py @@ -235,6 +235,7 @@ class QuestionRecommendation(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **_, ): super().__init__( @@ -242,6 +243,7 @@ def __init__( ) self._llm_provider = llm_provider + self._description = description self._components = self._update_components() self._final = "normalized" diff --git a/wren-ai-service/src/pipelines/generation/relationship_recommendation.py b/wren-ai-service/src/pipelines/generation/relationship_recommendation.py index 517d02fa85..dc7a9ee43b 100644 --- a/wren-ai-service/src/pipelines/generation/relationship_recommendation.py +++ b/wren-ai-service/src/pipelines/generation/relationship_recommendation.py @@ -202,6 +202,7 @@ class RelationshipRecommendation(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **_, ): super().__init__( @@ -209,6 +210,7 @@ def __init__( ) self._llm_provider = llm_provider + self._description = description self._components = self._update_components() self._final = "validated" diff --git a/wren-ai-service/src/pipelines/generation/semantics_description.py b/wren-ai-service/src/pipelines/generation/semantics_description.py index 37f3e3f637..67a5d8acd8 100644 --- a/wren-ai-service/src/pipelines/generation/semantics_description.py +++ b/wren-ai-service/src/pipelines/generation/semantics_description.py @@ -217,12 +217,18 @@ class SemanticResult(BaseModel): class SemanticsDescription(BasicPipeline): - def __init__(self, llm_provider: LLMProvider, **_): + def __init__( + self, + llm_provider: LLMProvider, + description: str = "", + **_, + ): super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) self._llm_provider = llm_provider + self._description = description self._components = self._update_components() self._final = "output" diff --git a/wren-ai-service/src/pipelines/generation/sql_answer.py b/wren-ai-service/src/pipelines/generation/sql_answer.py index f8d530b256..147c5c2198 100644 --- a/wren-ai-service/src/pipelines/generation/sql_answer.py +++ b/wren-ai-service/src/pipelines/generation/sql_answer.py @@ -93,6 +93,7 @@ class SQLAnswer(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): super().__init__( @@ -101,6 +102,7 @@ def __init__( self._user_queues = {} self._llm_provider = llm_provider + self._description = description self._components = self._update_components() def _update_components(self): diff --git a/wren-ai-service/src/pipelines/generation/sql_correction.py b/wren-ai-service/src/pipelines/generation/sql_correction.py index 850837f9a1..36e085fca5 100644 --- a/wren-ai-service/src/pipelines/generation/sql_correction.py +++ b/wren-ai-service/src/pipelines/generation/sql_correction.py @@ -131,6 +131,7 @@ def __init__( llm_provider: LLMProvider, document_store_provider: DocumentStoreProvider, engine: Engine, + description: str = "", **kwargs, ): super().__init__( @@ -139,6 +140,7 @@ def __init__( self._llm_provider = llm_provider self._engine = engine + self._description = description self._components = self._update_components() self._retriever = document_store_provider.get_retriever( document_store_provider.get_store("project_meta") diff --git a/wren-ai-service/src/pipelines/generation/sql_diagnosis.py b/wren-ai-service/src/pipelines/generation/sql_diagnosis.py index 3f22b9d512..d52e402dc3 100644 --- a/wren-ai-service/src/pipelines/generation/sql_diagnosis.py +++ b/wren-ai-service/src/pipelines/generation/sql_diagnosis.py @@ -117,23 +117,29 @@ class SQLDiagnosis(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): - self._components = { - "generator": llm_provider.get_generator( + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + self._llm_provider = llm_provider + self._description = description + self._components = self._update_components() + + def _update_components(self): + return { + "generator": self._llm_provider.get_generator( system_prompt=sql_diagnosis_system_prompt, generation_kwargs=SQL_DIAGNOSIS_MODEL_KWARGS, ), - "generator_name": llm_provider.get_model(), + "generator_name": self._llm_provider.get_model(), "prompt_builder": PromptBuilder( template=sql_diagnosis_user_prompt_template ), } - super().__init__( - AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) - ) - @observe(name="SQL Diagnosis") async def run( self, diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index a2f9832279..ecf4c2bb48 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -151,6 +151,7 @@ def __init__( llm_provider: LLMProvider, document_store_provider: DocumentStoreProvider, engine: Engine, + description: str = "", **kwargs, ): super().__init__( @@ -162,6 +163,7 @@ def __init__( ) self._llm_provider = llm_provider self._engine = engine + self._description = description self._components = self._update_components() def _update_components(self): diff --git a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py index 21c141437d..08d788a76d 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py @@ -100,6 +100,7 @@ class SQLGenerationReasoning(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): super().__init__( @@ -108,6 +109,7 @@ def __init__( self._user_queues = {} self._llm_provider = llm_provider + self._description = description self._components = self._update_components() def _update_components(self): diff --git a/wren-ai-service/src/pipelines/generation/sql_question.py b/wren-ai-service/src/pipelines/generation/sql_question.py index e9a8cceef0..5217f35e2d 100644 --- a/wren-ai-service/src/pipelines/generation/sql_question.py +++ b/wren-ai-service/src/pipelines/generation/sql_question.py @@ -96,6 +96,7 @@ class SQLQuestion(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): super().__init__( @@ -103,6 +104,7 @@ def __init__( ) self._llm_provider = llm_provider + self._description = description self._components = self._update_components() def _update_components(self): diff --git a/wren-ai-service/src/pipelines/generation/sql_regeneration.py b/wren-ai-service/src/pipelines/generation/sql_regeneration.py index 3d1bfaa9f5..ede7d89a21 100644 --- a/wren-ai-service/src/pipelines/generation/sql_regeneration.py +++ b/wren-ai-service/src/pipelines/generation/sql_regeneration.py @@ -155,6 +155,7 @@ def __init__( self, llm_provider: LLMProvider, engine: Engine, + description: str = "", **kwargs, ): super().__init__( @@ -163,6 +164,7 @@ def __init__( self._llm_provider = llm_provider self._engine = engine + self._description = description self._components = self._update_components() def _update_components(self): diff --git a/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py b/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py index 1b6170e5f4..d0eec76ae7 100644 --- a/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py +++ b/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py @@ -100,6 +100,7 @@ class SQLTablesExtraction(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): super().__init__( @@ -107,6 +108,7 @@ def __init__( ) self._llm_provider = llm_provider + self._description = description self._components = self._update_components() def _update_components(self): diff --git a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py index 60233f1e94..c622396f44 100644 --- a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py +++ b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py @@ -83,6 +83,7 @@ def __init__( self, llm_provider: LLMProvider, wren_ai_docs: list[dict], + description: str = "", **kwargs, ): super().__init__( @@ -91,6 +92,7 @@ def __init__( self._user_queues = {} self._llm_provider = llm_provider + self._description = description self._components = self._update_components() self._configs = { diff --git a/wren-ai-service/src/pipelines/indexing/db_schema.py b/wren-ai-service/src/pipelines/indexing/db_schema.py index ea0596f2a9..cf91207a2e 100644 --- a/wren-ai-service/src/pipelines/indexing/db_schema.py +++ b/wren-ai-service/src/pipelines/indexing/db_schema.py @@ -342,6 +342,7 @@ def __init__( embedder_provider: EmbedderProvider, document_store_provider: DocumentStoreProvider, column_batch_size: int = 50, + description: str = "", **kwargs, ) -> None: super().__init__( @@ -349,6 +350,7 @@ def __init__( ) dbschema_store = document_store_provider.get_store() + self._description = description self._components = { "cleaner": DocumentCleaner([dbschema_store]), diff --git a/wren-ai-service/src/pipelines/indexing/historical_question.py b/wren-ai-service/src/pipelines/indexing/historical_question.py index b4ec96741e..2e25862cb4 100644 --- a/wren-ai-service/src/pipelines/indexing/historical_question.py +++ b/wren-ai-service/src/pipelines/indexing/historical_question.py @@ -137,6 +137,7 @@ def __init__( self, embedder_provider: EmbedderProvider, document_store_provider: DocumentStoreProvider, + description: str = "", **kwargs, ) -> None: super().__init__( @@ -145,6 +146,7 @@ def __init__( # keep the store name as it is for now, might change in the future store = document_store_provider.get_store(dataset_name="view_questions") + self._description = description self._components = { "cleaner": DocumentCleaner([store]), diff --git a/wren-ai-service/src/pipelines/indexing/instructions.py b/wren-ai-service/src/pipelines/indexing/instructions.py index fd473eb2fe..881abc7938 100644 --- a/wren-ai-service/src/pipelines/indexing/instructions.py +++ b/wren-ai-service/src/pipelines/indexing/instructions.py @@ -129,6 +129,7 @@ def __init__( self, embedder_provider: EmbedderProvider, document_store_provider: DocumentStoreProvider, + description: str = "", **kwargs, ) -> None: super().__init__( @@ -136,6 +137,7 @@ def __init__( ) store = document_store_provider.get_store(dataset_name="instructions") + self._description = description self._components = { "cleaner": InstructionsCleaner(store), diff --git a/wren-ai-service/src/pipelines/indexing/project_meta.py b/wren-ai-service/src/pipelines/indexing/project_meta.py index e7b9dd5443..ef1e1a6518 100644 --- a/wren-ai-service/src/pipelines/indexing/project_meta.py +++ b/wren-ai-service/src/pipelines/indexing/project_meta.py @@ -67,6 +67,7 @@ class ProjectMeta(BasicPipeline): def __init__( self, document_store_provider: DocumentStoreProvider, + description: str = "", **kwargs, ) -> None: super().__init__( @@ -74,6 +75,7 @@ def __init__( ) store = document_store_provider.get_store(dataset_name="project_meta") + self._description = description self._components = { "validator": MDLValidator(), diff --git a/wren-ai-service/src/pipelines/indexing/sql_pairs.py b/wren-ai-service/src/pipelines/indexing/sql_pairs.py index 6676891f80..9e7e4b44d9 100644 --- a/wren-ai-service/src/pipelines/indexing/sql_pairs.py +++ b/wren-ai-service/src/pipelines/indexing/sql_pairs.py @@ -169,6 +169,7 @@ def __init__( embedder_provider: EmbedderProvider, document_store_provider: DocumentStoreProvider, sql_pairs_path: str = "sql_pairs.json", + description: str = "", **kwargs, ) -> None: super().__init__( @@ -176,6 +177,7 @@ def __init__( ) store = document_store_provider.get_store(dataset_name="sql_pairs") + self._description = description self._components = { "cleaner": SqlPairsCleaner(store), diff --git a/wren-ai-service/src/pipelines/indexing/table_description.py b/wren-ai-service/src/pipelines/indexing/table_description.py index 40b2646aca..fb0a558fd1 100644 --- a/wren-ai-service/src/pipelines/indexing/table_description.py +++ b/wren-ai-service/src/pipelines/indexing/table_description.py @@ -120,6 +120,7 @@ def __init__( self, embedder_provider: EmbedderProvider, document_store_provider: DocumentStoreProvider, + description: str = "", **kwargs, ) -> None: super().__init__( @@ -129,6 +130,7 @@ def __init__( table_description_store = document_store_provider.get_store( dataset_name="table_descriptions" ) + self._description = description self._components = { "cleaner": DocumentCleaner([table_description_store]), diff --git a/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py b/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py index 995db02c2e..11acb43929 100644 --- a/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py @@ -457,6 +457,7 @@ def __init__( document_store_provider: DocumentStoreProvider, table_retrieval_size: int = 10, table_column_retrieval_size: int = 100, + description: str = "", **kwargs, ): super().__init__( @@ -468,6 +469,7 @@ def __init__( self._table_column_retrieval_size = table_column_retrieval_size self._document_store_provider = document_store_provider self._embedder_provider = embedder_provider + self._description = description self._components = self._update_components() self._configs = self._update_configs() diff --git a/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py index 4942380aa9..65501f5084 100644 --- a/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py @@ -123,6 +123,7 @@ def __init__( embedder_provider: EmbedderProvider, document_store_provider: DocumentStoreProvider, historical_question_retrieval_similarity_threshold: float = 0.9, + description: str = "", **kwargs, ) -> None: super().__init__( @@ -132,6 +133,8 @@ def __init__( view_questions_store = document_store_provider.get_store( dataset_name="view_questions" ) + self._description = description + self._components = { "view_questions_store": view_questions_store, "embedder": embedder_provider.get_text_embedder(), diff --git a/wren-ai-service/src/pipelines/retrieval/instructions.py b/wren-ai-service/src/pipelines/retrieval/instructions.py index 552b4d06fd..0cfe765a2c 100644 --- a/wren-ai-service/src/pipelines/retrieval/instructions.py +++ b/wren-ai-service/src/pipelines/retrieval/instructions.py @@ -189,6 +189,7 @@ def __init__( document_store_provider: DocumentStoreProvider, similarity_threshold: float = 0.7, top_k: int = 10, + description: str = "", **kwargs, ) -> None: super().__init__( @@ -196,6 +197,8 @@ def __init__( ) store = document_store_provider.get_store(dataset_name="instructions") + self._description = description + self._components = { "store": store, "embedder": embedder_provider.get_text_embedder(), diff --git a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py index 7f185a642a..a0f6c2f3f8 100644 --- a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py +++ b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py @@ -80,11 +80,13 @@ class PreprocessSqlData(BasicPipeline): def __init__( self, llm_provider: LLMProvider, + description: str = "", **kwargs, ): super().__init__(Driver({}, sys.modules[__name__], adapter=base.DictResult())) self._llm_provider = llm_provider + self._description = description self._configs = self._update_configs() def _update_configs(self): diff --git a/wren-ai-service/src/pipelines/retrieval/sql_executor.py b/wren-ai-service/src/pipelines/retrieval/sql_executor.py index a239e6d57d..dae00cedbd 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_executor.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_executor.py @@ -64,12 +64,14 @@ class SQLExecutor(BasicPipeline): def __init__( self, engine: Engine, + description: str = "", **kwargs, ): super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + self._description = description self._components = { "data_fetcher": DataFetcher(engine=engine), } diff --git a/wren-ai-service/src/pipelines/retrieval/sql_functions.py b/wren-ai-service/src/pipelines/retrieval/sql_functions.py index 78d179151f..b7b5f008bc 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_functions.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_functions.py @@ -85,12 +85,14 @@ def __init__( engine: Engine, document_store_provider: DocumentStoreProvider, ttl: int = 60 * 60 * 24, + description: str = "", **kwargs, ) -> None: super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) + self._description = description self._retriever = document_store_provider.get_retriever( document_store_provider.get_store("project_meta") ) diff --git a/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py b/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py index 414d3c9524..57accf57e5 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py @@ -120,6 +120,7 @@ def __init__( document_store_provider: DocumentStoreProvider, sql_pairs_similarity_threshold: float = 0.7, sql_pairs_retrieval_max_size: int = 10, + description: str = "", **kwargs, ) -> None: super().__init__( @@ -127,6 +128,8 @@ def __init__( ) store = document_store_provider.get_store(dataset_name="sql_pairs") + self._description = description + self._components = { "store": store, "embedder": embedder_provider.get_text_embedder(), diff --git a/wren-ai-service/src/providers/__init__.py b/wren-ai-service/src/providers/__init__.py index 2f3d7dc943..cf356ea2c6 100644 --- a/wren-ai-service/src/providers/__init__.py +++ b/wren-ai-service/src/providers/__init__.py @@ -107,9 +107,10 @@ def build_fallback_params(all_models: dict) -> dict: ] returned[model_name] = { - "provider": entry["provider"], - "model": model["model"], - "kwargs": model["kwargs"], + "provider": entry.get("provider"), + "model": model.get("model"), + "alias": model.get("alias"), + "kwargs": model.get("kwargs"), "context_window_size": model.get("context_window_size", 100000), "fallback_model_list": fallback_model_list, **model_additional_params, @@ -163,8 +164,9 @@ def embedder_processor(entry: dict) -> dict: k: v for k, v in model.items() if k not in ["model", "kwargs", "alias"] } returned[identifier] = { - "provider": entry["provider"], - "model": model["model"], + "provider": entry.get("provider"), + "model": model.get("model"), + "alias": model.get("alias"), **model_additional_params, **others, } @@ -277,6 +279,7 @@ def pipeline_processor(entry: dict) -> dict: "embedder": "openai_embedder.text-embedding-3-large", "document_store": "qdrant", "engine": "wren_ui", + "description": "Indexing pipeline", } } @@ -292,6 +295,7 @@ def pipeline_processor(entry: dict) -> dict: "embedder": pipe.get("embedder"), "document_store": pipe.get("document_store"), "engine": pipe.get("engine"), + "description": pipe.get("description"), } for pipe in entry["pipes"] } @@ -381,6 +385,7 @@ def get(type: str, components: dict, instantiated_providers: dict): def componentize(components: dict, instantiated_providers: dict): return PipelineComponent( + description=components.get("description", ""), embedder_provider=get("embedder", components, instantiated_providers), llm_provider=get("llm", components, instantiated_providers), document_store_provider=get( diff --git a/wren-ai-service/src/providers/embedder/litellm.py b/wren-ai-service/src/providers/embedder/litellm.py index 4d051e3284..24b8a27331 100644 --- a/wren-ai-service/src/providers/embedder/litellm.py +++ b/wren-ai-service/src/providers/embedder/litellm.py @@ -36,15 +36,17 @@ def __init__( self, model: str, api_key: Optional[str] = None, - api_base_url: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, timeout: Optional[float] = None, **kwargs, ): self._api_key = api_key self._model = model - self._api_base_url = api_base_url + self._api_base = api_base + self._api_version = api_version self._timeout = timeout - self._kwargs = kwargs + self._model_kwargs = kwargs @component.output_types(embedding=List[float], meta=Dict[str, Any]) @backoff.on_exception(backoff.expo, openai.APIError, max_time=60.0, max_tries=3) @@ -63,9 +65,10 @@ async def run(self, text: str): model=self._model, input=[text_to_embed], api_key=self._api_key, - api_base=self._api_base_url, + api_base=self._api_base, + api_version=self._api_version, timeout=self._timeout, - **self._kwargs, + **self._model_kwargs, ) meta = { @@ -83,16 +86,18 @@ def __init__( model: str, batch_size: int = 32, api_key: Optional[str] = None, - api_base_url: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, timeout: Optional[float] = None, **kwargs, ): self._api_key = api_key self._model = model self._batch_size = batch_size - self._api_base_url = api_base_url + self._api_base = api_base + self._api_version = api_version self._timeout = timeout - self._kwargs = kwargs + self._model_kwargs = kwargs async def _embed_batch( self, texts_to_embed: List[str], batch_size: int @@ -102,9 +107,10 @@ async def embed_single_batch(batch: List[str]) -> Any: model=self._model, input=batch, api_key=self._api_key, - api_base=self._api_base_url, + api_base=self._api_base, + api_version=self._api_version, timeout=self._timeout, - **self._kwargs, + **self._model_kwargs, ) batches = [ @@ -171,31 +177,37 @@ def __init__( str ] = None, # e.g. EMBEDDER_OPENAI_API_KEY, EMBEDDER_ANTHROPIC_API_KEY, etc. api_base: Optional[str] = None, + api_version: Optional[str] = None, + alias: Optional[str] = None, timeout: float = 120.0, **kwargs, ): self._api_key = os.getenv(api_key_name) if api_key_name else None self._api_base = remove_trailing_slash(api_base) if api_base else None - self._embedding_model = model + self._api_version = api_version + self._model = model + self._alias = alias self._timeout = timeout if "provider" in kwargs: del kwargs["provider"] - self._kwargs = kwargs + self._model_kwargs = kwargs def get_text_embedder(self): return AsyncTextEmbedder( api_key=self._api_key, api_base_url=self._api_base, - model=self._embedding_model, + api_version=self._api_version, + model=self._model, timeout=self._timeout, - **self._kwargs, + **self._model_kwargs, ) def get_document_embedder(self): return AsyncDocumentEmbedder( api_key=self._api_key, - api_base_url=self._api_base, - model=self._embedding_model, + api_base=self._api_base, + api_version=self._api_version, + model=self._model, timeout=self._timeout, - **self._kwargs, + **self._model_kwargs, ) diff --git a/wren-ai-service/src/providers/llm/litellm.py b/wren-ai-service/src/providers/llm/litellm.py index 3748a945da..97752a4eda 100644 --- a/wren-ai-service/src/providers/llm/litellm.py +++ b/wren-ai-service/src/providers/llm/litellm.py @@ -29,6 +29,7 @@ def __init__( ] = None, # e.g. OPENAI_API_KEY, LLM_ANTHROPIC_API_KEY, etc. api_base: Optional[str] = None, api_version: Optional[str] = None, + alias: Optional[str] = None, kwargs: Optional[Dict[str, Any]] = None, timeout: float = 120.0, context_window_size: int = 100000, @@ -37,7 +38,8 @@ def __init__( **_, ): self._model = model - # TODO: remove _api_key, _api_base, _api_version in the future, as it is not used in litellm + self._alias = alias + self._api_key_name = api_key_name self._api_key = os.getenv(api_key_name) if api_key_name else None self._api_base = remove_trailing_slash(api_base) if api_base else None self._api_version = api_version diff --git a/wren-ai-service/src/utils.py b/wren-ai-service/src/utils.py index 7abfb40a58..92413de0f4 100644 --- a/wren-ai-service/src/utils.py +++ b/wren-ai-service/src/utils.py @@ -3,11 +3,12 @@ import os import re from pathlib import Path +from typing import Any, Optional import requests from dotenv import load_dotenv from langfuse.decorators import langfuse_context -from pydantic import BaseModel +from pydantic import BaseModel, Field from src.config import Settings @@ -221,9 +222,38 @@ def extract_braces_content(resp: str) -> str: return match.group(1) if match else resp -class SinglePipeComponentRequest(BaseModel): - pipeline_name: str - llm_config: str +class Configs(BaseModel): + class Providers(BaseModel): + class LLMProvider(BaseModel): + model: str + alias: str + api_base: Optional[str] + api_version: Optional[str] + context_window_size: int + timeout: float = 600 + kwargs: Optional[dict[str, Any]] + + class EmbedderProvider(BaseModel): + model: str + alias: str + dimension: int + api_base: Optional[str] + api_version: Optional[str] + timeout: float = 600 + kwargs: Optional[dict[str, Any]] + + llm: list[LLMProvider] = Field(default_factory=list) + embedder: list[EmbedderProvider] = Field(default_factory=list) + + class Pipeline(BaseModel): + has_db_data_in_llm_prompt: bool + llm: str + embedder: str + description: str + + env_vars: dict[str, str] + providers: Providers + pipelines: dict[str, Pipeline] def has_db_data_in_llm_prompt(pipe_name: str) -> bool: From 7ffb9773bfc430c37c8dfe4b300196c611cd824a Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Thu, 18 Sep 2025 10:28:51 +0800 Subject: [PATCH 4/7] update --- wren-ai-service/src/__main__.py | 66 +++++++++++-------- wren-ai-service/src/core/provider.py | 24 +++++-- wren-ai-service/src/globals.py | 14 ++-- .../pipelines/generation/chart_adjustment.py | 2 +- .../pipelines/generation/chart_generation.py | 2 +- .../pipelines/generation/data_assistance.py | 2 +- .../generation/followup_sql_generation.py | 2 +- .../followup_sql_generation_reasoning.py | 2 +- .../generation/intent_classification.py | 2 +- .../generation/misleading_assistance.py | 2 +- .../generation/question_recommendation.py | 2 +- .../generation/relationship_recommendation.py | 2 +- .../generation/semantics_description.py | 2 +- .../src/pipelines/generation/sql_answer.py | 2 +- .../pipelines/generation/sql_correction.py | 2 +- .../src/pipelines/generation/sql_diagnosis.py | 2 +- .../pipelines/generation/sql_generation.py | 2 +- .../generation/sql_generation_reasoning.py | 2 +- .../src/pipelines/generation/sql_question.py | 2 +- .../pipelines/generation/sql_regeneration.py | 2 +- .../generation/sql_tables_extraction.py | 2 +- .../generation/user_guide_assistance.py | 2 +- .../retrieval/db_schema_retrieval.py | 6 +- .../retrieval/preprocess_sql_data.py | 4 +- wren-ai-service/src/providers/__init__.py | 5 +- wren-ai-service/src/utils.py | 6 +- 26 files changed, 94 insertions(+), 69 deletions(-) diff --git a/wren-ai-service/src/__main__.py b/wren-ai-service/src/__main__.py index 1b4e2c7702..4877804fc0 100644 --- a/wren-ai-service/src/__main__.py +++ b/wren-ai-service/src/__main__.py @@ -15,6 +15,7 @@ ) from src.providers import generate_components from src.providers.embedder.litellm import LitellmEmbedderProvider +from src.providers.llm.litellm import LitellmLLMProvider from src.utils import ( Configs, init_langfuse, @@ -107,36 +108,38 @@ def get_configs(): _llm_configs = [] for _, model_config in app.state.instantiated_providers["llm"].items(): - _llm_configs.append( - { - "model": model_config._model, - "alias": model_config._alias, - "api_base": model_config._api_base, - "api_version": model_config._api_version, - "context_window_size": model_config._context_window_size, - "timeout": model_config._timeout, - "kwargs": model_config._model_kwargs, - } - ) + _llm_config = { + "model": model_config._model, + "alias": model_config._alias, + "context_window_size": model_config._context_window_size, + "timeout": model_config._timeout, + "kwargs": model_config._model_kwargs, + } + if model_config._api_base: + _llm_config["api_base"] = model_config._api_base + if model_config._api_version: + _llm_config["api_version"] = model_config._api_version + _llm_configs.append(_llm_config) _llm_model_alias_mapping[model_config._model] = model_config._alias _configs["providers"]["llm"] = _llm_configs _embedder_configs = [] # we only support one embedding model now for _, model_config in app.state.instantiated_providers["embedder"].items(): - _embedder_configs.append( - { - "model": model_config._model, - "alias": model_config._alias, - "dimension": app.state.instantiated_providers["document_store"][ - "qdrant" - ]._embedding_model_dim, - "api_base": model_config._api_base, - "api_version": model_config._api_version, - "timeout": model_config._timeout, - "kwargs": model_config._model_kwargs, - } - ) + _embedder_config = { + "model": model_config._model, + "alias": model_config._alias, + "dimension": app.state.instantiated_providers["document_store"][ + "qdrant" + ]._embedding_model_dim, + "timeout": model_config._timeout, + "kwargs": model_config._model_kwargs, + } + if model_config._api_base: + _embedder_config["api_base"] = model_config._api_base + if model_config._api_version: + _embedder_config["api_version"] = model_config._api_version + _embedder_configs.append(_embedder_config) _embedder_model_alias_mapping[model_config._model] = model_config._alias break _configs["providers"]["embedder"] = _embedder_configs @@ -170,12 +173,19 @@ def get_configs(): @app.post("/configs") def update_configs(configs_request: Configs): - embedder_providers = {} - for embedder_provider in configs_request.providers.embedder: - identifier = f"litellm_embedder.{embedder_provider.alias if embedder_provider.alias else embedder_provider.model}" - embedder_providers[identifier] = LitellmEmbedderProvider( + # override current instantiated_providers + app.state.instantiated_providers["embedder"] = { + f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider( **embedder_provider.__dict__ ) + for embedder_provider in configs_request.providers.embedder + } + app.state.instantiated_providers["llm"] = { + f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider(**llm_provider.__dict__) + for llm_provider in configs_request.providers.llm + } + + print(f"pipe_components: {app.state.pipe_components}") # try: # for payload in pipe_components_request: diff --git a/wren-ai-service/src/core/provider.py b/wren-ai-service/src/core/provider.py index 822777424c..9b89d89a59 100644 --- a/wren-ai-service/src/core/provider.py +++ b/wren-ai-service/src/core/provider.py @@ -8,13 +8,20 @@ class LLMProvider(metaclass=ABCMeta): def get_generator(self, *args, **kwargs): ... - def get_model(self): + @property + def alias(self): + return self._alias + + @property + def model(self): return self._model - def get_model_kwargs(self): + @property + def model_kwargs(self): return self._model_kwargs - def get_context_window_size(self): + @property + def context_window_size(self): return self._context_window_size @@ -27,9 +34,18 @@ def get_text_embedder(self, *args, **kwargs): def get_document_embedder(self, *args, **kwargs): ... - def get_model(self): + @property + def alias(self): + return self._alias + + @property + def model(self): return self._model + @property + def model_kwargs(self): + return self._model_kwargs + class DocumentStoreProvider(metaclass=ABCMeta): @abstractmethod diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 8c2e67c7f0..fab06456fa 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -269,21 +269,19 @@ def create_pipe_components(service_container: ServiceContainer): if pipe_name not in _pipe_components: _pipe_components[pipe_name] = {} if hasattr(pipe, "_llm_provider") and pipe._llm_provider is not None: - _pipe_components[pipe_name]["llm"] = pipe._llm_provider.get_model() + _pipe_components[pipe_name]["llm"] = pipe._llm_provider.alias if ( hasattr(pipe, "_embedder_provider") and pipe._embedder_provider is not None ): - _pipe_components[pipe_name][ - "embedder" - ] = pipe._embedder_provider.get_model() + _pipe_components[pipe_name]["embedder"] = pipe._embedder_provider.alias if "services" not in _pipe_components[pipe_name]: _pipe_components[pipe_name]["services"] = set() _pipe_components[pipe_name]["services"].add(service) _pipe_components[pipe_name][ "has_db_data_in_llm_prompt" ] = has_db_data_in_llm_prompt(pipe_name) - _pipe_components[pipe_name]["description"] = pipe._description + _pipe_components[pipe_name]["description"] = pipe._description or "" return _pipe_components @@ -315,8 +313,8 @@ def _convert_pipe_metadata( ) -> dict: llm_metadata = ( { - "llm_model": llm_provider.get_model(), - "llm_model_kwargs": llm_provider.get_model_kwargs(), + "llm_model": llm_provider.model, + "llm_model_kwargs": llm_provider.model_kwargs, } if llm_provider else {} @@ -324,7 +322,7 @@ def _convert_pipe_metadata( embedding_metadata = ( { - "embedding_model": embedder_provider.get_model(), + "embedding_model": embedder_provider.model, } if embedder_provider else {} diff --git a/wren-ai-service/src/pipelines/generation/chart_adjustment.py b/wren-ai-service/src/pipelines/generation/chart_adjustment.py index 168671e4f4..e0abace0d6 100644 --- a/wren-ai-service/src/pipelines/generation/chart_adjustment.py +++ b/wren-ai-service/src/pipelines/generation/chart_adjustment.py @@ -177,7 +177,7 @@ def _update_components(self): system_prompt=chart_adjustment_system_prompt, generation_kwargs=CHART_ADJUSTMENT_MODEL_KWARGS, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "chart_data_preprocessor": ChartDataPreprocessor(), "post_processor": ChartGenerationPostProcessor(), } diff --git a/wren-ai-service/src/pipelines/generation/chart_generation.py b/wren-ai-service/src/pipelines/generation/chart_generation.py index 9d18e77c5a..c091179cb1 100644 --- a/wren-ai-service/src/pipelines/generation/chart_generation.py +++ b/wren-ai-service/src/pipelines/generation/chart_generation.py @@ -150,7 +150,7 @@ def _update_components(self): system_prompt=chart_generation_system_prompt, generation_kwargs=CHART_GENERATION_MODEL_KWARGS, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "chart_data_preprocessor": ChartDataPreprocessor(), "post_processor": ChartGenerationPostProcessor(), } diff --git a/wren-ai-service/src/pipelines/generation/data_assistance.py b/wren-ai-service/src/pipelines/generation/data_assistance.py index 1906552689..451c1e96c8 100644 --- a/wren-ai-service/src/pipelines/generation/data_assistance.py +++ b/wren-ai-service/src/pipelines/generation/data_assistance.py @@ -111,7 +111,7 @@ def _update_components(self): system_prompt=data_assistance_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=data_assistance_user_prompt_template ), diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py index a4c4b344be..947944a663 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py @@ -175,7 +175,7 @@ def _update_components(self): system_prompt=sql_generation_system_prompt, generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=text_to_sql_with_followup_user_prompt_template ), diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py index 05a1f4231a..baf2cc6e39 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py @@ -133,7 +133,7 @@ def _update_components(self): system_prompt=sql_generation_reasoning_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=sql_generation_reasoning_user_prompt_template ), diff --git a/wren-ai-service/src/pipelines/generation/intent_classification.py b/wren-ai-service/src/pipelines/generation/intent_classification.py index b4f7d516ff..176b1e71c9 100644 --- a/wren-ai-service/src/pipelines/generation/intent_classification.py +++ b/wren-ai-service/src/pipelines/generation/intent_classification.py @@ -380,7 +380,7 @@ def _update_components(self): system_prompt=intent_classification_system_prompt, generation_kwargs=INTENT_CLASSIFICAION_MODEL_KWARGS, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=intent_classification_user_prompt_template ), diff --git a/wren-ai-service/src/pipelines/generation/misleading_assistance.py b/wren-ai-service/src/pipelines/generation/misleading_assistance.py index 8fe4cc41cc..ca59b7075d 100644 --- a/wren-ai-service/src/pipelines/generation/misleading_assistance.py +++ b/wren-ai-service/src/pipelines/generation/misleading_assistance.py @@ -111,7 +111,7 @@ def _update_components(self): system_prompt=misleading_assistance_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=misleading_assistance_user_prompt_template ), diff --git a/wren-ai-service/src/pipelines/generation/question_recommendation.py b/wren-ai-service/src/pipelines/generation/question_recommendation.py index 5c78a0af10..1ec08039c5 100644 --- a/wren-ai-service/src/pipelines/generation/question_recommendation.py +++ b/wren-ai-service/src/pipelines/generation/question_recommendation.py @@ -255,7 +255,7 @@ def _update_components(self): system_prompt=system_prompt, generation_kwargs=QUESTION_RECOMMENDATION_MODEL_KWARGS, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, } @observe(name="Question Recommendation") diff --git a/wren-ai-service/src/pipelines/generation/relationship_recommendation.py b/wren-ai-service/src/pipelines/generation/relationship_recommendation.py index dc7a9ee43b..e80dbc94c8 100644 --- a/wren-ai-service/src/pipelines/generation/relationship_recommendation.py +++ b/wren-ai-service/src/pipelines/generation/relationship_recommendation.py @@ -222,7 +222,7 @@ def _update_components(self): system_prompt=system_prompt, generation_kwargs=RELATIONSHIP_RECOMMENDATION_MODEL_KWARGS, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, } @observe(name="Relationship Recommendation") diff --git a/wren-ai-service/src/pipelines/generation/semantics_description.py b/wren-ai-service/src/pipelines/generation/semantics_description.py index 67a5d8acd8..7d072c5e9e 100644 --- a/wren-ai-service/src/pipelines/generation/semantics_description.py +++ b/wren-ai-service/src/pipelines/generation/semantics_description.py @@ -239,7 +239,7 @@ def _update_components(self): system_prompt=system_prompt, generation_kwargs=SEMANTICS_DESCRIPTION_MODEL_KWARGS, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, } @observe(name="Semantics Description Generation") diff --git a/wren-ai-service/src/pipelines/generation/sql_answer.py b/wren-ai-service/src/pipelines/generation/sql_answer.py index 147c5c2198..58d833696e 100644 --- a/wren-ai-service/src/pipelines/generation/sql_answer.py +++ b/wren-ai-service/src/pipelines/generation/sql_answer.py @@ -114,7 +114,7 @@ def _update_components(self): system_prompt=sql_to_answer_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, } def _streaming_callback(self, chunk, query_id): diff --git a/wren-ai-service/src/pipelines/generation/sql_correction.py b/wren-ai-service/src/pipelines/generation/sql_correction.py index 36e085fca5..aa1f611056 100644 --- a/wren-ai-service/src/pipelines/generation/sql_correction.py +++ b/wren-ai-service/src/pipelines/generation/sql_correction.py @@ -152,7 +152,7 @@ def _update_components(self): system_prompt=sql_correction_system_prompt, generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=sql_correction_user_prompt_template ), diff --git a/wren-ai-service/src/pipelines/generation/sql_diagnosis.py b/wren-ai-service/src/pipelines/generation/sql_diagnosis.py index d52e402dc3..4d79911ee0 100644 --- a/wren-ai-service/src/pipelines/generation/sql_diagnosis.py +++ b/wren-ai-service/src/pipelines/generation/sql_diagnosis.py @@ -134,7 +134,7 @@ def _update_components(self): system_prompt=sql_diagnosis_system_prompt, generation_kwargs=SQL_DIAGNOSIS_MODEL_KWARGS, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=sql_diagnosis_user_prompt_template ), diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index ecf4c2bb48..1327375fc0 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -172,7 +172,7 @@ def _update_components(self): system_prompt=sql_generation_system_prompt, generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=sql_generation_user_prompt_template ), diff --git a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py index 08d788a76d..05fa1ba0e6 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py @@ -118,7 +118,7 @@ def _update_components(self): system_prompt=sql_generation_reasoning_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=sql_generation_reasoning_user_prompt_template ), diff --git a/wren-ai-service/src/pipelines/generation/sql_question.py b/wren-ai-service/src/pipelines/generation/sql_question.py index 5217f35e2d..1b4f59d901 100644 --- a/wren-ai-service/src/pipelines/generation/sql_question.py +++ b/wren-ai-service/src/pipelines/generation/sql_question.py @@ -113,7 +113,7 @@ def _update_components(self): system_prompt=sql_question_system_prompt, generation_kwargs=SQL_QUESTION_MODEL_KWARGS, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder(template=sql_question_user_prompt_template), } diff --git a/wren-ai-service/src/pipelines/generation/sql_regeneration.py b/wren-ai-service/src/pipelines/generation/sql_regeneration.py index ede7d89a21..7de247fcd6 100644 --- a/wren-ai-service/src/pipelines/generation/sql_regeneration.py +++ b/wren-ai-service/src/pipelines/generation/sql_regeneration.py @@ -173,7 +173,7 @@ def _update_components(self): system_prompt=sql_regeneration_system_prompt, generation_kwargs=SQL_GENERATION_MODEL_KWARGS, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=sql_regeneration_user_prompt_template ), diff --git a/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py b/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py index d0eec76ae7..e8e6b07eb4 100644 --- a/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py +++ b/wren-ai-service/src/pipelines/generation/sql_tables_extraction.py @@ -117,7 +117,7 @@ def _update_components(self): system_prompt=sql_tables_extraction_system_prompt, generation_kwargs=SQL_TABLES_EXTRACTION_MODEL_KWARGS, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=sql_tables_extraction_user_prompt_template ), diff --git a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py index c622396f44..fd72ee3c85 100644 --- a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py +++ b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py @@ -105,7 +105,7 @@ def _update_components(self): system_prompt=user_guide_assistance_system_prompt, streaming_callback=self._streaming_callback, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=user_guide_assistance_user_prompt_template ), diff --git a/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py b/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py index 11acb43929..c12210b142 100644 --- a/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py @@ -474,7 +474,7 @@ def __init__( self._configs = self._update_configs() def _update_configs(self): - _model = self._llm_provider.get_model() + _model = (self._llm_provider.model,) if "gpt-4o" in _model or "gpt-4o-mini" in _model: _encoding = tiktoken.get_encoding("o200k_base") else: @@ -482,7 +482,7 @@ def _update_configs(self): return { "encoding": _encoding, - "context_window_size": self._llm_provider.get_context_window_size(), + "context_window_size": self._llm_provider.context_window_size, } def _update_components(self): @@ -502,7 +502,7 @@ def _update_components(self): system_prompt=table_columns_selection_system_prompt, generation_kwargs=RETRIEVAL_MODEL_KWARGS, ), - "generator_name": self._llm_provider.get_model(), + "generator_name": self._llm_provider.model, "prompt_builder": PromptBuilder( template=table_columns_selection_user_prompt_template ), diff --git a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py index a0f6c2f3f8..1cae34515d 100644 --- a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py +++ b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py @@ -90,7 +90,7 @@ def __init__( self._configs = self._update_configs() def _update_configs(self): - _model = self._llm_provider.get_model() + _model = (self._llm_provider.model,) if _model == "gpt-4o-mini" or _model == "gpt-4o": _encoding = tiktoken.get_encoding("o200k_base") else: @@ -98,7 +98,7 @@ def _update_configs(self): return { "encoding": _encoding, - "context_window_size": self._llm_provider.get_context_window_size(), + "context_window_size": self._llm_provider.context_window_size, } def update_llm_provider(self, llm_provider: LLMProvider): diff --git a/wren-ai-service/src/providers/__init__.py b/wren-ai-service/src/providers/__init__.py index cf356ea2c6..393ac3614b 100644 --- a/wren-ai-service/src/providers/__init__.py +++ b/wren-ai-service/src/providers/__init__.py @@ -109,7 +109,7 @@ def build_fallback_params(all_models: dict) -> dict: returned[model_name] = { "provider": entry.get("provider"), "model": model.get("model"), - "alias": model.get("alias"), + "alias": model.get("alias", model.get("model")), "kwargs": model.get("kwargs"), "context_window_size": model.get("context_window_size", 100000), "fallback_model_list": fallback_model_list, @@ -166,7 +166,7 @@ def embedder_processor(entry: dict) -> dict: returned[identifier] = { "provider": entry.get("provider"), "model": model.get("model"), - "alias": model.get("alias"), + "alias": model.get("alias", model.get("model")), **model_additional_params, **others, } @@ -378,6 +378,7 @@ def generate_components(configs: list[dict]) -> dict[str, PipelineComponent]: } for type, configs in config.providers.items() } + print(f"instantiated_providers: {instantiated_providers}") def get(type: str, components: dict, instantiated_providers: dict): identifier = components.get(type) diff --git a/wren-ai-service/src/utils.py b/wren-ai-service/src/utils.py index 92413de0f4..bc797bc0d6 100644 --- a/wren-ai-service/src/utils.py +++ b/wren-ai-service/src/utils.py @@ -247,9 +247,9 @@ class EmbedderProvider(BaseModel): class Pipeline(BaseModel): has_db_data_in_llm_prompt: bool - llm: str - embedder: str - description: str + llm: Optional[str] = None + embedder: Optional[str] = None + description: Optional[str] = None env_vars: dict[str, str] providers: Providers From 3ebe8fac06e21526c883e2cf42c04c3de2d4f8f7 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Thu, 18 Sep 2025 11:40:05 +0800 Subject: [PATCH 5/7] update --- wren-ai-service/src/__main__.py | 57 ++++++++++--------- wren-ai-service/src/core/pipeline.py | 14 +++-- .../generation/sql_generation_reasoning.py | 4 -- .../src/pipelines/indexing/db_schema.py | 27 +++++---- .../pipelines/indexing/historical_question.py | 21 ++++--- .../src/pipelines/indexing/instructions.py | 16 ++++-- .../src/pipelines/indexing/project_meta.py | 16 ++++-- .../src/pipelines/indexing/sql_pairs.py | 19 ++++--- .../pipelines/indexing/table_description.py | 19 ++++--- .../retrieval/db_schema_retrieval.py | 9 ++- .../historical_question_retrieval.py | 23 ++++---- .../src/pipelines/retrieval/instructions.py | 25 ++++---- .../retrieval/preprocess_sql_data.py | 2 +- .../src/pipelines/retrieval/sql_functions.py | 11 +++- .../retrieval/sql_pairs_retrieval.py | 25 ++++---- wren-ai-service/src/providers/__init__.py | 1 - wren-ai-service/src/utils.py | 12 ++-- 17 files changed, 181 insertions(+), 120 deletions(-) diff --git a/wren-ai-service/src/__main__.py b/wren-ai-service/src/__main__.py index 4877804fc0..6801313c23 100644 --- a/wren-ai-service/src/__main__.py +++ b/wren-ai-service/src/__main__.py @@ -1,7 +1,7 @@ from contextlib import asynccontextmanager import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, RedirectResponse @@ -173,32 +173,37 @@ def get_configs(): @app.post("/configs") def update_configs(configs_request: Configs): - # override current instantiated_providers - app.state.instantiated_providers["embedder"] = { - f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider( - **embedder_provider.__dict__ - ) - for embedder_provider in configs_request.providers.embedder - } - app.state.instantiated_providers["llm"] = { - f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider(**llm_provider.__dict__) - for llm_provider in configs_request.providers.llm - } + try: + # override current instantiated_providers + app.state.instantiated_providers["embedder"] = { + f"litellm_embedder.{embedder_provider.alias}": LitellmEmbedderProvider( + **embedder_provider.__dict__ + ) + for embedder_provider in configs_request.providers.embedder + } + app.state.instantiated_providers["llm"] = { + f"litellm_llm.{llm_provider.alias}": LitellmLLMProvider( + **llm_provider.__dict__ + ) + for llm_provider in configs_request.providers.llm + } - print(f"pipe_components: {app.state.pipe_components}") - - # try: - # for payload in pipe_components_request: - # for service in app.state.pipe_components[payload.pipeline_name].get( - # "services", [] - # ): - # service._pipelines[payload.pipeline_name].update_llm_provider( - # app.state.instantiated_providers["llm"][payload.llm_config] - # ) - # except Exception as e: - # raise HTTPException( - # status_code=500, detail=f"Error updating pipe components: {e}" - # ) + # override current pipe_components + for pipe_name, pipe_component in app.state.pipe_components.items(): + if pipe_name in configs_request.pipelines: + pipe_config = configs_request.pipelines[pipe_name] + pipe_component.update(pipe_config) + + # changing llm models and embedding models based on configs_request + for pipeline_name, pipe_config in configs_request.pipelines.items(): + for service in app.state.pipe_components[pipeline_name].get("services", []): + service._pipelines[pipeline_name].update_components( + llm_provider=app.state.instantiated_providers["llm"][ + f"litellm_llm.{pipe_config.llm}" + ] + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error updating configs: {e}") if __name__ == "__main__": diff --git a/wren-ai-service/src/core/pipeline.py b/wren-ai-service/src/core/pipeline.py index aeabf9b150..8bcc5ac6d9 100644 --- a/wren-ai-service/src/core/pipeline.py +++ b/wren-ai-service/src/core/pipeline.py @@ -1,7 +1,7 @@ from abc import ABCMeta, abstractmethod from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Dict +from typing import Any, Dict, Optional from hamilton.async_driver import AsyncDriver from hamilton.driver import Driver @@ -17,6 +17,7 @@ def __init__(self, pipe: Pipeline | AsyncDriver | Driver): self._description = "" self._llm_provider = None self._embedder_provider = None + self._document_store_provider = None self._components = {} @abstractmethod @@ -26,12 +27,15 @@ def run(self, *args, **kwargs) -> Dict[str, Any]: def _update_components(self) -> dict: ... - def update_llm_provider(self, llm_provider: LLMProvider): + def update_components( + self, + llm_provider: Optional[LLMProvider] = None, + embedder_provider: Optional[EmbedderProvider] = None, + document_store_provider: Optional[DocumentStoreProvider] = None, + ): self._llm_provider = llm_provider - self._components = self._update_components() - - def update_embedder_provider(self, embedder_provider: EmbedderProvider): self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider self._components = self._update_components() def __str__(self): diff --git a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py index 05fa1ba0e6..ed4c57aeb5 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py @@ -124,10 +124,6 @@ def _update_components(self): ), } - def update_llm_provider(self, llm_provider: LLMProvider): - self._llm_provider = llm_provider - self._components = self._update_components() - def _streaming_callback(self, chunk, query_id): if query_id not in self._user_queues: self._user_queues[query_id] = asyncio.Queue() diff --git a/wren-ai-service/src/pipelines/indexing/db_schema.py b/wren-ai-service/src/pipelines/indexing/db_schema.py index cf91207a2e..e0cd12070f 100644 --- a/wren-ai-service/src/pipelines/indexing/db_schema.py +++ b/wren-ai-service/src/pipelines/indexing/db_schema.py @@ -349,19 +349,12 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - dbschema_store = document_store_provider.get_store() + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider + self._dbschema_store = self._document_store_provider.get_store() self._description = description + self._components = self._update_components() - self._components = { - "cleaner": DocumentCleaner([dbschema_store]), - "validator": MDLValidator(), - "embedder": embedder_provider.get_document_embedder(), - "chunker": DDLChunker(), - "writer": AsyncDocumentWriter( - document_store=dbschema_store, - policy=DuplicatePolicy.OVERWRITE, - ), - } self._configs = { "column_batch_size": column_batch_size, } @@ -369,6 +362,18 @@ def __init__( helper.load_helpers() + def _update_components(self): + return { + "cleaner": DocumentCleaner([self._dbschema_store]), + "validator": MDLValidator(), + "embedder": self._embedder_provider.get_document_embedder(), + "chunker": DDLChunker(), + "writer": AsyncDocumentWriter( + document_store=self._dbschema_store, + policy=DuplicatePolicy.OVERWRITE, + ), + } + @observe(name="DB Schema Indexing") async def run( self, mdl_str: str, project_id: Optional[str] = None diff --git a/wren-ai-service/src/pipelines/indexing/historical_question.py b/wren-ai-service/src/pipelines/indexing/historical_question.py index 2e25862cb4..69158f3527 100644 --- a/wren-ai-service/src/pipelines/indexing/historical_question.py +++ b/wren-ai-service/src/pipelines/indexing/historical_question.py @@ -145,21 +145,28 @@ def __init__( ) # keep the store name as it is for now, might change in the future - store = document_store_provider.get_store(dataset_name="view_questions") + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider + self._store = self._document_store_provider.get_store( + dataset_name="view_questions" + ) self._description = description + self._components = self._update_components() + + self._configs = {} + self._final = "write" - self._components = { - "cleaner": DocumentCleaner([store]), + def _update_components(self): + return { + "cleaner": DocumentCleaner([self._store]), "validator": MDLValidator(), - "embedder": embedder_provider.get_document_embedder(), + "embedder": self._embedder_provider.get_document_embedder(), "chunker": ViewChunker(), "writer": AsyncDocumentWriter( - document_store=store, + document_store=self._store, policy=DuplicatePolicy.OVERWRITE, ), } - self._configs = {} - self._final = "write" @observe(name="Historical Question Indexing") async def run( diff --git a/wren-ai-service/src/pipelines/indexing/instructions.py b/wren-ai-service/src/pipelines/indexing/instructions.py index 881abc7938..015b661169 100644 --- a/wren-ai-service/src/pipelines/indexing/instructions.py +++ b/wren-ai-service/src/pipelines/indexing/instructions.py @@ -136,15 +136,21 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - store = document_store_provider.get_store(dataset_name="instructions") + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider + self._store = self._document_store_provider.get_store( + dataset_name="instructions" + ) self._description = description + self._components = self._update_components() - self._components = { - "cleaner": InstructionsCleaner(store), - "embedder": embedder_provider.get_document_embedder(), + def _update_components(self): + return { + "cleaner": InstructionsCleaner(self._store), + "embedder": self._embedder_provider.get_document_embedder(), "document_converter": InstructionsConverter(), "writer": AsyncDocumentWriter( - document_store=store, + document_store=self._store, policy=DuplicatePolicy.OVERWRITE, ), } diff --git a/wren-ai-service/src/pipelines/indexing/project_meta.py b/wren-ai-service/src/pipelines/indexing/project_meta.py index ef1e1a6518..0cb0097910 100644 --- a/wren-ai-service/src/pipelines/indexing/project_meta.py +++ b/wren-ai-service/src/pipelines/indexing/project_meta.py @@ -74,18 +74,24 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - store = document_store_provider.get_store(dataset_name="project_meta") + self._document_store_provider = document_store_provider + self._store = self._document_store_provider.get_store( + dataset_name="project_meta" + ) self._description = description - self._components = { + self._components = self._update_components() + self._final = "write" + + def _update_components(self): + return { "validator": MDLValidator(), - "cleaner": DocumentCleaner([store]), + "cleaner": DocumentCleaner([self._store]), "writer": AsyncDocumentWriter( - document_store=store, + document_store=self._store, policy=DuplicatePolicy.OVERWRITE, ), } - self._final = "write" @observe(name="Project Meta Indexing") async def run( diff --git a/wren-ai-service/src/pipelines/indexing/sql_pairs.py b/wren-ai-service/src/pipelines/indexing/sql_pairs.py index 9e7e4b44d9..c14cc16a25 100644 --- a/wren-ai-service/src/pipelines/indexing/sql_pairs.py +++ b/wren-ai-service/src/pipelines/indexing/sql_pairs.py @@ -176,21 +176,26 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - store = document_store_provider.get_store(dataset_name="sql_pairs") + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider + self._store = self._document_store_provider.get_store(dataset_name="sql_pairs") self._description = description - self._components = { - "cleaner": SqlPairsCleaner(store), - "embedder": embedder_provider.get_document_embedder(), + self._components = self._update_components() + + self._external_pairs = _load_sql_pairs(sql_pairs_path) + + def _update_components(self): + return { + "cleaner": SqlPairsCleaner(self._store), + "embedder": self._embedder_provider.get_document_embedder(), "document_converter": SqlPairsConverter(), "writer": AsyncDocumentWriter( - document_store=store, + document_store=self._store, policy=DuplicatePolicy.OVERWRITE, ), } - self._external_pairs = _load_sql_pairs(sql_pairs_path) - @observe(name="SQL Pairs Indexing") async def run( self, diff --git a/wren-ai-service/src/pipelines/indexing/table_description.py b/wren-ai-service/src/pipelines/indexing/table_description.py index fb0a558fd1..4fc20ea9a1 100644 --- a/wren-ai-service/src/pipelines/indexing/table_description.py +++ b/wren-ai-service/src/pipelines/indexing/table_description.py @@ -127,23 +127,28 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - table_description_store = document_store_provider.get_store( + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider + self._table_description_store = self._document_store_provider.get_store( dataset_name="table_descriptions" ) self._description = description - self._components = { - "cleaner": DocumentCleaner([table_description_store]), + self._components = self._update_components() + self._configs = {} + self._final = "write" + + def _update_components(self): + return { + "cleaner": DocumentCleaner([self._table_description_store]), "validator": MDLValidator(), - "embedder": embedder_provider.get_document_embedder(), + "embedder": self._embedder_provider.get_document_embedder(), "chunker": TableDescriptionChunker(), "writer": AsyncDocumentWriter( - document_store=table_description_store, + document_store=self._table_description_store, policy=DuplicatePolicy.OVERWRITE, ), } - self._configs = {} - self._final = "write" @observe(name="Table Description Indexing") async def run( diff --git a/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py b/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py index c12210b142..ae1508f5fe 100644 --- a/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py @@ -508,9 +508,12 @@ def _update_components(self): ), } - def update_llm_provider(self, llm_provider: LLMProvider): - self._llm_provider = llm_provider - self._components = self._update_components() + def update_components( + self, llm_provider: LLMProvider, embedder_provider: EmbedderProvider + ): + super().update_components( + llm_provider=llm_provider, embedder_provider=embedder_provider + ) self._configs = self._update_configs() @observe(name="Ask Retrieval") diff --git a/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py index 65501f5084..2c815cb1e6 100644 --- a/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py @@ -130,26 +130,29 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - view_questions_store = document_store_provider.get_store( + self._view_questions_store = document_store_provider.get_store( dataset_name="view_questions" ) + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider self._description = description + self._components = self._update_components() + self._configs = { + "historical_question_retrieval_similarity_threshold": historical_question_retrieval_similarity_threshold, + } - self._components = { - "view_questions_store": view_questions_store, - "embedder": embedder_provider.get_text_embedder(), - "view_questions_retriever": document_store_provider.get_retriever( - document_store=view_questions_store, + def _update_components(self): + return { + "view_questions_store": self._view_questions_store, + "embedder": self._embedder_provider.get_text_embedder(), + "view_questions_retriever": self._document_store_provider.get_retriever( + document_store=self._view_questions_store, ), "score_filter": ScoreFilter(), # TODO: add a llm filter to filter out low scoring document, in case ScoreFilter is not accurate enough "output_formatter": OutputFormatter(), } - self._configs = { - "historical_question_retrieval_similarity_threshold": historical_question_retrieval_similarity_threshold, - } - @observe(name="Historical Question") async def run(self, query: str, project_id: Optional[str] = None): logger.info("HistoricalQuestion Retrieval pipeline is running...") diff --git a/wren-ai-service/src/pipelines/retrieval/instructions.py b/wren-ai-service/src/pipelines/retrieval/instructions.py index 0cfe765a2c..b704864b6c 100644 --- a/wren-ai-service/src/pipelines/retrieval/instructions.py +++ b/wren-ai-service/src/pipelines/retrieval/instructions.py @@ -196,23 +196,28 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - store = document_store_provider.get_store(dataset_name="instructions") + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider + self._store = document_store_provider.get_store(dataset_name="instructions") self._description = description - self._components = { - "store": store, - "embedder": embedder_provider.get_text_embedder(), - "retriever": document_store_provider.get_retriever( - document_store=store, + self._components = self._update_components() + self._configs = { + "similarity_threshold": similarity_threshold, + "top_k": top_k, + } + + def _update_components(self): + return { + "store": self._store, + "embedder": self._embedder_provider.get_text_embedder(), + "retriever": self._document_store_provider.get_retriever( + document_store=self._store, ), "scope_filter": ScopeFilter(), "score_filter": ScoreFilter(), "output_formatter": OutputFormatter(), } - self._configs = { - "similarity_threshold": similarity_threshold, - "top_k": top_k, - } @observe(name="Instructions Retrieval") async def run( diff --git a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py index 1cae34515d..7b886fe06e 100644 --- a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py +++ b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py @@ -101,7 +101,7 @@ def _update_configs(self): "context_window_size": self._llm_provider.context_window_size, } - def update_llm_provider(self, llm_provider: LLMProvider): + def update_components(self, llm_provider: LLMProvider): self._llm_provider = llm_provider self._configs = self._update_configs() diff --git a/wren-ai-service/src/pipelines/retrieval/sql_functions.py b/wren-ai-service/src/pipelines/retrieval/sql_functions.py index b7b5f008bc..593a54ae88 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_functions.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_functions.py @@ -93,8 +93,9 @@ def __init__( ) self._description = description - self._retriever = document_store_provider.get_retriever( - document_store_provider.get_store("project_meta") + self._document_store_provider = document_store_provider + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") ) self._cache = TTLCache(maxsize=100, ttl=ttl) self._components = { @@ -102,6 +103,12 @@ def __init__( "ttl_cache": self._cache, } + def update_components(self, document_store_provider: DocumentStoreProvider): + self._document_store_provider = document_store_provider + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") + ) + @observe(name="SQL Functions Retrieval") async def run( self, diff --git a/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py b/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py index 57accf57e5..3424ecb9fe 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py @@ -127,23 +127,28 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - store = document_store_provider.get_store(dataset_name="sql_pairs") + self._embedder_provider = embedder_provider + self._document_store_provider = document_store_provider + self._store = self._document_store_provider.get_store(dataset_name="sql_pairs") self._description = description + self._components = self._update_components() - self._components = { - "store": store, - "embedder": embedder_provider.get_text_embedder(), - "retriever": document_store_provider.get_retriever( - document_store=store, + self._configs = { + "sql_pairs_similarity_threshold": sql_pairs_similarity_threshold, + "sql_pairs_retrieval_max_size": sql_pairs_retrieval_max_size, + } + + def _update_components(self): + return { + "store": self._store, + "embedder": self._embedder_provider.get_text_embedder(), + "retriever": self._document_store_provider.get_retriever( + document_store=self._store, ), "score_filter": ScoreFilter(), # TODO: add a llm filter to filter out low scoring document, in case ScoreFilter is not accurate enough "output_formatter": OutputFormatter(), } - self._configs = { - "sql_pairs_similarity_threshold": sql_pairs_similarity_threshold, - "sql_pairs_retrieval_max_size": sql_pairs_retrieval_max_size, - } @observe(name="SqlPairs Retrieval") async def run(self, query: str, project_id: Optional[str] = None): diff --git a/wren-ai-service/src/providers/__init__.py b/wren-ai-service/src/providers/__init__.py index 393ac3614b..78ec07a2ef 100644 --- a/wren-ai-service/src/providers/__init__.py +++ b/wren-ai-service/src/providers/__init__.py @@ -378,7 +378,6 @@ def generate_components(configs: list[dict]) -> dict[str, PipelineComponent]: } for type, configs in config.providers.items() } - print(f"instantiated_providers: {instantiated_providers}") def get(type: str, components: dict, instantiated_providers: dict): identifier = components.get(type) diff --git a/wren-ai-service/src/utils.py b/wren-ai-service/src/utils.py index bc797bc0d6..91e9a05795 100644 --- a/wren-ai-service/src/utils.py +++ b/wren-ai-service/src/utils.py @@ -227,20 +227,20 @@ class Providers(BaseModel): class LLMProvider(BaseModel): model: str alias: str - api_base: Optional[str] - api_version: Optional[str] context_window_size: int timeout: float = 600 - kwargs: Optional[dict[str, Any]] + api_base: Optional[str] = None + api_version: Optional[str] = None + kwargs: Optional[dict[str, Any]] = None class EmbedderProvider(BaseModel): model: str alias: str dimension: int - api_base: Optional[str] - api_version: Optional[str] timeout: float = 600 - kwargs: Optional[dict[str, Any]] + kwargs: Optional[dict[str, Any]] = None + api_base: Optional[str] = None + api_version: Optional[str] = None llm: list[LLMProvider] = Field(default_factory=list) embedder: list[EmbedderProvider] = Field(default_factory=list) From 3ee04ea1fb0ec6e1a37b59a704c15711ac1806aa Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Thu, 18 Sep 2025 13:00:33 +0800 Subject: [PATCH 6/7] update --- wren-ai-service/src/__main__.py | 51 ++++++++++++++++--- wren-ai-service/src/core/pipeline.py | 11 ++-- .../retrieval/preprocess_sql_data.py | 2 +- .../src/pipelines/retrieval/sql_functions.py | 2 +- wren-ai-service/src/utils.py | 6 +-- 5 files changed, 55 insertions(+), 17 deletions(-) diff --git a/wren-ai-service/src/__main__.py b/wren-ai-service/src/__main__.py index 6801313c23..488d024e53 100644 --- a/wren-ai-service/src/__main__.py +++ b/wren-ai-service/src/__main__.py @@ -14,6 +14,7 @@ create_service_metadata, ) from src.providers import generate_components +from src.providers.document_store.qdrant import QdrantProvider from src.providers.embedder.litellm import LitellmEmbedderProvider from src.providers.llm.litellm import LitellmLLMProvider from src.utils import ( @@ -187,6 +188,19 @@ def update_configs(configs_request: Configs): ) for llm_provider in configs_request.providers.llm } + app.state.instantiated_providers["document_store"]["qdrant"] = QdrantProvider( + location=app.state.instantiated_providers["document_store"][ + "qdrant" + ]._location, + api_key=app.state.instantiated_providers["document_store"][ + "qdrant" + ]._api_key, + timeout=app.state.instantiated_providers["document_store"][ + "qdrant" + ]._timeout, + embedding_model_dim=configs_request.providers.embedder[0].dimension, + recreate_index=True, + ) # override current pipe_components for pipe_name, pipe_component in app.state.pipe_components.items(): @@ -194,14 +208,35 @@ def update_configs(configs_request: Configs): pipe_config = configs_request.pipelines[pipe_name] pipe_component.update(pipe_config) - # changing llm models and embedding models based on configs_request - for pipeline_name, pipe_config in configs_request.pipelines.items(): - for service in app.state.pipe_components[pipeline_name].get("services", []): - service._pipelines[pipeline_name].update_components( - llm_provider=app.state.instantiated_providers["llm"][ - f"litellm_llm.{pipe_config.llm}" - ] - ) + # updating pipelines + for pipeline_name, pipe_components in app.state.pipe_components.items(): + for service in pipe_components.get("services", []): + if pipe_config := configs_request.pipelines.get(pipeline_name): + service._pipelines[pipeline_name].update_components( + llm_provider=app.state.instantiated_providers["llm"][ + f"litellm_llm.{pipe_config.llm}" + ] + if pipe_config.llm + else None, + embedder_provider=app.state.instantiated_providers["embedder"][ + f"litellm_embedder.{pipe_config.embedder}" + ] + if pipe_config.embedder + else None, + ) + else: + if service._pipelines[pipeline_name]._document_store_provider: + service._pipelines[pipeline_name].update_components( + llm_provider=service._pipelines[ + pipeline_name + ]._llm_provider, + embedder_provider=service._pipelines[ + pipeline_name + ]._embedder_provider, + document_store_provider=app.state.instantiated_providers[ + "document_store" + ]["qdrant"], + ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error updating configs: {e}") diff --git a/wren-ai-service/src/core/pipeline.py b/wren-ai-service/src/core/pipeline.py index 8bcc5ac6d9..64fa37ac49 100644 --- a/wren-ai-service/src/core/pipeline.py +++ b/wren-ai-service/src/core/pipeline.py @@ -33,13 +33,16 @@ def update_components( embedder_provider: Optional[EmbedderProvider] = None, document_store_provider: Optional[DocumentStoreProvider] = None, ): - self._llm_provider = llm_provider - self._embedder_provider = embedder_provider - self._document_store_provider = document_store_provider + if llm_provider: + self._llm_provider = llm_provider + if embedder_provider: + self._embedder_provider = embedder_provider + if document_store_provider: + self._document_store_provider = document_store_provider self._components = self._update_components() def __str__(self): - return f"BasicPipeline(llm_provider={self._llm_provider}, embedder_provider={self._embedder_provider})" + return f"BasicPipeline(llm_provider={self._llm_provider}, embedder_provider={self._embedder_provider}, document_store_provider={self._document_store_provider})" @dataclass diff --git a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py index 7b886fe06e..51e1d56c73 100644 --- a/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py +++ b/wren-ai-service/src/pipelines/retrieval/preprocess_sql_data.py @@ -101,7 +101,7 @@ def _update_configs(self): "context_window_size": self._llm_provider.context_window_size, } - def update_components(self, llm_provider: LLMProvider): + def update_components(self, llm_provider: LLMProvider, **_): self._llm_provider = llm_provider self._configs = self._update_configs() diff --git a/wren-ai-service/src/pipelines/retrieval/sql_functions.py b/wren-ai-service/src/pipelines/retrieval/sql_functions.py index 593a54ae88..43abced26a 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_functions.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_functions.py @@ -103,7 +103,7 @@ def __init__( "ttl_cache": self._cache, } - def update_components(self, document_store_provider: DocumentStoreProvider): + def update_components(self, document_store_provider: DocumentStoreProvider, **_): self._document_store_provider = document_store_provider self._retriever = self._document_store_provider.get_retriever( self._document_store_provider.get_store("project_meta") diff --git a/wren-ai-service/src/utils.py b/wren-ai-service/src/utils.py index 91e9a05795..b27821426e 100644 --- a/wren-ai-service/src/utils.py +++ b/wren-ai-service/src/utils.py @@ -8,7 +8,7 @@ import requests from dotenv import load_dotenv from langfuse.decorators import langfuse_context -from pydantic import BaseModel, Field +from pydantic import BaseModel from src.config import Settings @@ -242,8 +242,8 @@ class EmbedderProvider(BaseModel): api_base: Optional[str] = None api_version: Optional[str] = None - llm: list[LLMProvider] = Field(default_factory=list) - embedder: list[EmbedderProvider] = Field(default_factory=list) + llm: list[LLMProvider] + embedder: list[EmbedderProvider] class Pipeline(BaseModel): has_db_data_in_llm_prompt: bool From 186e3d020e0f2097860cfbbedbb5ccbc7275ae43 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Thu, 18 Sep 2025 14:49:02 +0800 Subject: [PATCH 7/7] update --- wren-ai-service/src/__main__.py | 65 +++++++++++++------ wren-ai-service/src/core/pipeline.py | 4 +- .../generation/followup_sql_generation.py | 21 +++++- .../pipelines/generation/sql_correction.py | 18 ++++- .../pipelines/generation/sql_generation.py | 18 ++++- .../src/pipelines/indexing/db_schema.py | 14 ++++ .../pipelines/indexing/historical_question.py | 16 +++++ .../src/pipelines/indexing/instructions.py | 16 +++++ .../src/pipelines/indexing/project_meta.py | 9 +++ .../src/pipelines/indexing/sql_pairs.py | 14 ++++ .../pipelines/indexing/table_description.py | 16 +++++ .../retrieval/db_schema_retrieval.py | 20 +++++- .../historical_question_retrieval.py | 16 +++++ .../src/pipelines/retrieval/instructions.py | 16 +++++ .../retrieval/sql_pairs_retrieval.py | 14 ++++ .../src/providers/document_store/qdrant.py | 3 +- 16 files changed, 250 insertions(+), 30 deletions(-) diff --git a/wren-ai-service/src/__main__.py b/wren-ai-service/src/__main__.py index 488d024e53..1323d8c5b3 100644 --- a/wren-ai-service/src/__main__.py +++ b/wren-ai-service/src/__main__.py @@ -34,10 +34,13 @@ async def lifespan(app: FastAPI): # startup events pipe_components, instantiated_providers = generate_components(settings.components) + app.state.pipe_components = pipe_components + app.state.instantiated_providers = instantiated_providers app.state.service_container = create_service_container(pipe_components, settings) - app.state.pipe_components = create_pipe_components(app.state.service_container) + app.state.pipe_service_components = create_pipe_components( + app.state.service_container + ) app.state.service_metadata = create_service_metadata(pipe_components) - app.state.instantiated_providers = instantiated_providers init_langfuse(settings) yield @@ -145,7 +148,7 @@ def get_configs(): break _configs["providers"]["embedder"] = _embedder_configs - for pipe_name, pipe_component in app.state.pipe_components.items(): + for pipe_name, pipe_component in app.state.pipe_service_components.items(): llm_model = pipe_component.get("llm", None) embedding_model = pipe_component.get("embedder", None) description = pipe_component.get("description", "") @@ -201,28 +204,48 @@ def update_configs(configs_request: Configs): embedding_model_dim=configs_request.providers.embedder[0].dimension, recreate_index=True, ) + _embedder_providers = app.state.instantiated_providers["embedder"] + _llm_providers = app.state.instantiated_providers["llm"] + _document_store_provider = app.state.instantiated_providers["document_store"][ + "qdrant" + ] # override current pipe_components - for pipe_name, pipe_component in app.state.pipe_components.items(): + for ( + pipe_name, + pipe_service_components, + ) in app.state.pipe_service_components.items(): if pipe_name in configs_request.pipelines: pipe_config = configs_request.pipelines[pipe_name] - pipe_component.update(pipe_config) + pipe_service_components.update(pipe_config) # updating pipelines - for pipeline_name, pipe_components in app.state.pipe_components.items(): - for service in pipe_components.get("services", []): + for ( + pipeline_name, + pipe_service_components, + ) in app.state.pipe_service_components.items(): + for service in pipe_service_components.get("services", []): if pipe_config := configs_request.pipelines.get(pipeline_name): service._pipelines[pipeline_name].update_components( - llm_provider=app.state.instantiated_providers["llm"][ - f"litellm_llm.{pipe_config.llm}" - ] - if pipe_config.llm - else None, - embedder_provider=app.state.instantiated_providers["embedder"][ - f"litellm_embedder.{pipe_config.embedder}" - ] - if pipe_config.embedder - else None, + llm_provider=( + _llm_providers[f"litellm_llm.{pipe_config.llm}"] + if pipe_config.llm + else None + ), + embedder_provider=( + _embedder_providers[ + f"litellm_embedder.{pipe_config.embedder}" + ] + if pipe_config.embedder + else None + ), + document_store_provider=( + _document_store_provider + if service._pipelines[ + pipeline_name + ]._document_store_provider + else None + ), ) else: if service._pipelines[pipeline_name]._document_store_provider: @@ -233,10 +256,12 @@ def update_configs(configs_request: Configs): embedder_provider=service._pipelines[ pipeline_name ]._embedder_provider, - document_store_provider=app.state.instantiated_providers[ - "document_store" - ]["qdrant"], + document_store_provider=_document_store_provider, ) + + # TODO: updating service_metadata + for pipeline_name, _ in app.state.pipe_components.items(): + pass except Exception as e: raise HTTPException(status_code=500, detail=f"Error updating configs: {e}") diff --git a/wren-ai-service/src/core/pipeline.py b/wren-ai-service/src/core/pipeline.py index 64fa37ac49..9946ca5d53 100644 --- a/wren-ai-service/src/core/pipeline.py +++ b/wren-ai-service/src/core/pipeline.py @@ -32,6 +32,7 @@ def update_components( llm_provider: Optional[LLMProvider] = None, embedder_provider: Optional[EmbedderProvider] = None, document_store_provider: Optional[DocumentStoreProvider] = None, + update_components: bool = True, ): if llm_provider: self._llm_provider = llm_provider @@ -39,7 +40,8 @@ def update_components( self._embedder_provider = embedder_provider if document_store_provider: self._document_store_provider = document_store_provider - self._components = self._update_components() + if update_components: + self._components = self._update_components() def __str__(self): return f"BasicPipeline(llm_provider={self._llm_provider}, embedder_provider={self._embedder_provider}, document_store_provider={self._document_store_provider})" diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py index 947944a663..7cd076bb8a 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py @@ -161,14 +161,31 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - self._retriever = document_store_provider.get_retriever( - document_store_provider.get_store("project_meta") + self._document_store_provider = document_store_provider + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") ) self._llm_provider = llm_provider self._engine = engine self._description = description self._components = self._update_components() + def update_components( + self, + llm_provider: LLMProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + llm_provider=llm_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") + ) + self._components = self._update_components() + def _update_components(self): return { "generator": self._llm_provider.get_generator( diff --git a/wren-ai-service/src/pipelines/generation/sql_correction.py b/wren-ai-service/src/pipelines/generation/sql_correction.py index aa1f611056..5affc3192d 100644 --- a/wren-ai-service/src/pipelines/generation/sql_correction.py +++ b/wren-ai-service/src/pipelines/generation/sql_correction.py @@ -142,8 +142,22 @@ def __init__( self._engine = engine self._description = description self._components = self._update_components() - self._retriever = document_store_provider.get_retriever( - document_store_provider.get_store("project_meta") + self._document_store_provider = document_store_provider + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") + ) + + def update_components( + self, + llm_provider: LLMProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + llm_provider=llm_provider, document_store_provider=document_store_provider + ) + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") ) def _update_components(self): diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index 1327375fc0..9073388866 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -158,14 +158,28 @@ def __init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) - self._retriever = document_store_provider.get_retriever( - document_store_provider.get_store("project_meta") + self._document_store_provider = document_store_provider + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") ) self._llm_provider = llm_provider self._engine = engine self._description = description self._components = self._update_components() + def update_components( + self, + llm_provider: LLMProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + llm_provider=llm_provider, document_store_provider=document_store_provider + ) + self._retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store("project_meta") + ) + def _update_components(self): return { "generator": self._llm_provider.get_generator( diff --git a/wren-ai-service/src/pipelines/indexing/db_schema.py b/wren-ai-service/src/pipelines/indexing/db_schema.py index e0cd12070f..f4a9218805 100644 --- a/wren-ai-service/src/pipelines/indexing/db_schema.py +++ b/wren-ai-service/src/pipelines/indexing/db_schema.py @@ -362,6 +362,20 @@ def __init__( helper.load_helpers() + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._dbschema_store = self._document_store_provider.get_store() + self._components = self._update_components() + def _update_components(self): return { "cleaner": DocumentCleaner([self._dbschema_store]), diff --git a/wren-ai-service/src/pipelines/indexing/historical_question.py b/wren-ai-service/src/pipelines/indexing/historical_question.py index 69158f3527..a5888a3c76 100644 --- a/wren-ai-service/src/pipelines/indexing/historical_question.py +++ b/wren-ai-service/src/pipelines/indexing/historical_question.py @@ -156,6 +156,22 @@ def __init__( self._configs = {} self._final = "write" + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._store = self._document_store_provider.get_store( + dataset_name="view_questions" + ) + self._components = self._update_components() + def _update_components(self): return { "cleaner": DocumentCleaner([self._store]), diff --git a/wren-ai-service/src/pipelines/indexing/instructions.py b/wren-ai-service/src/pipelines/indexing/instructions.py index 015b661169..337efce174 100644 --- a/wren-ai-service/src/pipelines/indexing/instructions.py +++ b/wren-ai-service/src/pipelines/indexing/instructions.py @@ -144,6 +144,22 @@ def __init__( self._description = description self._components = self._update_components() + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._store = self._document_store_provider.get_store( + dataset_name="instructions" + ) + self._components = self._update_components() + def _update_components(self): return { "cleaner": InstructionsCleaner(self._store), diff --git a/wren-ai-service/src/pipelines/indexing/project_meta.py b/wren-ai-service/src/pipelines/indexing/project_meta.py index 0cb0097910..0ec516df15 100644 --- a/wren-ai-service/src/pipelines/indexing/project_meta.py +++ b/wren-ai-service/src/pipelines/indexing/project_meta.py @@ -83,6 +83,15 @@ def __init__( self._components = self._update_components() self._final = "write" + def update_components(self, document_store_provider: DocumentStoreProvider, **_): + super().update_components( + document_store_provider=document_store_provider, update_components=False + ) + self._store = self._document_store_provider.get_store( + dataset_name="project_meta" + ) + self._components = self._update_components() + def _update_components(self): return { "validator": MDLValidator(), diff --git a/wren-ai-service/src/pipelines/indexing/sql_pairs.py b/wren-ai-service/src/pipelines/indexing/sql_pairs.py index c14cc16a25..a2c6d3e7ca 100644 --- a/wren-ai-service/src/pipelines/indexing/sql_pairs.py +++ b/wren-ai-service/src/pipelines/indexing/sql_pairs.py @@ -185,6 +185,20 @@ def __init__( self._external_pairs = _load_sql_pairs(sql_pairs_path) + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._store = self._document_store_provider.get_store(dataset_name="sql_pairs") + self._components = self._update_components() + def _update_components(self): return { "cleaner": SqlPairsCleaner(self._store), diff --git a/wren-ai-service/src/pipelines/indexing/table_description.py b/wren-ai-service/src/pipelines/indexing/table_description.py index 4fc20ea9a1..69c542d517 100644 --- a/wren-ai-service/src/pipelines/indexing/table_description.py +++ b/wren-ai-service/src/pipelines/indexing/table_description.py @@ -138,6 +138,22 @@ def __init__( self._configs = {} self._final = "write" + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._table_description_store = self._document_store_provider.get_store( + dataset_name="table_descriptions" + ) + self._components = self._update_components() + def _update_components(self): return { "cleaner": DocumentCleaner([self._table_description_store]), diff --git a/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py b/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py index ae1508f5fe..7c881490c2 100644 --- a/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/db_schema_retrieval.py @@ -509,11 +509,27 @@ def _update_components(self): } def update_components( - self, llm_provider: LLMProvider, embedder_provider: EmbedderProvider + self, + llm_provider: LLMProvider, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, ): super().update_components( - llm_provider=llm_provider, embedder_provider=embedder_provider + llm_provider=llm_provider, + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, ) + self._table_retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store(dataset_name="table_descriptions"), + top_k=self._table_retrieval_size, + ) + self._dbschema_retriever = self._document_store_provider.get_retriever( + self._document_store_provider.get_store(), + top_k=self._table_column_retrieval_size, + ) + self._components = self._update_components() self._configs = self._update_configs() @observe(name="Ask Retrieval") diff --git a/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py index 2c815cb1e6..1ba78ce5f5 100644 --- a/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py @@ -141,6 +141,22 @@ def __init__( "historical_question_retrieval_similarity_threshold": historical_question_retrieval_similarity_threshold, } + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._view_questions_store = self._document_store_provider.get_store( + dataset_name="view_questions" + ) + self._components = self._update_components() + def _update_components(self): return { "view_questions_store": self._view_questions_store, diff --git a/wren-ai-service/src/pipelines/retrieval/instructions.py b/wren-ai-service/src/pipelines/retrieval/instructions.py index b704864b6c..8b0cca7a23 100644 --- a/wren-ai-service/src/pipelines/retrieval/instructions.py +++ b/wren-ai-service/src/pipelines/retrieval/instructions.py @@ -207,6 +207,22 @@ def __init__( "top_k": top_k, } + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._store = self._document_store_provider.get_store( + dataset_name="instructions" + ) + self._components = self._update_components() + def _update_components(self): return { "store": self._store, diff --git a/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py b/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py index 3424ecb9fe..7bcf71061e 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_pairs_retrieval.py @@ -138,6 +138,20 @@ def __init__( "sql_pairs_retrieval_max_size": sql_pairs_retrieval_max_size, } + def update_components( + self, + embedder_provider: EmbedderProvider, + document_store_provider: DocumentStoreProvider, + **_, + ): + super().update_components( + embedder_provider=embedder_provider, + document_store_provider=document_store_provider, + update_components=False, + ) + self._store = self._document_store_provider.get_store(dataset_name="sql_pairs") + self._components = self._update_components() + def _update_components(self): return { "store": self._store, diff --git a/wren-ai-service/src/providers/document_store/qdrant.py b/wren-ai-service/src/providers/document_store/qdrant.py index b90961c456..165ade3ba1 100644 --- a/wren-ai-service/src/providers/document_store/qdrant.py +++ b/wren-ai-service/src/providers/document_store/qdrant.py @@ -389,7 +389,8 @@ def __init__( self._api_key = Secret.from_token(api_key) if api_key else None self._timeout = timeout self._embedding_model_dim = embedding_model_dim - self._reset_document_store(recreate_index) + if recreate_index: + self._reset_document_store(recreate_index) def _reset_document_store(self, recreate_index: bool): self.get_store(recreate_index=recreate_index)