diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index 0a433360df..57dd150af2 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -95,6 +95,14 @@ def wrapper(self: SQLMeshMagics, *args: t.Any, **kwargs: t.Any) -> None: return wrapper +def parse_expand(value: str) -> t.Union[bool, t.List[str]]: + if value.lower() == "true": + return True + if value.lower() == "false": + return False + return [name.strip() for name in value.split(",") if name.strip()] + + def format_arguments(func: t.Callable) -> t.Callable: """Decorator to add common format arguments to magic commands.""" func = argument( @@ -633,8 +641,8 @@ def evaluate(self, context: Context, line: str) -> None: @argument("--execution-time", type=str, help="Execution time.") @argument( "--expand", - type=t.Union[bool, t.Iterable[str]], - help="Whether or not to use expand materialized models, defaults to False. If True, all referenced models are expanded as raw queries. If a list, only referenced models are expanded as raw queries.", + type=parse_expand, + help="Whether or not to use expand materialized models, defaults to False. If 'true', all referenced models are expanded as raw queries. If a comma-separated list of model names, only those models are expanded as raw queries.", ) @argument("--dialect", type=str, help="SQL dialect to render.") @argument("--no-format", action="store_true", help="Disable fancy formatting of the query.") @@ -647,6 +655,7 @@ def render(self, context: Context, line: str) -> None: render_opts = vars(parse_argstring(self.render, line)) model = render_opts.pop("model") dialect = render_opts.pop("dialect", None) + expand = render_opts.pop("expand", False) model = context.get_model(model, raise_if_missing=True) @@ -655,7 +664,7 @@ def render(self, context: Context, line: str) -> None: start=render_opts.pop("start", None), end=render_opts.pop("end", None), execution_time=render_opts.pop("execution_time", None), - expand=render_opts.pop("expand", False), + expand=expand, ) no_format = render_opts.pop("no_format", False)