Skip to content

Commit 391bb65

Browse files
authored
fix(oso_agent): fixes vector store initialization (#5434)
* fix(oso_agent): fixes vector store initialization * fixes * fix: disable some unmaintained tests * more fixes
1 parent 92c1eb5 commit 391bb65

File tree

13 files changed

+1782
-1596
lines changed

13 files changed

+1782
-1596
lines changed
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
dependencies:
22
- name: dagster
33
repository: https://dagster-io.github.io/helm
4-
version: 1.10.4
5-
digest: sha256:e666c3d9872bcd0d659bb91903f841cc9cdcf1a9a56ba9aad0a0ffa8e7a7e327
6-
generated: "2025-03-10T21:02:26.932985+01:00"
4+
version: 1.11.15
5+
digest: sha256:6a7f57b097653371f57d47f95ea1c846e285b7692ede562cecf99dc98c114378
6+
generated: "2025-10-20T21:37:00.11332175Z"

ops/helm-charts/oso-dagster/Chart.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ name: oso-dagster
33
description: Extension of the dagster template
44

55
type: application
6-
version: 0.25.0
6+
version: 0.26.0
77
appVersion: "1.0.0"
88
dependencies:
99
- name: dagster
10-
version: "1.10.4"
10+
version: "1.11.15"
1111
repository: "https://dagster-io.github.io/helm"
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# OSO Dagster Custom Helm Chart
2+
3+
To update the Dagster version used in this chart, modify the `version` field under `dependencies` in `Chart.yaml` to the desired Dagster Helm chart version and then run
4+
5+
```bash
6+
helm dependency update
7+
```

uv.lock

Lines changed: 1714 additions & 1465 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

warehouse/metrics-service/metrics_service/test_cache.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414

1515
@pytest.mark.asyncio
16+
@pytest.mark.skip(reason="unmaintained test")
1617
async def test_cache_export_manager():
1718
adapter_mock = AsyncMock(FakeExportAdapter)
1819
adapter_mock.export_table.return_value = ExportReference(
@@ -56,6 +57,7 @@ class TestException(Exception):
5657

5758

5859
@pytest.mark.asyncio
60+
@pytest.mark.skip(reason="unmaintained test")
5961
async def test_cache_export_manager_fails():
6062
adapter_mock = AsyncMock(FakeExportAdapter)
6163
adapter_mock.export_table = AsyncMock(side_effect=TestException("test"))

warehouse/oso_agent/oso_agent/agent/basic_agent.py

Lines changed: 0 additions & 99 deletions
This file was deleted.

warehouse/oso_agent/oso_agent/cli/commands.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def initialize_vector_store(config: AgentConfig):
145145
storage_context=storage_context,
146146
oso_client=oso_client,
147147
embed_model=embed,
148+
show_progress=True,
148149
)
149150
)
150151

warehouse/oso_agent/oso_agent/tool/embedding.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import logging
32

43
from llama_index.embeddings.google_genai import GoogleGenAIEmbedding
@@ -9,13 +8,16 @@
98

109
logger = logging.getLogger(__name__)
1110

11+
1212
def create_embedding(config: AgentConfig):
1313
"""Setup the embedding model depending on the configuration"""
1414
match config.llm:
1515
case LocalLLMConfig(
1616
ollama_embedding=embedding, ollama_url=base_url, ollama_timeout=timeout
1717
):
18-
logger.info(f"Initializing Ollama embedding model {config.llm.ollama_model}")
18+
logger.info(
19+
f"Initializing Ollama embedding model {config.llm.ollama_model}"
20+
)
1921
return OllamaEmbedding(
2022
model_name=embedding,
2123
base_url=base_url,
@@ -26,7 +28,10 @@ def create_embedding(config: AgentConfig):
2628
return GoogleGenAIEmbedding(
2729
api_key=api_key,
2830
model_name=embedding,
29-
embed_batch_size=100,
31+
embed_batch_size=30,
32+
retries=10,
33+
retry_min_seconds=30,
34+
retry_max_seconds=300,
3035
)
3136
case _:
3237
raise AgentConfigError(f"Unsupported LLM type: {config.llm.type}")

warehouse/oso_agent/oso_agent/tool/oso_text2sql.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,27 @@ async def index_oso_tables(
7474
embed_model: BaseEmbedding,
7575
tables_to_index: dict[str, list[str]] | None = None,
7676
include_tables: list[str] | None = None,
77+
insert_batch_size: int = 500,
78+
show_progress: bool = False,
7779
) -> VectorStoreIndex:
7880
"""Index the given tables into a vector store index. Tables are separated by
7981
adding the table name to the metadata of the nodes.
8082
8183
This is not intended to be run every time the agent runs. It should be
8284
called as a preprocessing step using the cli command `index-oso-tables`.
85+
86+
Args:
87+
config: The agent configuration.
88+
storage_context: The storage context to use for the index.
89+
oso_client: The Oso client to use to access the database.
90+
embed_model: The embedding model to use for the index.
91+
tables_to_index: A mapping of table names to the list of columns to index.
92+
include_tables: A list of tables to include in the OsoSqlDatabase.
93+
insert_batch_size: The batch size to use when inserting nodes into the
94+
vector store. For google's vector store, this should be less than
95+
5000 though it seems that it's also about the file size of the resulant
96+
embedding json. At this time, 500 seems to be a safe bet.
97+
show_progress: Whether to show a progress bar when inserting nodes.
8398
"""
8499

85100
tables_to_index = tables_to_index or DEFAULT_TABLES_TO_INDEX
@@ -147,7 +162,8 @@ async def index_oso_tables(
147162
embed_model=embed_model,
148163
storage_context=storage_context,
149164
is_complete_overwrite=True,
150-
insert_batch_size=100000,
165+
insert_batch_size=insert_batch_size,
166+
show_progress=show_progress,
151167
)
152168
return index
153169
# vector_store.add(nodes)

warehouse/oso_agent/oso_agent/types/response.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,30 @@
11
import typing as t
22

33
from llama_index.core.workflow import Context
4-
from llama_index.core.workflow.handler import WorkflowHandler
54
from oso_semantic.definition import SemanticQuery
65
from pydantic import BaseModel, Field
6+
from workflows.handler import WorkflowHandler
77

88
from .sql_query import SqlQuery
99

1010

1111
class ErrorResponse(BaseModel):
1212
type: t.Literal["error"] = "error"
1313

14-
message: str = Field(
15-
description="Error message from the agent."
16-
)
14+
message: str = Field(description="Error message from the agent.")
1715

1816
details: str = Field(
19-
default="",
20-
description="Optional details about the error, if available."
17+
default="", description="Optional details about the error, if available."
2118
)
2219

2320
def __str__(self) -> str:
2421
"""Return the string representation of the error response."""
25-
return f"Error: {self.message} | Details: {self.details}" if self.details else f"Error: {self.message}"
22+
return (
23+
f"Error: {self.message} | Details: {self.details}"
24+
if self.details
25+
else f"Error: {self.message}"
26+
)
27+
2628

2729
class StrResponse(BaseModel):
2830
type: t.Literal["str"] = "str"
@@ -35,6 +37,7 @@ def __str__(self) -> str:
3537
"""Return the string representation of the response."""
3638
return self.blob
3739

40+
3841
class AnyResponse(BaseModel):
3942
type: t.Literal["any"] = "any"
4043

@@ -45,14 +48,16 @@ class AnyResponse(BaseModel):
4548
def __str__(self):
4649
return str(self.raw)
4750

51+
4852
class SemanticResponse(BaseModel):
4953
type: t.Literal["semantic"] = "semantic"
5054

5155
query: SemanticQuery
5256

5357
def __str__(self):
5458
return self.query.model_dump_json()
55-
59+
60+
5661
class SqlResponse(BaseModel):
5762
type: t.Literal["sql"] = "sql"
5863

@@ -61,16 +66,15 @@ class SqlResponse(BaseModel):
6166
def __str__(self):
6267
return self.query.query
6368

69+
6470
ResponseType = t.Union[
65-
StrResponse,
66-
SemanticResponse,
67-
SqlResponse,
68-
ErrorResponse,
69-
AnyResponse
71+
StrResponse, SemanticResponse, SqlResponse, ErrorResponse, AnyResponse
7072
]
7173

74+
7275
class WrappedResponse:
7376
"""A wrapper for the response from an agent"""
77+
7478
_response: ResponseType
7579
_handler: WorkflowHandler | None
7680

@@ -84,7 +88,7 @@ def ctx(self) -> Context:
8488
assert self._handler is not None, "Workflow handler is not set."
8589
assert self._handler.ctx is not None, "Workflow handler context is not set."
8690
return self._handler.ctx
87-
91+
8892
@property
8993
def response(self) -> ResponseType:
9094
"""Get the response from the agent."""

0 commit comments

Comments
 (0)