Skip to content

Commit 745db0f

Browse files
authored
Fix: treat all instances of macro variables as case-insensitive (#5352)
1 parent 7ad241a commit 745db0f

File tree

3 files changed

+78
-11
lines changed

3 files changed

+78
-11
lines changed

sqlmesh/core/macros.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,17 @@ def _macro_str_replace(text: str) -> str:
128128
return f"self.template({text}, locals())"
129129

130130

131+
class CaseInsensitiveMapping(t.Dict[str, t.Any]):
132+
def __init__(self, data: t.Dict[str, t.Any]) -> None:
133+
super().__init__(data)
134+
135+
def __getitem__(self, key: str) -> t.Any:
136+
return super().__getitem__(key.lower())
137+
138+
def get(self, key: str, default: t.Any = None, /) -> t.Any:
139+
return super().get(key.lower(), default)
140+
141+
131142
class MacroDialect(Python):
132143
class Generator(Python.Generator):
133144
TRANSFORMS = {
@@ -256,14 +267,18 @@ def evaluate_macros(
256267
changed = True
257268
variables = self.variables
258269

259-
if node.name not in self.locals and node.name.lower() not in variables:
270+
# This makes all variables case-insensitive, e.g. @X is the same as @x. We do this
271+
# for consistency, since `variables` and `blueprint_variables` are normalized.
272+
var_name = node.name.lower()
273+
274+
if var_name not in self.locals and var_name not in variables:
260275
if not isinstance(node.parent, StagedFilePath):
261276
raise SQLMeshError(f"Macro variable '{node.name}' is undefined.")
262277

263278
return node
264279

265280
# Precedence order is locals (e.g. @DEF) > blueprint variables > config variables
266-
value = self.locals.get(node.name, variables.get(node.name.lower()))
281+
value = self.locals.get(var_name, variables.get(var_name))
267282
if isinstance(value, list):
268283
return exp.convert(
269284
tuple(
@@ -313,11 +328,11 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str:
313328
"""
314329
# We try to convert all variables into sqlglot expressions because they're going to be converted
315330
# into strings; in sql we don't convert strings because that would result in adding quotes
316-
mapping = {
317-
k: convert_sql(v, self.dialect)
331+
base_mapping = {
332+
k.lower(): convert_sql(v, self.dialect)
318333
for k, v in chain(self.variables.items(), self.locals.items(), local_variables.items())
319334
}
320-
return MacroStrTemplate(str(text)).safe_substitute(mapping)
335+
return MacroStrTemplate(str(text)).safe_substitute(CaseInsensitiveMapping(base_mapping))
321336

322337
def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | None:
323338
if isinstance(node, MacroDef):
@@ -327,7 +342,9 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] |
327342
args[0] if len(args) == 1 else exp.Tuple(expressions=list(args))
328343
)
329344
else:
330-
self.locals[node.name] = self.transform(node.expression)
345+
# Make variables defined through `@DEF` case-insensitive
346+
self.locals[node.name.lower()] = self.transform(node.expression)
347+
331348
return node
332349

333350
if isinstance(node, (MacroSQL, MacroStrReplace)):
@@ -630,7 +647,7 @@ def substitute(
630647
) -> exp.Expression | t.List[exp.Expression] | None:
631648
if isinstance(node, (exp.Identifier, exp.Var)):
632649
if not isinstance(node.parent, exp.Column):
633-
name = node.name
650+
name = node.name.lower()
634651
if name in args:
635652
return args[name].copy()
636653
if name in evaluator.locals:
@@ -663,7 +680,7 @@ def substitute(
663680
return expressions, lambda args: func.this.transform(
664681
substitute,
665682
{
666-
expression.name: arg
683+
expression.name.lower(): arg
667684
for expression, arg in zip(
668685
func.expressions, args.expressions if isinstance(args, exp.Tuple) else [args]
669686
)

tests/core/test_macros.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,16 @@ def test_ast_correctness(macro_evaluator):
292292
"SELECT 'a' + a_z + 'c' + c_a, 'b' + b_z + 'c' + c_b",
293293
{"y": "c"},
294294
),
295+
(
296+
"""select @each(['a'], x -> @X)""",
297+
"SELECT 'a'",
298+
{},
299+
),
300+
(
301+
"""select @each(['a'], X -> @x)""",
302+
"SELECT 'a'",
303+
{},
304+
),
295305
(
296306
'"is_@{x}"',
297307
'"is_b"',
@@ -1112,7 +1122,9 @@ def test_macro_with_spaces():
11121122

11131123
for sql, expected in (
11141124
("@x", '"a b"'),
1125+
("@X", '"a b"'),
11151126
("@{x}", '"a b"'),
1127+
("@{X}", '"a b"'),
11161128
("a_@x", '"a_a b"'),
11171129
("a.@x", 'a."a b"'),
11181130
("@y", "'a b'"),
@@ -1121,6 +1133,7 @@ def test_macro_with_spaces():
11211133
("a.@{y}", 'a."a b"'),
11221134
("@z", 'a."b c"'),
11231135
("d.@z", 'd.a."b c"'),
1136+
("@'test_@{X}_suffix'", "'test_a b_suffix'"),
11241137
):
11251138
assert evaluator.transform(parse_one(sql)).sql() == expected
11261139

tests/core/test_model.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9377,9 +9377,9 @@ def test_model_blueprinting(tmp_path: Path) -> None:
93779377
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
93789378
)
93799379

9380-
blueprint_sql = tmp_path / "macros" / "identity_macro.py"
9381-
blueprint_sql.parent.mkdir(parents=True, exist_ok=True)
9382-
blueprint_sql.write_text(
9380+
identity_macro = tmp_path / "macros" / "identity_macro.py"
9381+
identity_macro.parent.mkdir(parents=True, exist_ok=True)
9382+
identity_macro.write_text(
93839383
"""from sqlmesh import macro
93849384
93859385
@macro()
@@ -11623,3 +11623,40 @@ def test_use_original_sql():
1162311623
assert model.query_.sql == "SELECT 1 AS one, 2 AS two"
1162411624
assert model.pre_statements_[0].sql == "CREATE TABLE pre (a INT)"
1162511625
assert model.post_statements_[0].sql == "CREATE TABLE post (b INT)"
11626+
11627+
11628+
def test_case_sensitive_macro_locals(tmp_path: Path) -> None:
11629+
init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY)
11630+
11631+
db_path = str(tmp_path / "db.db")
11632+
db_connection = DuckDBConnectionConfig(database=db_path)
11633+
11634+
config = Config(
11635+
gateways={"gw": GatewayConfig(connection=db_connection)},
11636+
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
11637+
)
11638+
11639+
macro_file = tmp_path / "macros" / "some_macro_with_globals.py"
11640+
macro_file.parent.mkdir(parents=True, exist_ok=True)
11641+
macro_file.write_text(
11642+
"""from sqlmesh import macro
11643+
11644+
x = 1
11645+
X = 2
11646+
11647+
@macro()
11648+
def my_macro(evaluator):
11649+
assert evaluator.locals.get("x") == 1
11650+
assert evaluator.locals.get("X") == 2
11651+
11652+
return x + X
11653+
"""
11654+
)
11655+
test_model = tmp_path / "models" / "test_model.sql"
11656+
test_model.parent.mkdir(parents=True, exist_ok=True)
11657+
test_model.write_text("MODEL (name test_model, kind FULL); SELECT @my_macro() AS c")
11658+
11659+
context = Context(paths=tmp_path, config=config)
11660+
model = context.get_model("test_model", raise_if_missing=True)
11661+
11662+
assert model.render_query_or_raise().sql() == 'SELECT 3 AS "c"'

0 commit comments

Comments
 (0)