Skip to content

Commit 8fd1468

Browse files
committed
feat(oso_dagster): asset caching for sqlmesh related assets
1 parent 18cd4f4 commit 8fd1468

File tree

6 files changed

+331
-84
lines changed

6 files changed

+331
-84
lines changed

warehouse/oso_dagster/assets/sqlmesh.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
import copy
22

33
import dagster as dg
4-
from dagster import AssetExecutionContext, AssetSelection, RunConfig, define_asset_job
4+
from dagster import (
5+
AssetExecutionContext,
6+
AssetSelection,
7+
ResourceParam,
8+
RunConfig,
9+
define_asset_job,
10+
)
511
from dagster_sqlmesh import (
12+
DagsterSQLMeshCacheOptions,
613
PlanOptions,
714
SQLMeshContextConfig,
815
SQLMeshDagsterTranslator,
916
SQLMeshResource,
1017
sqlmesh_assets,
1118
)
19+
from oso_dagster.config import DagsterConfig
1220
from oso_dagster.factories.common import AssetFactoryResponse
1321
from oso_dagster.resources.trino import TrinoResource
1422
from oso_dagster.utils.asynctools import multiple_async_contexts
@@ -60,29 +68,38 @@ class SQLMeshRunConfig(dg.Config):
6068
@early_resources_asset_factory()
6169
def sqlmesh_factory(
6270
sqlmesh_infra_config: dict,
63-
sqlmesh_config: SQLMeshContextConfig,
71+
sqlmesh_context_config: SQLMeshContextConfig,
6472
sqlmesh_translator: SQLMeshDagsterTranslator,
73+
sqlmesh_cache_options: DagsterSQLMeshCacheOptions,
6574
):
6675
dev_environment = sqlmesh_infra_config["dev_environment"]
6776
environment = sqlmesh_infra_config["environment"]
6877

78+
print(sqlmesh_cache_options)
79+
6980
@sqlmesh_assets(
70-
config=sqlmesh_config,
81+
config=sqlmesh_context_config,
7182
environment=environment,
7283
dagster_sqlmesh_translator=sqlmesh_translator,
7384
enabled_subsetting=False,
7485
op_tags=op_tags,
86+
cache_options=sqlmesh_cache_options,
7587
)
7688
async def sqlmesh_project(
7789
context: AssetExecutionContext,
90+
global_config: ResourceParam[DagsterConfig],
7891
sqlmesh: SQLMeshResource,
7992
trino: TrinoResource,
8093
config: SQLMeshRunConfig,
8194
):
8295
restate_models = config.restate_models[:] if config.restate_models else []
8396

84-
async with multiple_async_contexts(
85-
trino=trino.ensure_available(log_override=context.log),
97+
# We use this helper function so we can run sqlmesh both locally and in
98+
# a k8s environment
99+
def run_sqlmesh(
100+
context: AssetExecutionContext,
101+
sqlmesh: SQLMeshResource,
102+
config: SQLMeshRunConfig,
86103
):
87104
# If restate_by_entity_category is True, dynamically identify models based on entity categories
88105
if config.restate_by_entity_category:
@@ -133,6 +150,18 @@ async def sqlmesh_project(
133150
):
134151
yield result
135152

153+
# Trino can either be `local-trino` or `trino`
154+
if "trino" in global_config.sqlmesh_gateway:
155+
async with multiple_async_contexts(
156+
trino=trino.ensure_available(log_override=context.log),
157+
):
158+
for result in run_sqlmesh(context, sqlmesh, config):
159+
yield result
160+
else:
161+
# If we are not running trino we are using duckdb
162+
for result in run_sqlmesh(context, sqlmesh, config):
163+
yield result
164+
136165
all_assets_selection = AssetSelection.assets(sqlmesh_project)
137166

138167
return AssetFactoryResponse(
Lines changed: 84 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,110 @@
1-
import logging
21
import typing as t
2+
from pathlib import Path
3+
from traceback import print_exception
34

4-
from dagster import AssetKey, ResourceParam
5-
from dagster_sqlmesh import DagsterSQLMeshController, SQLMeshContextConfig
5+
import structlog
6+
from dagster import AssetKey, AssetsDefinition, ResourceParam
7+
from dagster_sqlmesh import (
8+
DagsterSQLMeshCacheOptions,
9+
DagsterSQLMeshController,
10+
SQLMeshContextConfig,
11+
)
612
from dagster_sqlmesh.controller.base import DEFAULT_CONTEXT_FACTORY
13+
from oso_dagster.resources.sqlmesh import SQLMeshExportedAssetDefinition
714
from sqlmesh.core.model import Model
815

916
from ..factories import AssetFactoryResponse, early_resources_asset_factory
1017
from ..resources import PrefixedSQLMeshTranslator, SQLMeshExporter
1118

12-
logger = logging.getLogger(__name__)
19+
logger = structlog.get_logger(__name__)
1320

1421

1522
@early_resources_asset_factory()
1623
def sqlmesh_export_factory(
1724
sqlmesh_infra_config: dict,
18-
sqlmesh_config: SQLMeshContextConfig,
25+
sqlmesh_context_config: SQLMeshContextConfig,
1926
sqlmesh_translator: PrefixedSQLMeshTranslator,
2027
sqlmesh_exporters: ResourceParam[t.List[SQLMeshExporter]],
28+
sqlmesh_cache_options: DagsterSQLMeshCacheOptions,
2129
):
2230
environment = sqlmesh_infra_config["environment"]
2331

2432
controller = DagsterSQLMeshController.setup_with_config(
25-
config=sqlmesh_config,
33+
config=sqlmesh_context_config,
2634
context_factory=DEFAULT_CONTEXT_FACTORY,
2735
)
28-
assets = []
29-
30-
with controller.instance(environment) as mesh:
31-
models = mesh.models()
32-
models_to_export: t.List[t.Tuple[Model, AssetKey]] = []
33-
for name, model in models.items():
34-
if "export" not in model.tags:
35-
continue
36-
models_to_export.append(
37-
(
38-
model,
39-
sqlmesh_translator.get_asset_key(mesh.context, model.fqn),
40-
)
41-
)
36+
assets: list[AssetsDefinition] = []
37+
38+
missing_exporters = set()
39+
40+
# Hack for now to cache all of the sqlmesh export assets
41+
if sqlmesh_cache_options.enabled:
42+
logger.info("SQLMesh export cache is enabled, caching assets")
43+
44+
cache_dir = Path(sqlmesh_cache_options.cache_dir)
45+
exporter_assets_cache_dir = cache_dir.joinpath("sqlmesh_export_assets")
46+
# Ensure the cache directory exists
47+
exporter_assets_cache_dir.mkdir(parents=True, exist_ok=True)
4248

43-
# Create a export assets for this
4449
for exporter in sqlmesh_exporters:
45-
asset_def = exporter.create_export_asset(
46-
mesh,
47-
sqlmesh_translator,
48-
to_export=models_to_export,
50+
# Load the definition from the exporter
51+
cache_file = exporter_assets_cache_dir / f"{exporter.name()}.json"
52+
if cache_file.exists():
53+
logger.debug(f"Loading cached asset definition for {exporter.name()}")
54+
exporter_asset_def = SQLMeshExportedAssetDefinition.model_validate_json(
55+
cache_file.read_text()
56+
)
57+
assets.append(exporter.asset_from_definition(exporter_asset_def))
58+
else:
59+
missing_exporters.add(exporter.name())
60+
else:
61+
missing_exporters = set(exporter.name() for exporter in sqlmesh_exporters)
62+
63+
if len(missing_exporters) > 0:
64+
try:
65+
with controller.instance(environment) as mesh:
66+
models = mesh.models()
67+
models_to_export: t.List[t.Tuple[Model, AssetKey]] = []
68+
for name, model in models.items():
69+
if "export" not in model.tags:
70+
continue
71+
models_to_export.append(
72+
(
73+
model,
74+
sqlmesh_translator.get_asset_key(mesh.context, model.fqn),
75+
)
76+
)
77+
78+
# Create a export assets for this
79+
for exporter in sqlmesh_exporters:
80+
asset_def = exporter.create_export_asset(
81+
mesh,
82+
sqlmesh_translator,
83+
to_export=models_to_export,
84+
)
85+
if sqlmesh_cache_options.enabled:
86+
cache_dir = Path(sqlmesh_cache_options.cache_dir)
87+
exporter_assets_cache_dir = cache_dir.joinpath(
88+
"sqlmesh_export_assets"
89+
)
90+
91+
# Save the asset definition to the cache
92+
cache_file = (
93+
exporter_assets_cache_dir / f"{exporter.name()}.json"
94+
)
95+
logger.debug(
96+
f"Caching asset definition for {exporter.name()} at {cache_file}"
97+
)
98+
cache_file.write_text(asset_def.model_dump_json())
99+
100+
assets.append(exporter.asset_from_definition(asset_def))
101+
logger.debug(f"exporting for {exporter.name()}")
102+
except Exception as e:
103+
logger.exception(
104+
"Failed to create SQLMesh export assets",
105+
error=str(e),
49106
)
50-
assets.append(asset_def)
51-
logger.debug(f"exporting for {exporter.__class__.__name__}")
107+
print_exception(e)
108+
raise e
52109

53110
return AssetFactoryResponse(assets=assets)

warehouse/oso_dagster/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ class DagsterConfig(BaseSettings):
6060
sqlmesh_catalog: str = "iceberg"
6161
sqlmesh_schema: str = "oso"
6262
sqlmesh_bq_export_dataset_id: str = "oso"
63+
sqlmesh_dagster_asset_cache_enabled: bool = False
64+
sqlmesh_dagster_asset_cache_dir: str = ""
65+
sqlmesh_dagster_asset_enable_ttl: bool = True
66+
sqlmesh_dagster_asset_ttl_seconds: int = 60 * 5
6367

6468
enable_k8s_executor: bool = False
6569

warehouse/oso_dagster/definitions/legacy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import logging
22
import typing as t
33

4+
from dotenv import load_dotenv
45
from oso_core.logging import setup_module_logging
56
from oso_core.logging.decorators import time_function
67

8+
load_dotenv()
9+
710
logger = logging.getLogger(__name__)
811

912
setup_module_logging("oso_dagster")
13+
setup_module_logging("dagster_sqlmesh")
1014

1115

1216
@time_function(logger, override_name="legacy_main")
@@ -68,7 +72,7 @@ def load_definitions(
6872
dlt_warehouse_destination: Destination,
6973
dlt: DagsterDltResource,
7074
alert_manager: AlertManager,
71-
sqlmesh_config: SQLMeshContextConfig,
75+
sqlmesh_context_config: SQLMeshContextConfig,
7276
sqlmesh_infra_config: t.Dict[str, str],
7377
sqlmesh: SQLMeshResource,
7478
k8s: K8sResource,
@@ -118,7 +122,7 @@ def load_definitions(
118122
"dlt_warehouse_destination": dlt_warehouse_destination,
119123
"project_id": global_config.project_id,
120124
"alert_manager": alert_manager,
121-
"sqlmesh_config": sqlmesh_config,
125+
"sqlmesh_config": sqlmesh_context_config,
122126
"sqlmesh_infra_config": sqlmesh_infra_config,
123127
"k8s": k8s,
124128
"trino": trino,

warehouse/oso_dagster/definitions/resources.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
import os
12
import typing as t
23

34
import structlog
45
from dagster import ConfigurableIOManagerFactory
56
from dagster_dlt import DagsterDltResource
67
from dagster_gcp import BigQueryResource, GCSResource
7-
from dagster_sqlmesh import SQLMeshContextConfig, SQLMeshResource
8+
from dagster_sqlmesh import (
9+
DagsterSQLMeshCacheOptions,
10+
SQLMeshContextConfig,
11+
SQLMeshResource,
12+
)
813
from dlt.common.destination import Destination
914
from oso_core.logging.decorators import time_function
1015
from oso_dagster.cbt.cbt import CBTResource
@@ -41,6 +46,7 @@
4146
LogAlertManager,
4247
)
4348
from oso_dagster.utils.secrets import SecretResolver
49+
from sqlmesh.core.config.connection import DuckDBConnectionConfig, TrinoConnectionConfig
4450

4551
from ..config import DagsterConfig
4652
from ..utils import GCPSecretResolver, LocalSecretResolver
@@ -103,15 +109,55 @@ def gcs_resource_factory(global_config: DagsterConfig) -> GCSResource:
103109
return GCSResource(project=global_config.project_id)
104110

105111

106-
@resource_factory("sqlmesh_translator")
112+
@resource_factory("sqlmesh_cache_options")
107113
@time_function(logger)
108-
def sqlmesh_translator_factory():
109-
return PrefixedSQLMeshTranslator("sqlmesh")
114+
def sqlmesh_cache_options_factory(
115+
global_config: DagsterConfig,
116+
) -> DagsterSQLMeshCacheOptions:
117+
"""Factory function to create SQLMesh cache options."""
118+
119+
return DagsterSQLMeshCacheOptions(
120+
enabled=global_config.sqlmesh_dagster_asset_cache_enabled,
121+
cache_dir=global_config.sqlmesh_dagster_asset_cache_dir,
122+
enable_ttl=global_config.sqlmesh_dagster_asset_enable_ttl,
123+
ttl_seconds=global_config.sqlmesh_dagster_asset_ttl_seconds,
124+
)
110125

111126

112-
@resource_factory("sqlmesh_config")
127+
@resource_factory("sqlmesh_translator")
128+
@time_function(logger)
129+
def sqlmesh_translator_factory(
130+
global_config: DagsterConfig, sqlmesh_context_config: SQLMeshContextConfig
131+
) -> PrefixedSQLMeshTranslator:
132+
sqlmesh_config = sqlmesh_context_config.sqlmesh_config
133+
if sqlmesh_config is None:
134+
raise ValueError("SQLMesh configuration is not set in the context config.")
135+
gateway = sqlmesh_config.gateways[global_config.sqlmesh_gateway]
136+
connection = gateway.connection
137+
if connection is None:
138+
raise ValueError("SQLMesh gateway connection is not set in the context config.")
139+
default_catalog = ""
140+
match connection:
141+
case TrinoConnectionConfig(catalog=catalog):
142+
default_catalog = catalog
143+
case DuckDBConnectionConfig(database=database):
144+
if not database:
145+
raise ValueError(
146+
"DuckDB database path is not set in the context config."
147+
)
148+
default_catalog = os.path.basename(database).split(".")[0]
149+
case _:
150+
raise ValueError(
151+
f"Unsupported SQLMesh connection type: {type(connection).__name__}"
152+
)
153+
return PrefixedSQLMeshTranslator("sqlmesh", default_catalog)
154+
155+
156+
@resource_factory("sqlmesh_context_config")
113157
@time_function(logger)
114-
def sqlmesh_config_factory(global_config: DagsterConfig) -> SQLMeshContextConfig:
158+
def sqlmesh_context_config_factory(
159+
global_config: DagsterConfig,
160+
) -> SQLMeshContextConfig:
115161
return SQLMeshContextConfig(
116162
path=global_config.sqlmesh_dir,
117163
gateway=global_config.sqlmesh_gateway,
@@ -145,11 +191,11 @@ def trino_resource_factory(
145191

146192
@resource_factory("sqlmesh")
147193
def sqlmesh_resource_factory(
148-
sqlmesh_config: SQLMeshContextConfig,
194+
sqlmesh_context_config: SQLMeshContextConfig,
149195
) -> SQLMeshResource:
150196
"""Factory function to create a SQLMesh resource."""
151197
return SQLMeshResource(
152-
config=sqlmesh_config,
198+
config=sqlmesh_context_config,
153199
)
154200

155201

@@ -347,7 +393,8 @@ def default_resource_registry():
347393
registry.add(gcs_resource_factory)
348394
registry.add(sqlmesh_resource_factory)
349395
registry.add(sqlmesh_translator_factory)
350-
registry.add(sqlmesh_config_factory)
396+
registry.add(sqlmesh_context_config_factory)
397+
registry.add(sqlmesh_cache_options_factory)
351398
registry.add(k8s_resource_factory)
352399
registry.add(trino_resource_factory)
353400
registry.add(sqlmesh_exporter_factory)

0 commit comments

Comments
 (0)