Skip to content

Commit 54b0de1

Browse files
volokluevclaude
andauthored
feat(ast): add ArbitrarySQL AST node to allow for subquery optimization (#7636)
For cross item queries to work, we need to be able to put a subquery into an AST node. Rather than modifying the whole query pipeline to be able to do this, take a shortcut and add an `ArbitrarySQL` node that is not transformed in any way by the query pipeline but can be inserted into a manually constructed EAP query (as we do in EAP) The query I want to be able to build is : ``` SELECT * FROM eap_items_1_dist WHERE trace_id IN (SELECT trace_id FROM eap_items_1_dist WHERE...) ``` --------- Co-authored-by: Claude Sonnet 4.5 <[email protected]>
1 parent 9a5cfa8 commit 54b0de1

File tree

12 files changed

+499
-95
lines changed

12 files changed

+499
-95
lines changed

snuba/clickhouse/formatter/expression.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Argument,
1414
Column,
1515
CurriedFunctionCall,
16+
DangerousRawSQL,
1617
Expression,
1718
ExpressionVisitor,
1819
FunctionCall,
@@ -179,6 +180,14 @@ def visit_lambda(self, exp: Lambda) -> str:
179180
ret = f"{', '.join(parameters)} -> {exp.transformation.accept(self)}"
180181
return self._alias(ret, exp.alias)
181182

183+
def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> str:
184+
"""
185+
Format DangerousRawSQL by passing through the SQL content directly without
186+
any escaping or validation. This is intentional as DangerousRawSQL is meant
187+
for pre-validated SQL in query optimization scenarios.
188+
"""
189+
return self._alias(exp.sql, exp.alias)
190+
182191

183192
class ClickhouseExpressionFormatter(ExpressionFormatterBase):
184193
"""

snuba/clickhouse/translators/snuba/mapping.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Argument,
2929
Column,
3030
CurriedFunctionCall,
31+
DangerousRawSQL,
3132
FunctionCall,
3233
Lambda,
3334
Literal,
@@ -176,6 +177,11 @@ def visit_lambda(self, exp: Lambda) -> Expression:
176177
self.__cache[exp] = ret
177178
return ret
178179

180+
def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> Expression:
181+
# DangerousRawSQL is passed through unchanged during translation
182+
# since it contains pre-formatted SQL that should not be modified
183+
return exp
184+
179185
def translate_function_strict(self, exp: FunctionCall) -> FunctionCall:
180186
"""
181187
Unfortunately it is not possible to avoid this assertion.

snuba/query/dsl_mapper.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Argument,
99
Column,
1010
CurriedFunctionCall,
11+
DangerousRawSQL,
1112
Expression,
1213
ExpressionVisitor,
1314
FunctionCall,
@@ -175,16 +176,18 @@ def visit_curried_function_call(self, exp: CurriedFunctionCall) -> str:
175176
if len(exp.parameters) == 1:
176177
raw_parameters += ","
177178
parameters = f", ({raw_parameters})"
178-
return (
179-
f"CurriedFunctionCall({repr(exp.alias)}, {internal_function}{parameters})"
180-
)
179+
return f"CurriedFunctionCall({repr(exp.alias)}, {internal_function}{parameters})"
181180

182181
def visit_argument(self, exp: Argument) -> str:
183182
return repr(exp)
184183

185184
def visit_lambda(self, exp: Lambda) -> str:
186185
return repr(exp)
187186

187+
def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> str:
188+
alias_str = f", {repr(exp.alias)}" if exp.alias else ", None"
189+
return f"DangerousRawSQL({alias_str}, {repr(exp.sql)})"
190+
188191
def visit_selected_expression(self, exp: SelectedExpression) -> str:
189192
return f"SelectedExpression({repr(exp.name)}, {exp.expression.accept(self)})"
190193

@@ -197,12 +200,7 @@ def visit_limitby(self, exp: LimitBy) -> str:
197200

198201

199202
def ast_repr(
200-
exp: (
201-
Expression
202-
| LimitBy
203-
| Sequence[Expression | SelectedExpression | OrderBy]
204-
| None
205-
),
203+
exp: Expression | LimitBy | Sequence[Expression | SelectedExpression | OrderBy] | None,
206204
visitor: DSLMapperVisitor,
207205
) -> str:
208206
if not exp:

snuba/query/expressions.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ def visit_argument(self, exp: Argument) -> TVisited:
128128
def visit_lambda(self, exp: Lambda) -> TVisited:
129129
raise NotImplementedError
130130

131+
@abstractmethod
132+
def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> TVisited:
133+
raise NotImplementedError
134+
131135

132136
class NoopVisitor(ExpressionVisitor[None]):
133137
"""A noop visitor that will traverse every node but will not
@@ -159,6 +163,9 @@ def visit_argument(self, exp: Argument) -> None:
159163
def visit_lambda(self, exp: Lambda) -> None:
160164
return exp.transformation.accept(self)
161165

166+
def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> None:
167+
return None
168+
162169

163170
class StringifyVisitor(ExpressionVisitor[str]):
164171
"""Visitor implementation to turn an expression into a string format
@@ -204,9 +211,7 @@ def visit_literal(self, exp: Literal) -> str:
204211

205212
def visit_column(self, exp: Column) -> str:
206213
column_str = (
207-
f"{exp.table_name}.{exp.column_name}"
208-
if exp.table_name
209-
else f"{exp.column_name}"
214+
f"{exp.table_name}.{exp.column_name}" if exp.table_name else f"{exp.column_name}"
210215
)
211216
return f"{self._get_line_prefix()}{column_str}{self._get_alias_str(exp)}"
212217

@@ -256,6 +261,10 @@ def visit_lambda(self, exp: Lambda) -> str:
256261
self.__level -= 1
257262
return f"{self._get_line_prefix()}({params_str}) ->\n{transformation_str}\n{self._get_line_prefix()}{self._get_alias_str(exp)}"
258263

264+
def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> str:
265+
sql_repr = repr(exp.sql)
266+
return f"{self._get_line_prefix()}DangerousRawSQL({sql_repr}){self._get_alias_str(exp)}"
267+
259268

260269
class ColumnVisitor(ExpressionVisitor[set[str]]):
261270
def __init__(self) -> None:
@@ -287,6 +296,9 @@ def visit_argument(self, exp: Argument) -> set[str]:
287296
def visit_lambda(self, exp: Lambda) -> set[str]:
288297
return exp.transformation.accept(self)
289298

299+
def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> set[str]:
300+
return self.columns
301+
290302

291303
OptionalScalarType = Union[None, bool, str, float, int, date, datetime]
292304

@@ -335,10 +347,7 @@ def accept(self, visitor: ExpressionVisitor[TVisited]) -> TVisited:
335347
def functional_eq(self, other: Expression) -> bool:
336348
if not isinstance(other, self.__class__):
337349
return False
338-
return (
339-
self.table_name == other.table_name
340-
and self.column_name == other.column_name
341-
)
350+
return self.table_name == other.table_name and self.column_name == other.column_name
342351

343352

344353
@dataclass(frozen=True, repr=_AUTO_REPR)
@@ -382,9 +391,7 @@ def __iter__(self) -> Iterator[Expression]:
382391
def functional_eq(self, other: Expression) -> bool:
383392
if not isinstance(other, self.__class__):
384393
return False
385-
return self.column.functional_eq(other.column) and self.key.functional_eq(
386-
other.key
387-
)
394+
return self.column.functional_eq(other.column) and self.key.functional_eq(other.key)
388395

389396

390397
@dataclass(frozen=True, repr=_AUTO_REPR)
@@ -572,3 +579,31 @@ def functional_eq(self, other: Expression) -> bool:
572579
if not self.transformation.functional_eq(other.transformation):
573580
return False
574581
return True
582+
583+
584+
@dataclass(frozen=True, repr=_AUTO_REPR)
585+
class DangerousRawSQL(Expression):
586+
"""
587+
Represents raw SQL that should be passed through directly to ClickHouse
588+
without any escaping or validation. This is intended for query optimization
589+
scenarios where the SQL is generated programmatically and already safe.
590+
591+
WARNING: This expression type bypasses all safety checks. Only use when
592+
the SQL content is guaranteed to be safe and properly formatted.
593+
"""
594+
595+
sql: str
596+
597+
def transform(self, func: Callable[[Expression], Expression]) -> Expression:
598+
return func(self)
599+
600+
def __iter__(self) -> Iterator[Expression]:
601+
yield self
602+
603+
def accept(self, visitor: ExpressionVisitor[TVisited]) -> TVisited:
604+
return visitor.visit_dangerous_raw_sql(self)
605+
606+
def functional_eq(self, other: Expression) -> bool:
607+
if not isinstance(other, self.__class__):
608+
return False
609+
return self.sql == other.sql

snuba/query/joins/classifier.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Argument,
1919
Column,
2020
CurriedFunctionCall,
21+
DangerousRawSQL,
2122
Expression,
2223
ExpressionVisitor,
2324
FunctionCall,
@@ -127,9 +128,7 @@ class UnclassifiedExpression(SubExpression):
127128
"""
128129

129130
def cut_branch(self, alias_generator: AliasGenerator) -> MainQueryExpression:
130-
return MainQueryExpression(
131-
main_expression=self.main_expression, cut_branches={}
132-
)
131+
return MainQueryExpression(main_expression=self.main_expression, cut_branches={})
133132

134133

135134
def _merge_subexpressions(
@@ -158,9 +157,7 @@ def _merge_subexpressions(
158157
if not subqueries:
159158
# All parameters are not classified. This function is also
160159
# not classified.
161-
return UnclassifiedExpression(
162-
builder([v.main_expression for v in sub_expressions])
163-
)
160+
return UnclassifiedExpression(builder([v.main_expression for v in sub_expressions]))
164161
else:
165162
# All parameters are either not classified or in a single
166163
# subquery. This function is also referencing that subquery
@@ -228,16 +225,10 @@ def visit_literal(self, exp: Literal) -> SubExpression:
228225
return UnclassifiedExpression(exp)
229226

230227
def visit_column(self, exp: Column) -> SubExpression:
231-
assert (
232-
exp.table_name
233-
), f"Invalid column expression in join: {exp}. Missing table alias"
234-
return SubqueryExpression(
235-
Column(exp.alias, None, exp.column_name), exp.table_name
236-
)
228+
assert exp.table_name, f"Invalid column expression in join: {exp}. Missing table alias"
229+
return SubqueryExpression(Column(exp.alias, None, exp.column_name), exp.table_name)
237230

238-
def visit_subscriptable_reference(
239-
self, exp: SubscriptableReference
240-
) -> SubExpression:
231+
def visit_subscriptable_reference(self, exp: SubscriptableReference) -> SubExpression:
241232
assert (
242233
exp.column.table_name
243234
), f"Invalid column expression in join: {exp}. Missing table alias"
@@ -303,11 +294,12 @@ def visit_lambda(self, exp: Lambda) -> SubExpression:
303294
transformed = exp.transformation.accept(self)
304295
return replace(
305296
transformed,
306-
main_expression=Lambda(
307-
exp.alias, exp.parameters, transformed.main_expression
308-
),
297+
main_expression=Lambda(exp.alias, exp.parameters, transformed.main_expression),
309298
)
310299

300+
def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> SubExpression:
301+
return UnclassifiedExpression(exp)
302+
311303

312304
class AggregateBranchCutter(BranchCutter):
313305
"""

snuba/query/parser/__init__.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Argument,
99
Column,
1010
CurriedFunctionCall,
11+
DangerousRawSQL,
1112
Expression,
1213
ExpressionVisitor,
1314
FunctionCall,
@@ -36,10 +37,7 @@ def validate_aliases(query: Union[CompositeQuery[LogicalDataSource], Query]) ->
3637
# happening.
3738
metrics.increment("empty_alias")
3839

39-
if (
40-
exp.alias in all_declared_aliases
41-
and exp != all_declared_aliases[exp.alias]
42-
):
40+
if exp.alias in all_declared_aliases and exp != all_declared_aliases[exp.alias]:
4341
raise AliasShadowingException(
4442
(
4543
f"Shadowing aliases detected for alias: {exp.alias}. "
@@ -51,9 +49,7 @@ def validate_aliases(query: Union[CompositeQuery[LogicalDataSource], Query]) ->
5149
all_declared_aliases[exp.alias] = exp
5250

5351

54-
def parse_subscriptables(
55-
query: Union[CompositeQuery[LogicalDataSource], Query]
56-
) -> None:
52+
def parse_subscriptables(query: Union[CompositeQuery[LogicalDataSource], Query]) -> None:
5753
"""
5854
Turns columns formatted as tags[asd] into SubscriptableReference.
5955
"""
@@ -77,9 +73,7 @@ def transform(exp: Expression) -> Expression:
7773
query.transform_expressions(transform)
7874

7975

80-
def apply_column_aliases(
81-
query: Union[CompositeQuery[LogicalDataSource], Query]
82-
) -> None:
76+
def apply_column_aliases(query: Union[CompositeQuery[LogicalDataSource], Query]) -> None:
8377
"""
8478
Applies an alias to all the columns in the query equal to the column
8579
name unless a column already has one or the alias is already defined.
@@ -92,11 +86,7 @@ def apply_column_aliases(
9286
current_aliases = {exp.alias for exp in query.get_all_expressions() if exp.alias}
9387

9488
def apply_aliases(exp: Expression) -> Expression:
95-
if (
96-
not isinstance(exp, Column)
97-
or exp.alias
98-
or exp.column_name in current_aliases
99-
):
89+
if not isinstance(exp, Column) or exp.alias or exp.column_name in current_aliases:
10090
return exp
10191
else:
10292
return replace(exp, alias=exp.column_name)
@@ -119,9 +109,7 @@ def expand_aliases(query: Union[CompositeQuery[LogicalDataSource], Query]) -> No
119109
exp.alias: exp for exp in query.get_all_expressions() if exp.alias is not None
120110
}
121111
fully_resolved_aliases = {
122-
alias: exp.accept(
123-
AliasExpanderVisitor(aliased_expressions, [], expand_nested=True)
124-
)
112+
alias: exp.accept(AliasExpanderVisitor(aliased_expressions, [], expand_nested=True))
125113
for alias, exp in aliased_expressions.items()
126114
}
127115

@@ -194,11 +182,7 @@ def visit_column(self, exp: Column) -> Expression:
194182
return self.__alias_lookup_table[name]
195183

196184
def __append_alias(self, alias: Optional[str]) -> Sequence[str]:
197-
return (
198-
[*self.__visited_stack, alias]
199-
if alias is not None
200-
else self.__visited_stack
201-
)
185+
return [*self.__visited_stack, alias] if alias is not None else self.__visited_stack
202186

203187
def visit_subscriptable_reference(self, exp: SubscriptableReference) -> Expression:
204188
expanded_column = exp.column.accept(
@@ -267,3 +251,6 @@ def visit_lambda(self, exp: Lambda) -> Expression:
267251
)
268252
),
269253
)
254+
255+
def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> Expression:
256+
return exp

0 commit comments

Comments
 (0)