Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 112 additions & 69 deletions dagster_sqlmesh/console.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import inspect
import logging
import textwrap
import typing as t
import unittest
import uuid
Expand Down Expand Up @@ -148,9 +147,10 @@ class Plan(BaseConsoleEvent):
@dataclass(kw_only=True)
class LogTestResults(BaseConsoleEvent):
result: unittest.result.TestResult
output: str | None
output: str | None = None
target_dialect: str


@dataclass(kw_only=True)
class ShowSQL(BaseConsoleEvent):
sql: str
Expand Down Expand Up @@ -221,7 +221,7 @@ class ShowTableDiffSummary(BaseConsoleEvent):

@dataclass(kw_only=True)
class PlanBuilt(BaseConsoleEvent):
plan: SQLMeshPlan
plan: SQLMeshPlan

ConsoleEvent = (
StartPlanEvaluation
Expand Down Expand Up @@ -277,6 +277,8 @@ class PlanBuilt(BaseConsoleEvent):
]

T = t.TypeVar("T")
EventType = t.TypeVar("EventType", bound=BaseConsoleEvent)


def get_console_event_by_name(
event_name: str,
Expand All @@ -297,13 +299,14 @@ class IntrospectingConsole(Console):

def __init_subclass__(cls):
super().__init_subclass__()
# Store method info for later creation in __init__
cls._method_info = []

known_events_classes = cls.events
known_events: list[str] = []
for known_event in known_events_classes:
assert inspect.isclass(known_event), "event must be a class"
known_events.append(known_event.__name__)


# Iterate through all the available abstract methods in console
for method_name in Console.__abstractmethods__:
Expand All @@ -314,96 +317,50 @@ def __init_subclass__(cls):
continue
logger.debug(f"Checking {method_name}")

# if the method doesn't exist we automatically create a method by
# inspecting the method's arguments. Anything that matches "known"
# events has it's values checked. The dataclass should define the
# required fields and everything else should be sent to a catchall
# argument in the dataclass for the event

# Convert method name from snake_case to camel case
camel_case_method_name = "".join(
word.capitalize()
for i, word in enumerate(method_name.split("_"))
)

signature = inspect.signature(getattr(Console, method_name))

if camel_case_method_name in known_events:
logger.debug(f"Creating {method_name} for {camel_case_method_name}")
signature = inspect.signature(getattr(Console, method_name))
handler = cls.create_event_handler(method_name, camel_case_method_name, signature)
setattr(cls, method_name, handler)
logger.debug(f"Storing {method_name} for {camel_case_method_name}")
event_cls = get_console_event_by_name(camel_case_method_name)
assert event_cls is not None, f"Event {camel_case_method_name} not found"
cls._method_info.append(('known', method_name, event_cls, signature))
else:
logger.debug(f"Creating {method_name} for unknown event")
signature = inspect.signature(getattr(Console, method_name))
handler = cls.create_unknown_event_handler(method_name, signature)
setattr(cls, method_name, handler)

@classmethod
def create_event_handler(cls, method_name: str, event_name: str, signature: inspect.Signature):
func_signature, call_params = cls.create_signatures_and_params(signature)

event_handler_str = textwrap.dedent(f"""
def {method_name}({", ".join(func_signature)}):
self.publish_known_event('{event_name}', {", ".join(call_params)})
""")
exec(event_handler_str)
return t.cast(t.Callable[[t.Any], t.Any], locals()[method_name])

@classmethod
def create_signatures_and_params(cls, signature: inspect.Signature):
func_signature: list[str] = []
call_params: list[str] = []
for param_name, param in signature.parameters.items():
if param_name == "self":
func_signature.append("self")
continue

if param.default is inspect._empty:
param_type_name = param.annotation
if not isinstance(param_type_name, str):
param_type_name = param_type_name.__name__
func_signature.append(f"{param_name}: '{param_type_name}'")
else:
default_value = param.default
param_type_name = param.annotation
if not isinstance(param_type_name, str):
param_type_name = param_type_name.__name__
if isinstance(param.default, str):
default_value = f"'{param.default}'"
func_signature.append(f"{param_name}: '{param_type_name}' = {default_value}")
call_params.append(f"{param_name}={param_name}")
return (func_signature, call_params)

@classmethod
def create_unknown_event_handler(cls, method_name: str, signature: inspect.Signature):
func_signature, call_params = cls.create_signatures_and_params(signature)

event_handler_str = textwrap.dedent(f"""
def {method_name}({", ".join(func_signature)}):
self.publish_unknown_event('{method_name}', {", ".join(call_params)})
""")
exec(event_handler_str)
return t.cast(t.Callable[[t.Any], t.Any], locals()[method_name])
logger.debug(f"Storing {method_name} for unknown event")
cls._method_info.append(('unknown', method_name, None, signature))

def __init__(self, log_override: logging.Logger | None = None) -> None:
self._handlers: dict[str, ConsoleEventHandler] = {}
self.logger = log_override or logger
self.id = str(uuid.uuid4())
self.logger.debug(f"EventConsole[{self.id}]: created")
self.categorizer = None

# Create methods now that we have self
for method_type, method_name, event_cls, signature in self._method_info:
if method_type == 'known':
handler = GeneratedCallable(self, event_cls, signature, method_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit old school this way, but I don't like to do things in constructors that might fail if I can really help it. Can we instead put the logic to handle the methods back. I think perhaps telling you not to wrap in that handler made it not possible to use __init_subclass__? If so, let's just keep that wrapped handler then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when I tried to eliminate the wrapper handler, it made __init_subclass__ impossible to use because we needed self

else:
handler = UnknownEventCallable(self, method_name, signature)
setattr(self, method_name, handler)

def publish_known_event(self, event_name: str, **kwargs: t.Any) -> None:
console_event = get_console_event_by_name(event_name)
assert console_event is not None, f"Event {event_name} not found"

expected_kwargs_fields = console_event.__dataclass_fields__
expected_kwargs: dict[str, t.Any] = {}
unknown_args: dict[str, t.Any] = {}

for key, value in kwargs.items():
if key not in expected_kwargs_fields:
unknown_args[key] = value
else:
expected_kwargs[key] = value

event = console_event(**expected_kwargs, unknown_args=unknown_args)

self.publish(event)
Expand Down Expand Up @@ -446,6 +403,92 @@ def capture_built_plan(self, plan: SQLMeshPlan) -> None:
"""Capture the built plan and publish a PlanBuilt event."""
self.publish(PlanBuilt(plan=plan))


class GeneratedCallable(t.Generic[EventType]):
"""A callable that dynamically handles console method invocations and converts them to events."""

def __init__(
self,
console: IntrospectingConsole,
event_cls: type[EventType],
original_signature: inspect.Signature,
method_name: str
):
self.console = console
self.event_cls = event_cls
self.original_signature = original_signature
self.method_name = method_name

def __call__(self, *args: t.Any, **kwargs: t.Any) -> None:
"""Create an instance of the event class with the provided arguments."""
# Bind arguments to the original signature
try:
bound = self.original_signature.bind(*args, **kwargs)
bound.apply_defaults()
except TypeError as e:
# If binding fails, collect all args/kwargs as unknown
self.console.logger.warning(f"Failed to bind arguments for {self.method_name}: {e}")
unknown_args = {str(i): arg for i, arg in enumerate(args[1:])} # Skip 'self'
unknown_args.update(kwargs)
self._create_and_publish_event({}, unknown_args)
return

# Process bound arguments
bound_args = dict(bound.arguments)
bound_args.pop("self", None) # Remove self from arguments

self._create_and_publish_event(bound_args, {})

def _create_and_publish_event(self, bound_args: dict[str, t.Any], extra_unknown: dict[str, t.Any]) -> None:
"""Create and publish the event with proper argument handling."""
expected_fields = self.event_cls.__dataclass_fields__
expected_kwargs: dict[str, t.Any] = {}
unknown_args: dict[str, t.Any] = {}

# Add any extra unknown args first
unknown_args.update(extra_unknown)

# Process bound arguments
for key, value in bound_args.items():
if key in expected_fields:
expected_kwargs[key] = value
else:
unknown_args[key] = value

# Create and publish the event
event = self.event_cls(**expected_kwargs, unknown_args=unknown_args)
self.console.publish(t.cast(ConsoleEvent, event))


class UnknownEventCallable:
"""A callable for handling unknown console events."""

def __init__(
self,
console: IntrospectingConsole,
method_name: str,
original_signature: inspect.Signature
):
self.console = console
self.method_name = method_name
self.original_signature = original_signature

def __call__(self, *args: t.Any, **kwargs: t.Any) -> None:
"""Handle unknown event method calls."""
# Bind arguments to the original signature
try:
bound = self.original_signature.bind(*args, **kwargs)
bound.apply_defaults()
bound_args = dict(bound.arguments)
bound_args.pop("self", None) # Remove self from arguments
except TypeError:
# If binding fails, collect all args/kwargs
bound_args = {str(i): arg for i, arg in enumerate(args[1:])} # Skip 'self'
bound_args.update(kwargs)

self.console.publish_unknown_event(self.method_name, **bound_args)


class EventConsole(IntrospectingConsole):
"""
A console implementation that manages and publishes events related to
Expand Down
2 changes: 1 addition & 1 deletion dagster_sqlmesh/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def _get_selected_models_from_context(
) -> tuple[set[str], dict[str, Model], list[str] | None]:
models_map = models.copy()
try:
selected_output_names = set(context.selected_output_names)
selected_output_names = set(context.op_execution_context.selected_output_names)
except (DagsterInvalidPropertyError, AttributeError) as e:
# Special case for direct execution context when testing. This is related to:
# https://github.com/dagster-io/dagster/issues/23633
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ readme = "README.md"
requires-python = ">=3.11,<3.13"
dependencies = [
"dagster>=1.7.8",
"sqlmesh<0.188",
"sqlmesh>=0.188",
"pytest>=8.3.2",
"pyarrow>=18.0.0",
"pydantic>=2.11.5",
Expand Down Expand Up @@ -41,7 +41,7 @@ exclude = [
"**/.github",
"**/.vscode",
"**/.idea",
"**/.pytest_cache",
"**/.pytest_cache",
]
pythonVersion = "3.11"
reportUnknownParameterType = true
Expand Down
Loading