diff --git a/.gitignore b/.gitignore index 64d547b..8f035c1 100644 --- a/.gitignore +++ b/.gitignore @@ -133,6 +133,7 @@ celerybeat.pid # Environments .env +.env.* .venv env/ venv/ diff --git a/deepnote_toolkit/sql/jinjasql_utils.py b/deepnote_toolkit/sql/jinjasql_utils.py index 7879f05..df34521 100644 --- a/deepnote_toolkit/sql/jinjasql_utils.py +++ b/deepnote_toolkit/sql/jinjasql_utils.py @@ -14,6 +14,7 @@ def render_jinja_sql_template(template, param_style=None): Args: template (str): The Jinja SQL template to render. param_style (str, optional): The parameter style to use. Defaults to "pyformat". + Common styles: "qmark" (?), "format" (%s), "pyformat" (%(name)s) Returns: str: The rendered SQL query. @@ -21,6 +22,8 @@ def render_jinja_sql_template(template, param_style=None): escaped_template = _escape_jinja_template(template) + # Default to pyformat for backwards compatibility + # Note: Some databases like Trino require "qmark" or "format" style jinja_sql = JinjaSql( param_style=param_style if param_style is not None else "pyformat" ) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 7f51e3e..940fded 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -150,6 +150,16 @@ class ExecuteSqlError(Exception): del sql_alchemy_dict["params"]["snowflake_private_key_passphrase"] param_style = sql_alchemy_dict.get("param_style") + + # Auto-detect param_style for databases that don't support pyformat default + if param_style is None: + url_obj = make_url(sql_alchemy_dict["url"]) + # Mapping of SQLAlchemy dialect names to their required param_style + dialect_param_styles = { + "trino": "qmark", # Trino requires ? placeholders with list/tuple params + } + param_style = dialect_param_styles.get(url_obj.drivername) + skip_template_render = re.search( "^snowflake.*host=.*.proxy.cloud.getdbt.com", sql_alchemy_dict["url"] ) @@ -425,10 +435,15 @@ def _execute_sql_on_engine(engine, query, bind_params): connection.connection if needs_raw_connection else connection ) + # pandas.read_sql_query expects params as tuple (not list) for qmark/format style + params_for_pandas = ( + tuple(bind_params) if isinstance(bind_params, list) else bind_params + ) + return pd.read_sql_query( query, con=connection_for_pandas, - params=bind_params, + params=params_for_pandas, coerce_float=coerce_float, ) except ResourceClosedError: diff --git a/poetry.lock b/poetry.lock index 4c79c52..eb7c788 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.0 and should not be changed by hand. [[package]] name = "alembic" @@ -4875,6 +4875,21 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-dotenv" +version = "1.2.1" +description = "Read key-value pairs from a .env file and set them as environment variables" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61"}, + {file = "python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6"}, +] + +[package.extras] +cli = ["click (>=5.0)"] + [[package]] name = "python-json-logger" version = "2.0.7" diff --git a/pyproject.toml b/pyproject.toml index 82f743d..91c848a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,7 +176,8 @@ dev = [ "poetry-dynamic-versioning>=1.4.0,<2.0.0", "twine>=6.1.0,<7.0.0", "codespell>=2.3.0,<3.0.0", - "pytest-subtests>=0.15.0,<0.16.0" + "pytest-subtests>=0.15.0,<0.16.0", + "python-dotenv>=1.2.1,<2.0.0" ] license-check = [ # Dependencies needed for license checking that aren't in main production dependencies diff --git a/tests/integration/test_trino.py b/tests/integration/test_trino.py new file mode 100644 index 0000000..695d6d7 --- /dev/null +++ b/tests/integration/test_trino.py @@ -0,0 +1,227 @@ +import json +import os +from contextlib import contextmanager +from pathlib import Path +from unittest import mock +from urllib.parse import quote + +import pandas as pd +import pytest +from dotenv import load_dotenv +from trino import dbapi +from trino.auth import BasicAuthentication + +from deepnote_toolkit import env as dnenv +from deepnote_toolkit.sql.sql_execution import execute_sql + + +@contextmanager +def use_trino_sql_connection(connection_json, env_var_name="TEST_TRINO_CONNECTION"): + dnenv.set_env(env_var_name, connection_json) + try: + yield env_var_name + finally: + dnenv.unset_env(env_var_name) + + +@pytest.fixture(scope="module") +def trino_credentials(): + env_path = Path(__file__).parent.parent.parent / ".env" + + if env_path.exists(): + load_dotenv(env_path) + + host = os.getenv("TRINO_HOST") + port = os.getenv("TRINO_PORT", "8080") + user = os.getenv("TRINO_USER") + password = os.getenv("TRINO_PASSWORD") + catalog = os.getenv("TRINO_CATALOG", "system") + schema = os.getenv("TRINO_SCHEMA", "runtime") + http_scheme = os.getenv("TRINO_HTTP_SCHEME", "https") + + if not host or not user: + pytest.skip( + "Trino credentials not found. " + "Please set TRINO_HOST and TRINO_USER in .env file" + ) + + return { + "host": host, + "port": int(port), + "user": user, + "password": password, + "catalog": catalog, + "schema": schema, + "http_scheme": http_scheme, + } + + +@pytest.fixture(scope="module") +def trino_connection(trino_credentials): + auth = None + + if trino_credentials["password"]: + auth = BasicAuthentication( + trino_credentials["user"], trino_credentials["password"] + ) + + conn = dbapi.connect( + host=trino_credentials["host"], + port=trino_credentials["port"], + user=trino_credentials["user"], + auth=auth, + http_scheme=trino_credentials["http_scheme"], + catalog=trino_credentials["catalog"], + schema=trino_credentials["schema"], + ) + + try: + yield conn + finally: + conn.close() + + +class TestTrinoConnection: + """Test Trino database connection.""" + + def test_connection_established(self, trino_connection): + """Test that connection to Trino is established.""" + cursor = trino_connection.cursor() + cursor.execute("SELECT 1") + result = cursor.fetchone() + + assert result is not None + assert result[0] == 1 + + cursor.close() + + def test_show_catalogs(self, trino_connection): + """Test listing available catalogs.""" + cursor = trino_connection.cursor() + cursor.execute("SHOW CATALOGS") + catalogs = cursor.fetchall() + + assert len(catalogs) > 0 + assert any("system" in str(catalog) for catalog in catalogs) + + cursor.close() + + +@pytest.fixture +def trino_toolkit_connection(trino_credentials): + """Create a Trino connection JSON for deepnote toolkit.""" + username = quote(trino_credentials["user"], safe="") + password_part = ( + f":{quote(trino_credentials['password'], safe='')}" + if trino_credentials["password"] + else "" + ) + connection_url = ( + f"trino://{username}{password_part}" + f"@{trino_credentials['host']}:{trino_credentials['port']}" + f"/{trino_credentials['catalog']}/{trino_credentials['schema']}" + ) + + # Trino uses `qmark` paramstyle (`?` placeholders with list/tuple params), not pyformat, which is the default + connection_json = json.dumps( + { + "url": connection_url, + "params": {}, + "param_style": "qmark", + } + ) + + with use_trino_sql_connection(connection_json) as env_var_name: + yield env_var_name + + +class TestTrinoWithDeepnoteToolkit: + """Test Trino connection using Toolkit's SQL execution.""" + + def test_execute_sql_simple_query(self, trino_toolkit_connection): + result = execute_sql( + template="SELECT 1 as test_value", + sql_alchemy_json_env_var=trino_toolkit_connection, + ) + + assert isinstance(result, pd.DataFrame) + assert len(result) == 1 + assert "test_value" in result.columns + assert result["test_value"].iloc[0] == 1 + + def test_execute_sql_with_jinja_template(self, trino_toolkit_connection): + test_string = "test string" + test_number = 123 + + def mock_get_variable_value(variable_name): + variables = { + "test_string_var": test_string, + "test_number_var": test_number, + } + return variables[variable_name] + + with mock.patch( + "deepnote_toolkit.sql.jinjasql_utils._get_variable_value", + side_effect=mock_get_variable_value, + ): + result = execute_sql( + template="SELECT {{test_string_var}} as message, {{test_number_var}} as number", + sql_alchemy_json_env_var=trino_toolkit_connection, + ) + + assert isinstance(result, pd.DataFrame) + assert len(result) == 1 + assert "message" in result.columns + assert "number" in result.columns + assert result["message"].iloc[0] == test_string + assert result["number"].iloc[0] == test_number + + def test_execute_sql_with_autodetection(self, trino_credentials): + """ + Test execute_sql with auto-detection of param_style + (regression reported in BLU-5135) + + This simulates the real-world scenario where the backend provides a connection + JSON without explicit param_style, and Toolkit must auto-detect it. + """ + + username = quote(trino_credentials["user"], safe="") + password_part = ( + f":{quote(trino_credentials['password'], safe='')}" + if trino_credentials["password"] + else "" + ) + connection_url = ( + f"trino://{username}{password_part}" + f"@{trino_credentials['host']}:{trino_credentials['port']}" + f"/{trino_credentials['catalog']}/{trino_credentials['schema']}" + ) + + connection_json = json.dumps( + { + "url": connection_url, + "params": {}, + # NO param_style - should auto-detect to `qmark` for Trino + } + ) + + test_value = "test value" + + with ( + use_trino_sql_connection( + connection_json, "TEST_TRINO_AUTODETECT" + ) as env_var_name, + mock.patch( + "deepnote_toolkit.sql.jinjasql_utils._get_variable_value", + return_value=test_value, + ), + ): + result = execute_sql( + template="SELECT {{test_var}} as detected", + sql_alchemy_json_env_var=env_var_name, + ) + + assert isinstance(result, pd.DataFrame) + assert len(result) == 1 + assert "detected" in result.columns + assert result["detected"].iloc[0] == test_value diff --git a/tests/unit/test_sql_execution.py b/tests/unit/test_sql_execution.py index 3285613..f17543e 100644 --- a/tests/unit/test_sql_execution.py +++ b/tests/unit/test_sql_execution.py @@ -278,6 +278,236 @@ def test_execute_sql_with_connection_json_with_snowflake_encrypted_private_key( ) +class TestTrinoParamStyleAutoDetection(TestCase): + """Tests for auto-detection of param_style for Trino connections""" + + @mock.patch("deepnote_toolkit.sql.sql_execution.compile_sql_query") + @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") + def test_trino_url_auto_detects_qmark_param_style( + self, mocked_query_data_source, mocked_compile_sql_query + ): + """Test that Trino URLs automatically get 'qmark' param_style when not specified""" + mock_df = pd.DataFrame({"col1": [1, 2, 3]}) + mocked_query_data_source.return_value = mock_df + mocked_compile_sql_query.return_value = ( + "SELECT * FROM test_table", + {}, + "SELECT * FROM test_table", + ) + + sql_alchemy_json = json.dumps( + { + "url": "trino://user@localhost:8080/catalog", + "params": {}, + "integration_id": "test_integration", + } + ) + + execute_sql_with_connection_json("SELECT * FROM test_table", sql_alchemy_json) + + # Verify compile_sql_query was called with 'qmark' param_style + mocked_compile_sql_query.assert_called_once() + call_args = mocked_compile_sql_query.call_args[0] + self.assertEqual(call_args[2], "qmark") + + @mock.patch("deepnote_toolkit.sql.sql_execution.compile_sql_query") + @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") + def test_non_trino_url_param_style_remains_none( + self, mocked_query_data_source, mocked_compile_sql_query + ): + """Test that non-Trino databases don't get auto-detected param_style""" + mock_df = pd.DataFrame({"col1": [1, 2, 3]}) + mocked_query_data_source.return_value = mock_df + mocked_compile_sql_query.return_value = ( + "SELECT * FROM test_table", + {}, + "SELECT * FROM test_table", + ) + + sql_alchemy_json = json.dumps( + { + "url": "postgresql://user:pass@localhost:5432/mydb", + "params": {}, + "integration_id": "test_integration", + } + ) + + execute_sql_with_connection_json("SELECT * FROM test_table", sql_alchemy_json) + + # Verify compile_sql_query was called with None param_style + mocked_compile_sql_query.assert_called_once() + call_args = mocked_compile_sql_query.call_args[0] + self.assertIsNone(call_args[2]) + + @mock.patch("deepnote_toolkit.sql.sql_execution.compile_sql_query") + @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") + def test_explicit_param_style_not_overridden( + self, mocked_query_data_source, mocked_compile_sql_query + ): + """Test that explicitly set param_style is preserved and not auto-detected""" + mock_df = pd.DataFrame({"col1": [1, 2, 3]}) + mocked_query_data_source.return_value = mock_df + mocked_compile_sql_query.return_value = ( + "SELECT * FROM test_table", + {}, + "SELECT * FROM test_table", + ) + + # Trino URL with explicit pyformat - should NOT be changed to qmark + sql_alchemy_json = json.dumps( + { + "url": "trino://user@localhost:8080/catalog", + "params": {}, + "param_style": "pyformat", + "integration_id": "test_integration", + } + ) + + execute_sql_with_connection_json("SELECT * FROM test_table", sql_alchemy_json) + + # Verify compile_sql_query was called with 'pyformat', NOT 'qmark' + mocked_compile_sql_query.assert_called_once() + call_args = mocked_compile_sql_query.call_args[0] + self.assertEqual(call_args[2], "pyformat") + + @mock.patch("deepnote_toolkit.sql.sql_execution.compile_sql_query") + @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") + def test_trino_url_with_protocol_suffix_not_matched( + self, mocked_query_data_source, mocked_compile_sql_query + ): + """Test that Trino URL variants like trino+rest:// don't match (drivername must be exactly 'trino')""" + mock_df = pd.DataFrame({"col1": [1, 2, 3]}) + mocked_query_data_source.return_value = mock_df + mocked_compile_sql_query.return_value = ( + "SELECT * FROM test_table", + {}, + "SELECT * FROM test_table", + ) + + sql_alchemy_json = json.dumps( + { + "url": "trino+rest://user@localhost:8080/catalog", + "params": {}, + "integration_id": "test_integration", + } + ) + + execute_sql_with_connection_json("SELECT * FROM test_table", sql_alchemy_json) + + # Verify compile_sql_query was called with None param_style + # because "trino+rest" doesn't match "trino" in the dictionary + mocked_compile_sql_query.assert_called_once() + call_args = mocked_compile_sql_query.call_args[0] + self.assertIsNone(call_args[2]) + + @mock.patch("deepnote_toolkit.sql.sql_execution.render_jinja_sql_template") + @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") + def test_trino_with_jinja_templates_uses_qmark( + self, mocked_query_data_source, mocked_render_jinja + ): + """Test that Trino queries with Jinja templates correctly use qmark style""" + mock_df = pd.DataFrame({"col1": [1, 2, 3]}) + mocked_query_data_source.return_value = mock_df + mocked_render_jinja.return_value = ( + "SELECT * FROM test_table WHERE id = ?", + [123], + ) + + sql_alchemy_json = json.dumps( + { + "url": "trino://user@localhost:8080/catalog", + "params": {}, + "integration_id": "test_integration", + } + ) + + execute_sql_with_connection_json( + "SELECT * FROM test_table WHERE id = {{ user_id }}", + sql_alchemy_json, + ) + + # Verify render_jinja_sql_template was called with 'qmark' param_style + mocked_render_jinja.assert_called() + call_args = mocked_render_jinja.call_args[0] + self.assertEqual(call_args[1], "qmark") + + # Verify bind_params is a list (qmark style) not dict (pyformat style) + call_args = mocked_query_data_source.call_args[0] + bind_params = call_args[1] + self.assertIsInstance(bind_params, list) + self.assertEqual(bind_params, [123]) + + @mock.patch("pandas.read_sql_query") + def test_list_bind_params_converted_to_tuple_for_pandas(self, mocked_read_sql): + """Test that list bind_params are converted to tuple for pandas.read_sql_query""" + from deepnote_toolkit.sql.sql_execution import _execute_sql_on_engine + + mock_df = pd.DataFrame({"col1": [1, 2, 3]}) + mocked_read_sql.return_value = mock_df + + # Mock engine and connection + mock_engine = mock.Mock() + mock_connection = mock.Mock() + mock_engine.begin.return_value.__enter__ = mock.Mock( + return_value=mock_connection + ) + mock_engine.begin.return_value.__exit__ = mock.Mock(return_value=None) + + # Test with list bind_params (qmark style for Trino) + list_params = [123, "test"] + _execute_sql_on_engine( + mock_engine, + "SELECT * FROM test_table WHERE id = ? AND name = ?", + list_params, + ) + + # Verify pandas.read_sql_query was called + self.assertTrue(mocked_read_sql.called) + + # Get the params argument passed to pandas.read_sql_query + call_kwargs = mocked_read_sql.call_args[1] + params_arg = call_kwargs.get("params") + + # Verify that list was converted to tuple + self.assertIsInstance(params_arg, tuple) + self.assertEqual(params_arg, (123, "test")) + + @mock.patch("pandas.read_sql_query") + def test_dict_bind_params_not_converted_for_pandas(self, mocked_read_sql): + """Test that dict bind_params remain as dict for pandas.read_sql_query""" + from deepnote_toolkit.sql.sql_execution import _execute_sql_on_engine + + mock_df = pd.DataFrame({"col1": [1, 2, 3]}) + mocked_read_sql.return_value = mock_df + + # Mock engine and connection + mock_engine = mock.Mock() + mock_connection = mock.Mock() + mock_engine.begin.return_value.__enter__ = mock.Mock( + return_value=mock_connection + ) + mock_engine.begin.return_value.__exit__ = mock.Mock(return_value=None) + + # Test with dict bind_params (pyformat style) + dict_params = {"id": 123, "name": "test"} + _execute_sql_on_engine( + mock_engine, + "SELECT * FROM test_table WHERE id = %(id)s AND name = %(name)s", + dict_params, + ) + + # Verify pandas.read_sql_query was called + self.assertTrue(mocked_read_sql.called) + + # Get the params argument passed to pandas.read_sql_query + call_kwargs = mocked_read_sql.call_args[1] + params_arg = call_kwargs.get("params") + + # Verify that dict was NOT converted (remains as dict) + self.assertIsInstance(params_arg, dict) + self.assertEqual(params_arg, {"id": 123, "name": "test"}) + + class TestSanitizeDataframe(unittest.TestCase): @parameterized.expand([(key, df) for key, df in testing_dataframes.items()]) def test_all_dataframes_serialize_to_parquet(self, key, df):