diff --git a/dagster_sqlmesh/__init__.py b/dagster_sqlmesh/__init__.py index 0cdaaea..286d39d 100644 --- a/dagster_sqlmesh/__init__.py +++ b/dagster_sqlmesh/__init__.py @@ -2,4 +2,6 @@ from .asset import * from .config import * +from .controller import * from .resource import * +from .translator import * diff --git a/dagster_sqlmesh/asset.py b/dagster_sqlmesh/asset.py index c697b1b..cee12e9 100644 --- a/dagster_sqlmesh/asset.py +++ b/dagster_sqlmesh/asset.py @@ -11,9 +11,61 @@ DagsterSQLMeshController, ) from dagster_sqlmesh.translator import SQLMeshDagsterTranslator +from dagster_sqlmesh.types import SQLMeshMultiAssetOptions logger = logging.getLogger(__name__) +def sqlmesh_to_multi_asset_options( + *, + environment: str, + config: SQLMeshContextConfig, + context_factory: ContextFactory[ContextCls] = lambda **kwargs: Context(**kwargs), + dagster_sqlmesh_translator: SQLMeshDagsterTranslator | None = None, +) -> SQLMeshMultiAssetOptions: + """Converts sqlmesh project into a SQLMeshMultiAssetOptions object which is + an intermediate representation of the SQLMesh project that can be used to + create a dagster multi_asset definition.""" + controller = DagsterSQLMeshController.setup_with_config( + config=config, context_factory=context_factory + ) + if not dagster_sqlmesh_translator: + dagster_sqlmesh_translator = SQLMeshDagsterTranslator() + + conversion = controller.to_asset_outs( + environment, + translator=dagster_sqlmesh_translator, + ) + return conversion + +def sqlmesh_asset_from_multi_asset_options( + *, + sqlmesh_multi_asset_options: SQLMeshMultiAssetOptions, + name: str | None = None, + compute_kind: str = "sqlmesh", + op_tags: t.Mapping[str, t.Any] | None = None, + required_resource_keys: set[str] | None = None, + retry_policy: RetryPolicy | None = None, + enabled_subsetting: bool = False, +) -> t.Callable[[t.Callable[..., t.Any]], AssetsDefinition]: + """Creates a dagster multi_asset definition from a SQLMeshMultiAssetOptions object.""" + kwargs: dict[str, t.Any] = {} + if enabled_subsetting: + kwargs["can_subset"] = True + + #asset_deps = sqlmesh_multi_asset_options.to_asset_deps() + #print("Asset deps boop:", asset_deps) # Debugging line + + return multi_asset( + outs=sqlmesh_multi_asset_options.to_asset_outs(), + deps=sqlmesh_multi_asset_options.to_asset_deps(), + internal_asset_deps=sqlmesh_multi_asset_options.to_internal_asset_deps(), + name=name, + compute_kind=compute_kind, + op_tags=op_tags, + required_resource_keys=required_resource_keys, + retry_policy=retry_policy, + **kwargs, + ) # Define a SQLMesh Asset def sqlmesh_assets( @@ -30,19 +82,19 @@ def sqlmesh_assets( # For now we don't set this by default enabled_subsetting: bool = False, ) -> t.Callable[[t.Callable[..., t.Any]], AssetsDefinition]: - controller = DagsterSQLMeshController.setup_with_config(config=config, context_factory=context_factory) - if not dagster_sqlmesh_translator: - dagster_sqlmesh_translator = SQLMeshDagsterTranslator() - conversion = controller.to_asset_outs(environment, translator=dagster_sqlmesh_translator) - - return multi_asset( + conversion = sqlmesh_to_multi_asset_options( + environment=environment, + config=config, + context_factory=context_factory, + dagster_sqlmesh_translator=dagster_sqlmesh_translator, + ) + + return sqlmesh_asset_from_multi_asset_options( + sqlmesh_multi_asset_options=conversion, name=name, - outs=conversion.outs, - deps=conversion.deps, - internal_asset_deps=conversion.internal_asset_deps, - op_tags=op_tags, compute_kind=compute_kind, - retry_policy=retry_policy, - can_subset=enabled_subsetting, + op_tags=op_tags, required_resource_keys=required_resource_keys, + retry_policy=retry_policy, + enabled_subsetting=enabled_subsetting, ) diff --git a/dagster_sqlmesh/config.py b/dagster_sqlmesh/config.py index 61431b3..7d60b2d 100644 --- a/dagster_sqlmesh/config.py +++ b/dagster_sqlmesh/config.py @@ -1,9 +1,11 @@ from dataclasses import dataclass +from pathlib import Path from typing import Any from dagster import Config from pydantic import Field from sqlmesh.core.config import Config as MeshConfig +from sqlmesh.core.config.loader import load_configs @dataclass @@ -27,7 +29,11 @@ class SQLMeshContextConfig(Config): config_override: dict[str, Any] | None = Field(default_factory=lambda: None) @property - def sqlmesh_config(self) -> MeshConfig | None: + def sqlmesh_config(self) -> MeshConfig: if self.config_override: return MeshConfig.parse_obj(self.config_override) - return None \ No newline at end of file + sqlmesh_path = Path(self.path) + configs = load_configs(None, MeshConfig, [sqlmesh_path]) + if sqlmesh_path not in configs: + raise ValueError(f"SQLMesh configuration not found at {sqlmesh_path}") + return configs[sqlmesh_path] \ No newline at end of file diff --git a/dagster_sqlmesh/controller/dagster.py b/dagster_sqlmesh/controller/dagster.py index bc858fc..dffff4d 100644 --- a/dagster_sqlmesh/controller/dagster.py +++ b/dagster_sqlmesh/controller/dagster.py @@ -1,13 +1,17 @@ # pyright: reportPrivateImportUsage=false import logging -from inspect import signature -from dagster import AssetDep, AssetKey, AssetOut -from dagster._core.definitions.asset_dep import CoercibleToAssetDep - -from dagster_sqlmesh.controller.base import ContextCls, SQLMeshController +from dagster_sqlmesh.controller.base import ( + ContextCls, + SQLMeshController, +) from dagster_sqlmesh.translator import SQLMeshDagsterTranslator -from dagster_sqlmesh.types import SQLMeshModelDep, SQLMeshMultiAssetOptions +from dagster_sqlmesh.types import ( + ConvertibleToAssetDep, + ConvertibleToAssetOut, + SQLMeshModelDep, + SQLMeshMultiAssetOptions, +) from dagster_sqlmesh.utils import get_asset_key_str logger = logging.getLogger(__name__) @@ -17,47 +21,65 @@ class DagsterSQLMeshController(SQLMeshController[ContextCls]): """An extension of the sqlmesh controller specifically for dagster use""" def to_asset_outs( - self, environment: str, translator: SQLMeshDagsterTranslator, + self, + environment: str, + translator: SQLMeshDagsterTranslator, ) -> SQLMeshMultiAssetOptions: + """Loads all the asset outs of the current sqlmesh environment. If a + cache is provided, it will be tried first to load the asset outs.""" + + internal_asset_deps_map: dict[str, set[str]] = {} + deps_map: dict[str, ConvertibleToAssetDep] = {} + asset_outs: dict[str, ConvertibleToAssetOut] = {} + with self.instance(environment, "to_asset_outs") as instance: context = instance.context - output = SQLMeshMultiAssetOptions() - depsMap: dict[str, CoercibleToAssetDep] = {} for model, deps in instance.non_external_models_dag(): asset_key = translator.get_asset_key(context=context, fqn=model.fqn) + asset_key_str = asset_key.to_user_string() model_deps = [ SQLMeshModelDep(fqn=dep, model=context.get_model(dep)) for dep in deps ] - internal_asset_deps: set[AssetKey] = set() + internal_asset_deps: set[str] = set() asset_tags = translator.get_tags(context, model) for dep in model_deps: if dep.model: - internal_asset_deps.add( - translator.get_asset_key(context, dep.model.fqn) - ) + dep_asset_key_str = translator.get_asset_key( + context, dep.model.fqn + ).to_user_string() + + internal_asset_deps.add(dep_asset_key_str) else: table = get_asset_key_str(dep.fqn) - key = translator.get_asset_key(context, dep.fqn) + key = translator.get_asset_key( + context, dep.fqn + ).to_user_string() internal_asset_deps.add(key) + # create an external dep - depsMap[table] = AssetDep(key) + deps_map[table] = translator.create_asset_dep(key=key) + model_key = get_asset_key_str(model.fqn) - # If current Dagster supports "kinds", add labels for Dagster UI - if "kinds" in signature(AssetOut).parameters: - output.outs[model_key] = AssetOut( - key=asset_key, tags=asset_tags, is_required=False, - group_name=translator.get_group_name(context, model), - kinds={"sqlmesh", translator._get_context_dialect(context).lower()} - ) - else: - output.outs[model_key] = AssetOut( - key=asset_key, tags=asset_tags, is_required=False, - group_name=translator.get_group_name(context, model) - ) - output.internal_asset_deps[model_key] = internal_asset_deps - - output.deps = list(depsMap.values()) - return output + asset_outs[model_key] = translator.create_asset_out( + model_key=model_key, + asset_key=asset_key_str, + tags=asset_tags, + is_required=False, + group_name=translator.get_group_name(context, model), + kinds={ + "sqlmesh", + translator.get_context_dialect(context).lower(), + }, + ) + internal_asset_deps_map[model_key] = internal_asset_deps + + deps = list(deps_map.values()) + + return SQLMeshMultiAssetOptions( + outs=asset_outs, + deps=deps, + internal_asset_deps=internal_asset_deps_map, + ) diff --git a/dagster_sqlmesh/translator.py b/dagster_sqlmesh/translator.py index 2da090d..648dbbf 100644 --- a/dagster_sqlmesh/translator.py +++ b/dagster_sqlmesh/translator.py @@ -1,10 +1,48 @@ +import typing as t from collections.abc import Sequence +from inspect import signature -from dagster import AssetKey +from dagster import AssetDep, AssetKey, AssetOut +from pydantic import BaseModel, Field from sqlglot import exp from sqlmesh.core.context import Context from sqlmesh.core.model import Model +from .types import ConvertibleToAssetDep, ConvertibleToAssetOut + + +class IntermediateAssetOut(BaseModel): + model_key: str + asset_key: str + tags: t.Mapping[str, str] | None = None + is_required: bool = True + group_name: str | None = None + kinds: set[str] | None = None + kwargs: dict[str, t.Any] = Field(default_factory=dict) + + def to_asset_out(self) -> AssetOut: + asset_key = AssetKey.from_user_string(self.asset_key) + + if "kinds" not in signature(AssetOut).parameters: + self.kinds = None + + return AssetOut( + key=asset_key, + tags=self.tags, + is_required=self.is_required, + group_name=self.group_name, + kinds=self.kinds, + **self.kwargs, + ) + + +class IntermediateAssetDep(BaseModel): + key: str + kwargs: dict[str, t.Any] = Field(default_factory=dict) + + def to_asset_dep(self) -> AssetDep: + return AssetDep(AssetKey.from_user_string(self.key)) + class SQLMeshDagsterTranslator: """Translates sqlmesh objects for dagster""" @@ -19,14 +57,40 @@ def get_asset_key_name(self, fqn: str) -> Sequence[str]: asset_key_name = [table.catalog, table.db, table.name] return asset_key_name - + def get_group_name(self, context: Context, model: Model) -> str: path = self.get_asset_key_name(model.fqn) return path[-2] - def _get_context_dialect(self, context: Context) -> str: + def get_context_dialect(self, context: Context) -> str: return context.engine_adapter.dialect + def create_asset_dep(self, *, key: str, **kwargs: t.Any) -> ConvertibleToAssetDep: + """Create an object that resolves to an AssetDep + + Most users of this library will not need to use this method, it is + primarily the way we enable cacheable assets from dagster-sqlmesh. + """ + return IntermediateAssetDep(key=key, kwargs=kwargs) + + def create_asset_out( + self, *, model_key: str, asset_key: str, **kwargs: t.Any + ) -> ConvertibleToAssetOut: + """Create an object that resolves to an AssetOut + + Most users of this library will not need to use this method, it is + primarily the way we enable cacheable assets from dagster-sqlmesh. + """ + return IntermediateAssetOut( + model_key=model_key, + asset_key=asset_key, + kinds=kwargs.pop("kinds", None), + tags=kwargs.pop("tags", None), + group_name=kwargs.pop("group_name", None), + is_required=kwargs.pop("is_required", False), + kwargs=kwargs, + ) + def get_tags(self, context: Context, model: Model) -> dict[str, str]: """Given the sqlmesh context and a model return the tags for that model""" return {k: "true" for k in model.tags} diff --git a/dagster_sqlmesh/types.py b/dagster_sqlmesh/types.py index b567df8..50187b5 100644 --- a/dagster_sqlmesh/types.py +++ b/dagster_sqlmesh/types.py @@ -1,8 +1,7 @@ import typing as t from dataclasses import dataclass, field -from dagster import AssetCheckResult, AssetKey, AssetMaterialization, AssetOut -from dagster._core.definitions.asset_dep import CoercibleToAssetDep +from dagster import AssetCheckResult, AssetDep, AssetKey, AssetMaterialization, AssetOut from sqlmesh.core.model import Model MultiAssetResponse = t.Iterable[AssetCheckResult | AssetMaterialization] @@ -30,10 +29,44 @@ class SQLMeshModelDep: def parse_fqn(self) -> SQLMeshParsedFQN: return SQLMeshParsedFQN.parse(self.fqn) + +class ConvertibleToAssetOut(t.Protocol): + def to_asset_out(self) -> AssetOut: + """Convert to an AssetOut object.""" + ... +class ConvertibleToAssetDep(t.Protocol): + def to_asset_dep(self) -> AssetDep: + """Convert to an AssetDep object.""" + ... + +class ConvertibleToAssetKey(t.Protocol): + def to_asset_key(self) -> AssetKey: + ... @dataclass(kw_only=True) class SQLMeshMultiAssetOptions: - outs: dict[str, AssetOut] = field(default_factory=lambda: {}) - deps: t.Iterable[CoercibleToAssetDep] = field(default_factory=lambda: {}) - internal_asset_deps: dict[str, set[AssetKey]] = field(default_factory=lambda: {}) + """Generic class for returning dagster multi asset options from SQLMesh, the + types used are intentionally generic so to allow for potentially using an + intermediate representation of the dagster asset objects. This is most + useful in caching purposes and is done to allow for users of this library to + manipulate the dagster asset creation process as they see fit.""" + + outs: t.Mapping[str, ConvertibleToAssetOut] = field(default_factory=lambda: {}) + deps: t.Iterable[ConvertibleToAssetDep] = field(default_factory=lambda: []) + internal_asset_deps: t.Mapping[str, set[str]] = field(default_factory=lambda: {}) + + def to_asset_outs(self) -> t.Mapping[str, AssetOut]: + """Convert to an iterable of AssetOut objects.""" + return {key: out.to_asset_out() for key, out in self.outs.items()} + + def to_asset_deps(self) -> t.Iterable[AssetDep]: + """Convert to an iterable of AssetDep objects.""" + return [dep.to_asset_dep() for dep in self.deps] + + def to_internal_asset_deps(self) -> dict[str, set[AssetKey]]: + """Convert to a dictionary of internal asset dependencies.""" + return { + key: {AssetKey.from_user_string(dep) for dep in deps} + for key, deps in self.internal_asset_deps.items() + } \ No newline at end of file diff --git a/sample/dagster_project/definitions.py b/sample/dagster_project/definitions.py index 8719a9e..8510830 100644 --- a/sample/dagster_project/definitions.py +++ b/sample/dagster_project/definitions.py @@ -5,23 +5,50 @@ import polars as pl from dagster import ( AssetExecutionContext, + AssetKey, Definitions, MaterializeResult, asset, define_asset_job, ) from dagster_duckdb_polars import DuckDBPolarsIOManager +from sqlglot import exp -from dagster_sqlmesh import SQLMeshContextConfig, SQLMeshResource, sqlmesh_assets +from dagster_sqlmesh import ( + Context, + SQLMeshContextConfig, + SQLMeshDagsterTranslator, + SQLMeshResource, + sqlmesh_assets, +) CURR_DIR = os.path.dirname(__file__) SQLMESH_PROJECT_PATH = os.path.abspath(os.path.join(CURR_DIR, "../sqlmesh_project")) +SQLMESH_CACHE_PATH = os.path.join(SQLMESH_PROJECT_PATH, ".cache") DUCKDB_PATH = os.path.join(CURR_DIR, "../../db.db") sqlmesh_config = SQLMeshContextConfig(path=SQLMESH_PROJECT_PATH, gateway="local") -@asset(key=["db", "sources", "reset_asset"]) +class RewrittenSQLMeshTranslator(SQLMeshDagsterTranslator): + """A contrived SQLMeshDagsterTranslator that flattens the catalog of the + sqlmesh project and only uses the table db and name + + We include this as a test of the translator functionality. + """ + + def get_asset_key(self, context: Context, fqn: str) -> AssetKey: + table = exp.to_table(fqn) # Ensure fqn is a valid table expression + if table.db == "sqlmesh_example": + # For the sqlmesh_example project, we use a custom key + return AssetKey(["sqlmesh", table.name]) + return AssetKey([table.db, table.name]) + + def get_group_name(self, context, model): + return "sqlmesh" + + +@asset(key=["sources", "reset_asset"]) def reset_asset() -> MaterializeResult: """An asset used for testing this entire workflow. If the duckdb database is found, this will delete it. This allows us to continously test this dag if @@ -34,7 +61,7 @@ def reset_asset() -> MaterializeResult: return MaterializeResult(metadata={"deleted": deleted}) -@asset(deps=[reset_asset], key=["db", "sources", "test_source"]) +@asset(deps=[reset_asset], key=["sources", "test_source"]) def test_source() -> pl.DataFrame: """Sets up the `test_source` table in duckdb that one of the sample sqlmesh models depends on""" @@ -52,15 +79,40 @@ def test_source() -> pl.DataFrame: ) -@sqlmesh_assets(environment="dev", config=sqlmesh_config, enabled_subsetting=True) -def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource) -> t.Iterator[MaterializeResult]: +@asset(deps=[AssetKey(["sqlmesh", "full_model"])]) +def post_full_model() -> pl.DataFrame: + """An asset that depends on the `full_model` asset from the sqlmesh project. + This is used to test that the sqlmesh assets are correctly materialized and + can be used in other assets. + """ + import duckdb + + conn = duckdb.connect(DUCKDB_PATH) + df = conn.query( + """ + SELECT * FROM sqlmesh_example__dev.full_model + """ + ).to_df() + conn.close() + return pl.from_dataframe(df) + + +@sqlmesh_assets( + environment="dev", + config=sqlmesh_config, + enabled_subsetting=True, + dagster_sqlmesh_translator=RewrittenSQLMeshTranslator(), +) +def sqlmesh_project( + context: AssetExecutionContext, sqlmesh: SQLMeshResource +) -> t.Iterator[MaterializeResult]: yield from sqlmesh.run(context) all_assets_job = define_asset_job(name="all_assets_job") defs = Definitions( - assets=[sqlmesh_project, test_source, reset_asset], + assets=[sqlmesh_project, test_source, reset_asset, post_full_model], resources={ "sqlmesh": SQLMeshResource(config=sqlmesh_config), "io_manager": DuckDBPolarsIOManager( diff --git a/uv.lock b/uv.lock index 9a12b51..3b69fc2 100644 --- a/uv.lock +++ b/uv.lock @@ -291,7 +291,7 @@ wheels = [ [[package]] name = "dagster-sqlmesh" -version = "0.17.0" +version = "0.18.0" source = { editable = "." } dependencies = [ { name = "dagster" },