Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions snuba/clickhouse/formatter/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Argument,
Column,
CurriedFunctionCall,
DangerousRawSQL,
Expression,
ExpressionVisitor,
FunctionCall,
Expand Down Expand Up @@ -179,6 +180,14 @@ def visit_lambda(self, exp: Lambda) -> str:
ret = f"{', '.join(parameters)} -> {exp.transformation.accept(self)}"
return self._alias(ret, exp.alias)

def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> str:
"""
Format DangerousRawSQL by passing through the SQL content directly without
any escaping or validation. This is intentional as DangerousRawSQL is meant
for pre-validated SQL in query optimization scenarios.
"""
return self._alias(exp.sql, exp.alias)


class ClickhouseExpressionFormatter(ExpressionFormatterBase):
"""
Expand Down
6 changes: 6 additions & 0 deletions snuba/clickhouse/translators/snuba/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Argument,
Column,
CurriedFunctionCall,
DangerousRawSQL,
FunctionCall,
Lambda,
Literal,
Expand Down Expand Up @@ -176,6 +177,11 @@ def visit_lambda(self, exp: Lambda) -> Expression:
self.__cache[exp] = ret
return ret

def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> Expression:
# DangerousRawSQL is passed through unchanged during translation
# since it contains pre-formatted SQL that should not be modified
return exp

def translate_function_strict(self, exp: FunctionCall) -> FunctionCall:
"""
Unfortunately it is not possible to avoid this assertion.
Expand Down
16 changes: 7 additions & 9 deletions snuba/query/dsl_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Argument,
Column,
CurriedFunctionCall,
DangerousRawSQL,
Expression,
ExpressionVisitor,
FunctionCall,
Expand Down Expand Up @@ -175,16 +176,18 @@ def visit_curried_function_call(self, exp: CurriedFunctionCall) -> str:
if len(exp.parameters) == 1:
raw_parameters += ","
parameters = f", ({raw_parameters})"
return (
f"CurriedFunctionCall({repr(exp.alias)}, {internal_function}{parameters})"
)
return f"CurriedFunctionCall({repr(exp.alias)}, {internal_function}{parameters})"

def visit_argument(self, exp: Argument) -> str:
return repr(exp)

def visit_lambda(self, exp: Lambda) -> str:
return repr(exp)

def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> str:
alias_str = f", {repr(exp.alias)}" if exp.alias else ", None"
return f"DangerousRawSQL({alias_str}, {repr(exp.sql)})"

def visit_selected_expression(self, exp: SelectedExpression) -> str:
return f"SelectedExpression({repr(exp.name)}, {exp.expression.accept(self)})"

Expand All @@ -197,12 +200,7 @@ def visit_limitby(self, exp: LimitBy) -> str:


def ast_repr(
exp: (
Expression
| LimitBy
| Sequence[Expression | SelectedExpression | OrderBy]
| None
),
exp: Expression | LimitBy | Sequence[Expression | SelectedExpression | OrderBy] | None,
visitor: DSLMapperVisitor,
) -> str:
if not exp:
Expand Down
55 changes: 45 additions & 10 deletions snuba/query/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def visit_argument(self, exp: Argument) -> TVisited:
def visit_lambda(self, exp: Lambda) -> TVisited:
raise NotImplementedError

@abstractmethod
def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> TVisited:
raise NotImplementedError


class NoopVisitor(ExpressionVisitor[None]):
"""A noop visitor that will traverse every node but will not
Expand Down Expand Up @@ -159,6 +163,9 @@ def visit_argument(self, exp: Argument) -> None:
def visit_lambda(self, exp: Lambda) -> None:
return exp.transformation.accept(self)

def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> None:
return None


class StringifyVisitor(ExpressionVisitor[str]):
"""Visitor implementation to turn an expression into a string format
Expand Down Expand Up @@ -204,9 +211,7 @@ def visit_literal(self, exp: Literal) -> str:

def visit_column(self, exp: Column) -> str:
column_str = (
f"{exp.table_name}.{exp.column_name}"
if exp.table_name
else f"{exp.column_name}"
f"{exp.table_name}.{exp.column_name}" if exp.table_name else f"{exp.column_name}"
)
return f"{self._get_line_prefix()}{column_str}{self._get_alias_str(exp)}"

Expand Down Expand Up @@ -256,6 +261,10 @@ def visit_lambda(self, exp: Lambda) -> str:
self.__level -= 1
return f"{self._get_line_prefix()}({params_str}) ->\n{transformation_str}\n{self._get_line_prefix()}{self._get_alias_str(exp)}"

def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> str:
sql_repr = repr(exp.sql)
return f"{self._get_line_prefix()}DangerousRawSQL({sql_repr}){self._get_alias_str(exp)}"


class ColumnVisitor(ExpressionVisitor[set[str]]):
def __init__(self) -> None:
Expand Down Expand Up @@ -287,6 +296,9 @@ def visit_argument(self, exp: Argument) -> set[str]:
def visit_lambda(self, exp: Lambda) -> set[str]:
return exp.transformation.accept(self)

def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> set[str]:
return self.columns


OptionalScalarType = Union[None, bool, str, float, int, date, datetime]

Expand Down Expand Up @@ -335,10 +347,7 @@ def accept(self, visitor: ExpressionVisitor[TVisited]) -> TVisited:
def functional_eq(self, other: Expression) -> bool:
if not isinstance(other, self.__class__):
return False
return (
self.table_name == other.table_name
and self.column_name == other.column_name
)
return self.table_name == other.table_name and self.column_name == other.column_name


@dataclass(frozen=True, repr=_AUTO_REPR)
Expand Down Expand Up @@ -382,9 +391,7 @@ def __iter__(self) -> Iterator[Expression]:
def functional_eq(self, other: Expression) -> bool:
if not isinstance(other, self.__class__):
return False
return self.column.functional_eq(other.column) and self.key.functional_eq(
other.key
)
return self.column.functional_eq(other.column) and self.key.functional_eq(other.key)


@dataclass(frozen=True, repr=_AUTO_REPR)
Expand Down Expand Up @@ -572,3 +579,31 @@ def functional_eq(self, other: Expression) -> bool:
if not self.transformation.functional_eq(other.transformation):
return False
return True


@dataclass(frozen=True, repr=_AUTO_REPR)
class DangerousRawSQL(Expression):
"""
Represents raw SQL that should be passed through directly to ClickHouse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RawSQL / DangerousSQL / UnescapedSQL might be better to make usage of this stick out in bad situations

without any escaping or validation. This is intended for query optimization
scenarios where the SQL is generated programmatically and already safe.

WARNING: This expression type bypasses all safety checks. Only use when
the SQL content is guaranteed to be safe and properly formatted.
"""

sql: str

def transform(self, func: Callable[[Expression], Expression]) -> Expression:
return func(self)

def __iter__(self) -> Iterator[Expression]:
yield self

def accept(self, visitor: ExpressionVisitor[TVisited]) -> TVisited:
return visitor.visit_dangerous_raw_sql(self)

def functional_eq(self, other: Expression) -> bool:
if not isinstance(other, self.__class__):
return False
return self.sql == other.sql
28 changes: 10 additions & 18 deletions snuba/query/joins/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Argument,
Column,
CurriedFunctionCall,
DangerousRawSQL,
Expression,
ExpressionVisitor,
FunctionCall,
Expand Down Expand Up @@ -127,9 +128,7 @@ class UnclassifiedExpression(SubExpression):
"""

def cut_branch(self, alias_generator: AliasGenerator) -> MainQueryExpression:
return MainQueryExpression(
main_expression=self.main_expression, cut_branches={}
)
return MainQueryExpression(main_expression=self.main_expression, cut_branches={})


def _merge_subexpressions(
Expand Down Expand Up @@ -158,9 +157,7 @@ def _merge_subexpressions(
if not subqueries:
# All parameters are not classified. This function is also
# not classified.
return UnclassifiedExpression(
builder([v.main_expression for v in sub_expressions])
)
return UnclassifiedExpression(builder([v.main_expression for v in sub_expressions]))
else:
# All parameters are either not classified or in a single
# subquery. This function is also referencing that subquery
Expand Down Expand Up @@ -228,16 +225,10 @@ def visit_literal(self, exp: Literal) -> SubExpression:
return UnclassifiedExpression(exp)

def visit_column(self, exp: Column) -> SubExpression:
assert (
exp.table_name
), f"Invalid column expression in join: {exp}. Missing table alias"
return SubqueryExpression(
Column(exp.alias, None, exp.column_name), exp.table_name
)
assert exp.table_name, f"Invalid column expression in join: {exp}. Missing table alias"
return SubqueryExpression(Column(exp.alias, None, exp.column_name), exp.table_name)

def visit_subscriptable_reference(
self, exp: SubscriptableReference
) -> SubExpression:
def visit_subscriptable_reference(self, exp: SubscriptableReference) -> SubExpression:
assert (
exp.column.table_name
), f"Invalid column expression in join: {exp}. Missing table alias"
Expand Down Expand Up @@ -303,11 +294,12 @@ def visit_lambda(self, exp: Lambda) -> SubExpression:
transformed = exp.transformation.accept(self)
return replace(
transformed,
main_expression=Lambda(
exp.alias, exp.parameters, transformed.main_expression
),
main_expression=Lambda(exp.alias, exp.parameters, transformed.main_expression),
)

def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> SubExpression:
return UnclassifiedExpression(exp)


class AggregateBranchCutter(BranchCutter):
"""
Expand Down
33 changes: 10 additions & 23 deletions snuba/query/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Argument,
Column,
CurriedFunctionCall,
DangerousRawSQL,
Expression,
ExpressionVisitor,
FunctionCall,
Expand Down Expand Up @@ -36,10 +37,7 @@ def validate_aliases(query: Union[CompositeQuery[LogicalDataSource], Query]) ->
# happening.
metrics.increment("empty_alias")

if (
exp.alias in all_declared_aliases
and exp != all_declared_aliases[exp.alias]
):
if exp.alias in all_declared_aliases and exp != all_declared_aliases[exp.alias]:
raise AliasShadowingException(
(
f"Shadowing aliases detected for alias: {exp.alias}. "
Expand All @@ -51,9 +49,7 @@ def validate_aliases(query: Union[CompositeQuery[LogicalDataSource], Query]) ->
all_declared_aliases[exp.alias] = exp


def parse_subscriptables(
query: Union[CompositeQuery[LogicalDataSource], Query]
) -> None:
def parse_subscriptables(query: Union[CompositeQuery[LogicalDataSource], Query]) -> None:
"""
Turns columns formatted as tags[asd] into SubscriptableReference.
"""
Expand All @@ -77,9 +73,7 @@ def transform(exp: Expression) -> Expression:
query.transform_expressions(transform)


def apply_column_aliases(
query: Union[CompositeQuery[LogicalDataSource], Query]
) -> None:
def apply_column_aliases(query: Union[CompositeQuery[LogicalDataSource], Query]) -> None:
"""
Applies an alias to all the columns in the query equal to the column
name unless a column already has one or the alias is already defined.
Expand All @@ -92,11 +86,7 @@ def apply_column_aliases(
current_aliases = {exp.alias for exp in query.get_all_expressions() if exp.alias}

def apply_aliases(exp: Expression) -> Expression:
if (
not isinstance(exp, Column)
or exp.alias
or exp.column_name in current_aliases
):
if not isinstance(exp, Column) or exp.alias or exp.column_name in current_aliases:
return exp
else:
return replace(exp, alias=exp.column_name)
Expand All @@ -119,9 +109,7 @@ def expand_aliases(query: Union[CompositeQuery[LogicalDataSource], Query]) -> No
exp.alias: exp for exp in query.get_all_expressions() if exp.alias is not None
}
fully_resolved_aliases = {
alias: exp.accept(
AliasExpanderVisitor(aliased_expressions, [], expand_nested=True)
)
alias: exp.accept(AliasExpanderVisitor(aliased_expressions, [], expand_nested=True))
for alias, exp in aliased_expressions.items()
}

Expand Down Expand Up @@ -194,11 +182,7 @@ def visit_column(self, exp: Column) -> Expression:
return self.__alias_lookup_table[name]

def __append_alias(self, alias: Optional[str]) -> Sequence[str]:
return (
[*self.__visited_stack, alias]
if alias is not None
else self.__visited_stack
)
return [*self.__visited_stack, alias] if alias is not None else self.__visited_stack

def visit_subscriptable_reference(self, exp: SubscriptableReference) -> Expression:
expanded_column = exp.column.accept(
Expand Down Expand Up @@ -267,3 +251,6 @@ def visit_lambda(self, exp: Lambda) -> Expression:
)
),
)

def visit_dangerous_raw_sql(self, exp: DangerousRawSQL) -> Expression:
return exp
Loading
Loading