Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions deepnote_toolkit/sql/jinjasql_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def render_jinja_sql_template(template, param_style=None):
str: The rendered SQL query.
"""

escaped_template = _escape_jinja_template(template)

# Default to pyformat for backwards compatibility
# Note: Some databases like Trino require "qmark" or "format" style
jinja_sql = JinjaSql(
param_style=param_style if param_style is not None else "pyformat"
)
effective_param_style = param_style if param_style is not None else "pyformat"

escaped_template = _escape_jinja_template(template, effective_param_style)

jinja_sql = JinjaSql(param_style=effective_param_style)
parsed_content = jinja_sql.env.parse(escaped_template)
required_variables = meta.find_undeclared_variables(parsed_content)
jinja_sql_data = {
Expand All @@ -40,9 +40,14 @@ def _get_variable_value(variable_name):
return getattr(__main__, variable_name)


def _escape_jinja_template(template):
def _escape_jinja_template(template, param_style: str = "pyformat"):
# see https://github.com/sripathikrishnan/jinjasql/issues/28 and https://stackoverflow.com/q/8657508/2761695
# we have to replace % by %% in the SQL query due to how SQL alchemy interprets %
# but only if the { is not preceded by { or followed by }, because those are jinja blocks
# we use lookbehind ?<= and lookahead ?= regex matchers to capture the { and } symbols
return re.sub(r"(?<=[^{])%(?=[^}])", "%%", template)
# but ONLY for param styles that use % (format and pyformat)
# For other param styles (qmark), % has no special meaning
# and should not be escaped (e.g., in date format strings like '%m-%d-%Y')
if param_style in ("format", "pyformat"):
# Only escape % if it's not part of a jinja block (not preceded by { or followed by })
# we use lookbehind ?<= and lookahead ?= regex matchers to capture the { and } symbols
return re.sub(r"(?<=[^{])%(?=[^}])", "%%", template)
return template
9 changes: 5 additions & 4 deletions deepnote_toolkit/sql/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ class ExecuteSqlError(Exception):
}
param_style = dialect_param_styles.get(url_obj.drivername)

if requires_duckdb and param_style is None:
# DuckDB uses the DB-API qmark style (`?` placeholders)
param_style = "qmark"

skip_template_render = re.search(
"^snowflake.*host=.*.proxy.cloud.getdbt.com", sql_alchemy_dict["url"]
)
Expand Down Expand Up @@ -294,10 +298,7 @@ def _execute_sql_with_caching(
):
# duckdb SQL is not cached, so we can skip the logic below for duckdb
if requires_duckdb:
# duckdb requires % to be unescaped, but other dialects require it to be escaped as %%
# https://docs.sqlalchemy.org/en/14/faq/sqlexpressions.html#why-are-percent-signs-being-doubled-up-when-stringifying-sql-statements
query_unescaped = query % () if query else query
dataframe = execute_duckdb_sql(query_unescaped, bind_params)
dataframe = execute_duckdb_sql(query, bind_params)
# for Chained SQL we return the dataframe with the SQL source attached as DeepnoteQueryPreview object
if return_variable_type == "query_preview":
return _convert_dataframe_to_query_preview(dataframe, query_preview_source)
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_jinjasql_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,32 @@ def test_qmark_format(self):
self.assertEqual(query.strip(), "SELECT * FROM users WHERE id = ?")
self.assertEqual(bind_params, ["test"])

def test_qmark_escaping(self):
template = "SELECT date_format(TIMESTAMP '2022-10-20 05:10:00', '%m-%d-%Y %H')"

query, bind_params = render_jinja_sql_template(template, param_style="qmark")

self.assertEqual(query, template)
self.assertEqual(bind_params, [])

def test_pyformat_escaping(self):
query, bind_params = render_jinja_sql_template(
"SELECT '% character'",
param_style="pyformat",
)

self.assertEqual(query, "SELECT '%% character'")
self.assertEqual(bind_params, {})

def test_format_escaping(self):
query, bind_params = render_jinja_sql_template(
"SELECT '% character'",
param_style="format",
)

self.assertEqual(query, "SELECT '%% character'")
self.assertEqual(bind_params, [])


if __name__ == "__main__":
unittest.main()