Skip to content
Merged
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ from dagster import (
AssetExecutionContext,
Definitions,
)
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource, SQLMeshDagsterTranslator

sqlmesh_config = SQLMeshContextConfig(path="/home/foo/sqlmesh_project", gateway="name-of-your-gateway")

@sqlmesh_assets(environment="dev", config=sqlmesh_config)
@sqlmesh_assets(environment="dev", config=sqlmesh_config, translator=SQLMeshDagsterTranslator())
def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource):
yield from sqlmesh.run(context)

Expand Down
10 changes: 4 additions & 6 deletions dagster_sqlmesh/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
from dagster import AssetsDefinition, RetryPolicy, multi_asset
from sqlmesh import Context

from dagster_sqlmesh.config import SQLMeshContextConfig
from dagster_sqlmesh.controller import (
ContextCls,
ContextFactory,
DagsterSQLMeshController,
)
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator

from .config import SQLMeshContextConfig

logger = logging.getLogger(__name__)


Expand All @@ -23,7 +22,7 @@ def sqlmesh_assets(
config: SQLMeshContextConfig,
context_factory: ContextFactory[ContextCls] = lambda **kwargs: Context(**kwargs),
name: str | None = None,
dagster_sqlmesh_translator: SQLMeshDagsterTranslator | None = None,
dagster_sqlmesh_translator: SQLMeshDagsterTranslator = SQLMeshDagsterTranslator(),
compute_kind: str = "sqlmesh",
op_tags: t.Mapping[str, t.Any] | None = None,
required_resource_keys: set[str] | None = None,
Expand All @@ -32,9 +31,8 @@ def sqlmesh_assets(
enabled_subsetting: bool = False,
) -> t.Callable[[t.Callable[..., t.Any]], AssetsDefinition]:
controller = DagsterSQLMeshController.setup_with_config(config=config, context_factory=context_factory)
if not dagster_sqlmesh_translator:
dagster_sqlmesh_translator = SQLMeshDagsterTranslator()
conversion = controller.to_asset_outs(environment, dagster_sqlmesh_translator)

conversion = controller.to_asset_outs(environment, translator=dagster_sqlmesh_translator)

return multi_asset(
name=name,
Expand Down
2 changes: 1 addition & 1 deletion dagster_sqlmesh/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ class SQLMeshContextConfig(Config):
def sqlmesh_config(self) -> MeshConfig | None:
if self.config_override:
return MeshConfig.parse_obj(self.config_override)
return None
return None
6 changes: 3 additions & 3 deletions dagster_sqlmesh/controller/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import TimeLike

from ..config import SQLMeshContextConfig
from ..console import (
from dagster_sqlmesh.config import SQLMeshContextConfig
from dagster_sqlmesh.console import (
ConsoleEvent,
ConsoleEventHandler,
ConsoleException,
EventConsole,
Plan,
SnapshotCategorizer,
)
from ..events import ConsoleGenerator
from dagster_sqlmesh.events import ConsoleGenerator

logger = logging.getLogger(__name__)

Expand Down
41 changes: 24 additions & 17 deletions dagster_sqlmesh/controller/dagster.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# pyright: reportPrivateImportUsage=false
import logging
from inspect import signature

from dagster import AssetDep, AssetKey, AssetOut
from dagster._core.definitions.asset_dep import CoercibleToAssetDep

from ..translator import SQLMeshDagsterTranslator
from ..types import SQLMeshModelDep, SQLMeshMultiAssetOptions
from ..utils import sqlmesh_model_name_to_key
from .base import ContextCls, SQLMeshController
from dagster_sqlmesh.controller.base import ContextCls, SQLMeshController
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
from dagster_sqlmesh.types import SQLMeshModelDep, SQLMeshMultiAssetOptions
from dagster_sqlmesh.utils import get_asset_key_str

logger = logging.getLogger(__name__)

Expand All @@ -16,18 +17,15 @@ class DagsterSQLMeshController(SQLMeshController[ContextCls]):
"""An extension of the sqlmesh controller specifically for dagster use"""

def to_asset_outs(
self, environment: str, translator: SQLMeshDagsterTranslator
self, environment: str, translator: SQLMeshDagsterTranslator,
) -> SQLMeshMultiAssetOptions:
with self.instance(environment, "to_asset_outs") as instance:
context = instance.context
output = SQLMeshMultiAssetOptions()
depsMap: dict[str, CoercibleToAssetDep] = {}

for model, deps in instance.non_external_models_dag():
asset_key = translator.get_asset_key_from_model(
context,
model,
)
asset_key = translator.get_asset_key(context=context, fqn=model.fqn)
model_deps = [
SQLMeshModelDep(fqn=dep, model=context.get_model(dep))
for dep in deps
Expand All @@ -38,18 +36,27 @@ def to_asset_outs(
for dep in model_deps:
if dep.model:
internal_asset_deps.add(
translator.get_asset_key_from_model(context, dep.model)
translator.get_asset_key(context, dep.model.fqn)
)
else:
table = translator.get_fqn_to_table(context, dep.fqn)
key = translator.get_asset_key_fqn(context, dep.fqn)
table = get_asset_key_str(dep.fqn)
key = translator.get_asset_key(context, dep.fqn)
internal_asset_deps.add(key)
# create an external dep
depsMap[table.name] = AssetDep(key)
model_key = sqlmesh_model_name_to_key(model.name)
output.outs[model_key] = AssetOut(
key=asset_key, tags=asset_tags, is_required=False
)
depsMap[table] = AssetDep(key)
model_key = get_asset_key_str(model.fqn)
# If current Dagster supports "kinds", add labels for Dagster UI
if "kinds" in signature(AssetOut).parameters:
output.outs[model_key] = AssetOut(
key=asset_key, tags=asset_tags, is_required=False,
group_name=translator.get_group_name(context, model),
kinds={"sqlmesh", translator._get_context_dialect(context).lower()}
)
else:
output.outs[model_key] = AssetOut(
key=asset_key, tags=asset_tags, is_required=False,
group_name=translator.get_group_name(context, model)
)
output.internal_asset_deps[model_key] = internal_asset_deps

output.deps = list(depsMap.values())
Expand Down
35 changes: 21 additions & 14 deletions dagster_sqlmesh/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@
from sqlmesh.utils.date import TimeLike
from sqlmesh.utils.errors import SQLMeshError

from dagster_sqlmesh import console
from dagster_sqlmesh.config import SQLMeshContextConfig
from dagster_sqlmesh.controller import PlanOptions, RunOptions
from dagster_sqlmesh.controller.base import (
DEFAULT_CONTEXT_FACTORY,
ContextCls,
ContextFactory,
)

from . import console
from .config import SQLMeshContextConfig
from .controller import PlanOptions, RunOptions
from .controller.dagster import DagsterSQLMeshController
from .utils import sqlmesh_model_name_to_key
from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController
from dagster_sqlmesh.utils import get_asset_key_str


class MaterializationTracker:
Expand Down Expand Up @@ -147,7 +146,7 @@ def __init__(
self._prefix = prefix
self._context = context
self._logger = context.log
self._tracker = MaterializationTracker(dag.sorted[:], self._logger)
self._tracker = MaterializationTracker(sorted_dag=dag.sorted[:], logger=self._logger)
self._stage = "plan"
self._errors: list[Exception] = []
self._is_testing = is_testing
Expand All @@ -173,7 +172,8 @@ def notify_success(
# We allow selecting models. That value is mapped to models_map.
# If the model is not in models_map, we can skip any notification
if model:
output_key = sqlmesh_model_name_to_key(model.name)
# Passing model.fqn to get internal unique asset key
output_key = get_asset_key_str(model.fqn)
if not self._is_testing:
# Stupidly dagster when testing cannot use the following
# method so we must specifically skip this when testing
Expand Down Expand Up @@ -227,7 +227,7 @@ def report_event(self, event: console.ConsoleEvent) -> None:
log_context.info(
"Snapshot progress update",
{
"asset_key": sqlmesh_model_name_to_key(snapshot.model.name),
"asset_key": get_asset_key_str(snapshot.model.name),
"progress": f"{done}/{expected}",
"duration_ms": duration_ms,
},
Expand Down Expand Up @@ -327,7 +327,10 @@ def run(

logger = context.log

controller = self.get_controller(context_factory, logger)
controller = self.get_controller(
context_factory=context_factory,
log_override=logger
)

with controller.instance(environment) as mesh:
dag = mesh.models_dag()
Expand All @@ -338,7 +341,10 @@ def run(
[model.fqn for model, _ in mesh.non_external_models_dag()]
)
selected_models_set, models_map, select_models = (
self._get_selected_models_from_context(context, models)
self._get_selected_models_from_context(
context=context,
models=models
)
)

if all_available_models == selected_models_set or select_models is None:
Expand All @@ -351,7 +357,8 @@ def run(
logger.info(f"selected models: {select_models}")

event_handler = DagsterSQLMeshEventHandler(
context, models_map, dag, "sqlmesh: ", is_testing=self.is_testing
context=context, models_map=models_map, dag=dag,
prefix="sqlmesh: ", is_testing=self.is_testing
)

try:
Expand Down Expand Up @@ -397,7 +404,7 @@ def _get_selected_models_from_context(
select_models: list[str] = []
models_map = {}
for key, model in models.items():
if sqlmesh_model_name_to_key(model.name) in selected_output_names:
if get_asset_key_str(model.fqn) in selected_output_names:
models_map[key] = model
select_models.append(model.name)
return (
Expand All @@ -414,5 +421,5 @@ def get_controller(
return DagsterSQLMeshController.setup_with_config(
config=self.config,
context_factory=context_factory,
log_override=log_override,
log_override=log_override
)
2 changes: 1 addition & 1 deletion dagster_sqlmesh/test_sqlmesh_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import polars

from .testing import SQLMeshTestContext
from dagster_sqlmesh.testing import SQLMeshTestContext

logger = logging.getLogger(__name__)

Expand Down
27 changes: 14 additions & 13 deletions dagster_sqlmesh/translator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Sequence

import sqlglot
from dagster import AssetKey
from sqlglot import exp
from sqlmesh.core.context import Context
Expand All @@ -9,19 +9,20 @@
class SQLMeshDagsterTranslator:
"""Translates sqlmesh objects for dagster"""

def get_asset_key_from_model(self, context: Context, model: Model) -> AssetKey:
def get_asset_key(self, context: Context, fqn: str) -> AssetKey:
"""Given the sqlmesh context and a model return the asset key"""
return AssetKey(model.view_name)

def get_asset_key_fqn(self, context: Context, fqn: str) -> AssetKey:
"""Given the sqlmesh context and a fqn of a model return an asset key"""
table = self.get_fqn_to_table(context, fqn)
return AssetKey(table.name)

def get_fqn_to_table(self, context: Context, fqn: str) -> exp.Table:
"""Given the sqlmesh context and a fqn return the table"""
dialect = self._get_context_dialect(context)
return sqlglot.to_table(fqn, dialect=dialect)
path = self.get_asset_key_name(fqn)
return AssetKey(path)

def get_asset_key_name(self, fqn: str) -> Sequence[str]:
table = exp.to_table(fqn)
asset_key_name = [table.catalog, table.db, table.name]

return asset_key_name

def get_group_name(self, context: Context, model: Model) -> str:
path = self.get_asset_key_name(model.fqn)
return path[-2]

def _get_context_dialect(self, context: Context) -> str:
return context.engine_adapter.dialect
Expand Down
16 changes: 9 additions & 7 deletions dagster_sqlmesh/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from sqlglot import exp
from sqlmesh.core.snapshot import SnapshotId


def sqlmesh_model_name_to_key(name: str) -> str:
return name.replace(".", "_dot__")


def key_to_sqlmesh_model_name(key: str) -> str:
return key.replace("_dot__", ".")

def get_asset_key_str(fqn: str) -> str:
# This is an internal identifier used to map outputs and dependencies
# it will not affect the existing AssetKeys
# Only alphanumeric characters and underscores
table = exp.to_table(fqn)
asset_key_name = [table.catalog, table.db, table.name]

return "sqlmesh__" + "_".join(asset_key_name)

def snapshot_id_to_model_name(snapshot_id: SnapshotId) -> str:
"""Convert a SnapshotId to its model name.
Expand Down
4 changes: 2 additions & 2 deletions sample/dagster_project/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
sqlmesh_config = SQLMeshContextConfig(path=SQLMESH_PROJECT_PATH, gateway="local")


@asset
@asset(key=["db", "sources", "reset_asset"])
def reset_asset() -> MaterializeResult:
"""An asset used for testing this entire workflow. If the duckdb database is
found, this will delete it. This allows us to continously test this dag if
Expand All @@ -34,7 +34,7 @@ def reset_asset() -> MaterializeResult:
return MaterializeResult(metadata={"deleted": deleted})


@asset(deps=[reset_asset])
@asset(deps=[reset_asset], key=["db", "sources", "test_source"])
def test_source() -> pl.DataFrame:
"""Sets up the `test_source` table in duckdb that one of the sample sqlmesh
models depends on"""
Expand Down