Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions dagster_sqlmesh/controller/dagster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
from dagster_sqlmesh.types import (
ConvertibleToAssetDep,
ConvertibleToAssetKey,
ConvertibleToAssetOut,
SQLMeshModelDep,
SQLMeshMultiAssetOptions,
Expand All @@ -24,11 +24,14 @@ def to_asset_outs(
environment: str,
translator: SQLMeshDagsterTranslator,
) -> SQLMeshMultiAssetOptions:
"""Loads all the asset outs of the current sqlmesh environment. If a
cache is provided, it will be tried first to load the asset outs."""
"""Loads all the asset outs of the current sqlmesh environment.

If a cache is provided, it will be tried first to load the asset outs.
External dependencies use IntermediateAssetDep objects that convert to AssetKey.
"""

internal_asset_deps_map: dict[str, set[str]] = {}
deps_map: dict[str, ConvertibleToAssetDep] = {}
deps_map: dict[str, ConvertibleToAssetKey] = {}
asset_outs: dict[str, ConvertibleToAssetOut] = {}

with self.instance(environment, "to_asset_outs") as instance:
Expand All @@ -52,14 +55,15 @@ def to_asset_outs(

internal_asset_deps.add(dep_asset_key_str)
else:
# External dependency - create IntermediateAssetDep
table = translator.get_asset_key_str(dep.fqn)
key = translator.get_asset_key(
context, dep.fqn
).to_user_string()
internal_asset_deps.add(key)
key = translator.get_asset_key(context, dep.fqn)
internal_asset_deps.add(key.to_user_string())

# create an external dep
deps_map[table] = translator.create_asset_dep(key=key)
# Create lazy intermediate representation for caching
deps_map[table] = translator.create_asset_dep(
key=key.to_user_string()
)

model_key = translator.get_asset_key_str(model.fqn)
asset_outs[model_key] = translator.create_asset_out(
Expand Down
60 changes: 38 additions & 22 deletions dagster_sqlmesh/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
ContextFactory,
)
from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController

if t.TYPE_CHECKING:
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -101,9 +99,9 @@ def from_dagster_metadata(
# convert metadata values
converted: dict[str, dg.MetadataValue] = {}
for key, value in metadata.items():
assert isinstance(
value, dg.MetadataValue
), f"Expected MetadataValue for {key}, got {type(value)}"
assert isinstance(value, dg.MetadataValue), (
f"Expected MetadataValue for {key}, got {type(value)}"
)
converted[key] = value

return cls.model_validate(
Expand Down Expand Up @@ -224,7 +222,7 @@ def stop_promotion(self) -> None:

def plan(self, batches: dict[Snapshot, int]) -> None:
self._batches = batches
self._count: dict[Snapshot, int] = {}
self._count = {}

for snapshot, _ in self._batches.items():
self._count[snapshot] = 0
Expand Down Expand Up @@ -331,7 +329,7 @@ def __init__(
models_map: dict[str, Model],
dag: DAG[t.Any],
prefix: str,
translator: "SQLMeshDagsterTranslator",
translator: SQLMeshDagsterTranslator,
is_testing: bool = False,
materializations_enabled: bool = True,
) -> None:
Expand Down Expand Up @@ -423,9 +421,9 @@ def create_materialize_result(
)
last_materialization_status = None
else:
assert (
last_materialization.asset_materialization is not None
), "Expected asset materialization to be present."
assert last_materialization.asset_materialization is not None, (
"Expected asset materialization to be present."
)
try:
last_materialization_status = (
ModelMaterializationStatus.from_dagster_metadata(
Expand Down Expand Up @@ -496,15 +494,19 @@ def report_event(self, event: console.ConsoleEvent) -> None:
log_context.info(
"Snapshot progress complete",
{
"asset_key": self._translator.get_asset_key_str(snapshot.model.name),
"asset_key": self._translator.get_asset_key_str(
snapshot.model.name
),
},
)
self._tracker.update_run(snapshot)
else:
log_context.info(
"Snapshot progress update",
{
"asset_key": self._translator.get_asset_key_str(snapshot.model.name),
"asset_key": self._translator.get_asset_key_str(
snapshot.model.name
),
"progress": f"{done}/{expected}",
"duration_ms": duration_ms,
},
Expand Down Expand Up @@ -580,6 +582,15 @@ def errors(self) -> list[Exception]:


class SQLMeshResource(dg.ConfigurableResource):
"""Dagster resource for executing SQLMesh plan and run operations.

The translator is obtained from `config.get_translator()` to ensure
consistency between asset definition loading and runtime execution.

Attributes:
is_testing: Whether the resource is being used in a testing context.
"""

is_testing: bool = False

def run(
Expand All @@ -599,16 +610,14 @@ def run(
run_options: RunOptions | None = None,
materializations_enabled: bool = True,
) -> t.Iterable[dg.MaterializeResult[t.Any]]:
"""Execute SQLMesh based on the configuration given"""
"""Execute SQLMesh plan and run, yielding MaterializeResult for each model."""
plan_options = plan_options or {}
run_options = run_options or {}

logger = context.log

controller = self.get_controller(
config=config,
context_factory=context_factory,
log_override=logger
config=config, context_factory=context_factory, log_override=logger
)

with controller.instance(environment) as mesh:
Expand All @@ -620,7 +629,9 @@ def run(
[model.fqn for model, _ in mesh.non_external_models_dag()]
)
selected_models_set, models_map, select_models = (
self._get_selected_models_from_context(context=context, config=config, models=models)
self._get_selected_models_from_context(
context=context, config=config, models=models
)
)

if all_available_models == selected_models_set or select_models is None:
Expand Down Expand Up @@ -696,6 +707,7 @@ def create_event_handler(
is_testing: bool,
materializations_enabled: bool,
) -> DagsterSQLMeshEventHandler:
"""Create an event handler for processing SQLMesh console events."""
translator = config.get_translator()
return DagsterSQLMeshEventHandler(
context=context,
Expand All @@ -708,14 +720,17 @@ def create_event_handler(
)

def _get_selected_models_from_context(
self,
context: dg.AssetExecutionContext,
self,
context: dg.AssetExecutionContext,
config: SQLMeshContextConfig,
models: MappingProxyType[str, Model]
models: MappingProxyType[str, Model],
) -> tuple[set[str], dict[str, Model], list[str] | None]:
"""Get the selected models from the execution context."""
models_map = models.copy()
try:
selected_output_names = set(context.op_execution_context.selected_output_names)
selected_output_names = set(
context.op_execution_context.selected_output_names
)
except (DagsterInvalidPropertyError, AttributeError) as e:
# Special case for direct execution context when testing. This is related to:
# https://github.com/dagster-io/dagster/issues/23633
Expand Down Expand Up @@ -744,6 +759,7 @@ def get_controller(
context_factory: ContextFactory[ContextCls],
log_override: logging.Logger | None = None,
) -> DagsterSQLMeshController[ContextCls]:
"""Get a SQLMesh controller for executing operations."""
return DagsterSQLMeshController.setup_with_config(
config=config,
context_factory=context_factory,
Expand Down
Loading