Skip to content

Commit 111ab2b

Browse files
authored
Improve sqlmesh as lib tooling (#4)
* Upgrade to latest sqlmesh * WIP * Refactor sqlmesh controller * bump * remove signals * fix * update * remove type keyword * Remove unused context.py
1 parent bfd24c6 commit 111ab2b

File tree

14 files changed

+2180
-2105
lines changed

14 files changed

+2180
-2105
lines changed

dagster_sqlmesh/asset.py

Lines changed: 10 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -1,115 +1,31 @@
1-
from dataclasses import dataclass, field
2-
from typing import (
3-
Union,
4-
Iterable,
5-
Dict,
6-
Mapping,
7-
Set,
8-
Any,
9-
Optional,
10-
)
1+
import typing as t
112
import logging
123

13-
import sqlglot
14-
from sqlglot import exp
15-
from sqlmesh.core.context import Context
16-
from sqlmesh.core.console import Console
17-
from sqlmesh.core.model import Model
184
from dagster import (
19-
AssetDep,
205
multi_asset,
21-
AssetCheckResult,
22-
AssetMaterialization,
23-
AssetOut,
24-
AssetKey,
256
RetryPolicy,
267
)
27-
from dagster._core.definitions.asset_dep import CoercibleToAssetDep
8+
9+
from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController
10+
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
2811

2912
from .config import SQLMeshContextConfig
30-
from .console import EventConsole, ConsoleEventHandler, DebugEventConsole
31-
from .utils import sqlmesh_model_name_to_key
32-
from .context import DagsterSQLMeshContext
3313

3414
logger = logging.getLogger(__name__)
3515

3616

37-
MultiAssetResponse = Iterable[Union[AssetCheckResult, AssetMaterialization]]
38-
39-
40-
@dataclass(kw_only=True)
41-
class SQLMeshParsedFQN:
42-
catalog: str
43-
schema: str
44-
view_name: str
45-
46-
47-
def parse_fqn(fqn: str):
48-
split_fqn = fqn.split(".")
49-
50-
# Remove any quotes
51-
split_fqn = list(map(lambda a: a.strip("'\""), split_fqn))
52-
return SQLMeshParsedFQN(
53-
catalog=split_fqn[0], schema=split_fqn[1], view_name=split_fqn[2]
54-
)
55-
56-
57-
@dataclass(kw_only=True)
58-
class SQLMeshModelDep:
59-
fqn: str
60-
model: Optional[Model] = None
61-
62-
def parse_fqn(self):
63-
return parse_fqn(self.fqn)
64-
65-
66-
@dataclass(kw_only=True)
67-
class SQLMeshMultiAssetOptions:
68-
outs: Dict[str, AssetOut] = field(default_factory=lambda: {})
69-
deps: Iterable[CoercibleToAssetDep] = field(default_factory=lambda: {})
70-
internal_asset_deps: Dict[str, Set[AssetKey]] = field(default_factory=lambda: {})
71-
72-
73-
class SQLMeshDagsterTranslator:
74-
def get_asset_key_from_model(self, context: Context, model: Model) -> AssetKey:
75-
return AssetKey(model.view_name)
76-
77-
def get_asset_key_fqn(self, context: Context, fqn: str) -> AssetKey:
78-
table = self.get_fqn_to_table(context, fqn)
79-
return AssetKey(table.name)
80-
81-
def get_fqn_to_table(self, context: Context, fqn: str) -> exp.Table:
82-
dialect = self.get_context_dialect(context)
83-
return sqlglot.to_table(fqn, dialect=dialect)
84-
85-
def get_context_dialect(self, context: Context) -> str:
86-
return context.engine_adapter.dialect
87-
88-
# def get_asset_deps(
89-
# self, context: Context, model: Model, deps: List[SQLMeshModelDep]
90-
# ) -> List[AssetKey]:
91-
# asset_keys: List[AssetKey] = []
92-
# for dep in deps:
93-
# if dep.model:
94-
# asset_keys.append(AssetKey(dep.model.view_name))
95-
# else:
96-
# parsed_fqn = dep.parse_fqn()
97-
# asset_keys.append(AssetKey([parsed_fqn.view_name]))
98-
# return asset_keys
99-
100-
10117
# Define a SQLMesh Asset
10218
def sqlmesh_assets(
10319
*,
10420
config: SQLMeshContextConfig,
105-
name: Optional[str] = None,
106-
dagster_sqlmesh_translator: Optional[SQLMeshDagsterTranslator] = None,
21+
name: t.Optional[str] = None,
22+
dagster_sqlmesh_translator: t.Optional[SQLMeshDagsterTranslator] = None,
10723
compute_kind: str = "sqlmesh",
108-
op_tags: Optional[Mapping[str, Any]] = None,
109-
required_resource_keys: Optional[Set[str]] = None,
110-
retry_policy: Optional[RetryPolicy] = None,
24+
op_tags: t.Optional[t.Mapping[str, t.Any]] = None,
25+
required_resource_keys: t.Optional[t.Set[str]] = None,
26+
retry_policy: t.Optional[RetryPolicy] = None,
11127
):
112-
controller = setup_sqlmesh_controller(config)
28+
controller = DagsterSQLMeshController.setup(config)
11329
if not dagster_sqlmesh_translator:
11430
dagster_sqlmesh_translator = SQLMeshDagsterTranslator()
11531
conversion = controller.to_asset_outs(dagster_sqlmesh_translator)
@@ -123,79 +39,4 @@ def sqlmesh_assets(
12339
compute_kind=compute_kind,
12440
retry_policy=retry_policy,
12541
required_resource_keys=required_resource_keys,
126-
# can_subset=True,
127-
)
128-
129-
130-
@dataclass
131-
class SQLMeshController:
132-
console: EventConsole
133-
context: DagsterSQLMeshContext
134-
135-
def add_event_handler(self, handler: ConsoleEventHandler):
136-
return self.console.add_handler(handler)
137-
138-
def remove_event_handler(self, handler_id: str):
139-
return self.console.remove_handler(handler_id)
140-
141-
def to_asset_outs(
142-
self, translator: SQLMeshDagsterTranslator
143-
) -> SQLMeshMultiAssetOptions:
144-
context = self.context
145-
dag = context.dag
146-
output = SQLMeshMultiAssetOptions()
147-
depsMap: Dict[str, CoercibleToAssetDep] = {}
148-
149-
for model_fqn, deps in dag.graph.items():
150-
logger.debug(f"model found: {model_fqn}")
151-
model = context.get_model(model_fqn)
152-
if not model:
153-
# If no model is returned this seems to be an asset dependency
154-
continue
155-
asset_out = translator.get_asset_key_from_model(
156-
context,
157-
model,
158-
)
159-
model_deps = [
160-
SQLMeshModelDep(fqn=dep, model=context.get_model(dep)) for dep in deps
161-
]
162-
internal_asset_deps: Set[AssetKey] = set()
163-
for dep in model_deps:
164-
if dep.model:
165-
internal_asset_deps.add(
166-
translator.get_asset_key_from_model(context, dep.model)
167-
)
168-
else:
169-
table = translator.get_fqn_to_table(context, dep.fqn)
170-
key = translator.get_asset_key_fqn(context, dep.fqn)
171-
internal_asset_deps.add(key)
172-
# create an external dep
173-
depsMap[table.name] = AssetDep(key)
174-
model_key = sqlmesh_model_name_to_key(model.name)
175-
output.outs[model_key] = AssetOut(key=asset_out, is_required=False)
176-
output.internal_asset_deps[model_key] = internal_asset_deps
177-
178-
output.deps = list(depsMap.values())
179-
return output
180-
181-
182-
def setup_sqlmesh_controller(
183-
config: SQLMeshContextConfig,
184-
debug_console: Optional[Console] = None,
185-
log_override: Optional[logging.Logger] = None,
186-
):
187-
console = EventConsole(log_override=log_override)
188-
if debug_console:
189-
console = DebugEventConsole(debug_console)
190-
options: Dict[str, Any] = dict(
191-
paths=config.path,
192-
gateway=config.gateway,
193-
console=console,
194-
)
195-
if config.sqlmesh_config:
196-
options["config"] = config.sqlmesh_config
197-
context = DagsterSQLMeshContext(**options)
198-
return SQLMeshController(
199-
console=console,
200-
context=context,
20142
)

dagster_sqlmesh/conftest.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
import shutil
55
import os
66
from dataclasses import dataclass
7-
from typing import cast, List, Optional, Any, Dict
7+
import typing as t
88

99
import pytest
1010
import duckdb
1111
import polars
1212
from sqlmesh.utils.date import TimeLike
13-
from sqlmesh.core.plan.builder import PlanBuilder
1413
from sqlmesh.core.console import get_console
1514
from sqlmesh.core.config import (
1615
Config as SQLMeshConfig,
@@ -20,8 +19,9 @@
2019
)
2120

2221
from dagster_sqlmesh.config import SQLMeshContextConfig
23-
from dagster_sqlmesh.events import ConsoleRecorder, show_plan_summary
24-
from dagster_sqlmesh.asset import setup_sqlmesh_controller
22+
from dagster_sqlmesh.events import ConsoleRecorder
23+
from dagster_sqlmesh.controller.base import PlanOptions, RunOptions
24+
from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController
2525

2626
logger = logging.getLogger(__name__)
2727

@@ -58,7 +58,9 @@ def create_controller(self, enable_debug_console: bool = False):
5858
console = None
5959
if enable_debug_console:
6060
console = get_console()
61-
return setup_sqlmesh_controller(self.context_config, debug_console=console)
61+
return DagsterSQLMeshController.setup(
62+
self.context_config, debug_console=console
63+
)
6264

6365
def query(self, *args, **kwargs):
6466
conn = duckdb.connect(self.db_path)
@@ -99,21 +101,19 @@ def run(
99101
*,
100102
environment: str,
101103
apply: bool = False,
102-
execution_time: Optional[TimeLike] = None,
104+
execution_time: t.Optional[TimeLike] = None,
103105
enable_debug_console: bool = False,
104-
start: Optional[TimeLike] = None,
105-
end: Optional[TimeLike] = None,
106-
restate_models: Optional[List[str]] = None,
106+
start: t.Optional[TimeLike] = None,
107+
end: t.Optional[TimeLike] = None,
108+
restate_models: t.Optional[t.List[str]] = None,
107109
):
108110
controller = self.create_controller(enable_debug_console=enable_debug_console)
109-
controller.add_event_handler(ConsoleRecorder())
110-
plan_options: Dict[str, Any] = dict(
111-
environment=environment,
111+
recorder = ConsoleRecorder()
112+
# controller.add_event_handler(ConsoleRecorder())
113+
plan_options = PlanOptions(
112114
enable_preview=True,
113115
)
114-
run_options: Dict[str, Any] = dict(
115-
environment=environment,
116-
)
116+
run_options = RunOptions()
117117
if execution_time:
118118
plan_options["execution_time"] = execution_time
119119
run_options["execution_time"] = execution_time
@@ -126,19 +126,12 @@ def run(
126126
plan_options["end"] = end
127127
run_options["end"] = end
128128

129-
builder = cast(
130-
PlanBuilder,
131-
controller.context.plan_builder(**plan_options),
132-
)
133-
if apply:
134-
logger.debug("making plan")
135-
plan = builder.build()
136-
show_plan_summary(logger, plan, lambda x: x.is_model)
137-
logger.debug("applying plan")
138-
controller.context.apply(plan)
139-
logger.debug("running through the scheduler")
140-
controller.context.run(**run_options)
141-
controller.context.close()
129+
for _context, event in controller.plan_and_run(
130+
environment,
131+
plan_options=plan_options,
132+
run_options=run_options,
133+
):
134+
recorder(event)
142135

143136

144137
@pytest.fixture

0 commit comments

Comments
 (0)