Skip to content

Commit 5065a2f

Browse files
committed
add unit tests
1 parent 0de8945 commit 5065a2f

File tree

4 files changed

+224
-10
lines changed

4 files changed

+224
-10
lines changed

libs/foundry-dev-tools/src/foundry_dev_tools/clients/foundry_sql_server.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,14 @@
1414
FoundrySqlQueryFailedError,
1515
FoundrySqlSerializationFormatNotImplementedError,
1616
)
17-
from foundry_dev_tools.utils.api_types import ArrowCompressionCodec, Ref, SqlDialect, SQLReturnType, assert_in_literal
17+
from foundry_dev_tools.utils.api_types import (
18+
ArrowCompressionCodec,
19+
FurnaceSqlDialect,
20+
Ref,
21+
SqlDialect,
22+
SQLReturnType,
23+
assert_in_literal,
24+
)
1825

1926
if TYPE_CHECKING:
2027
import pandas as pd
@@ -318,9 +325,10 @@ def query_foundry_sql(
318325
query: str,
319326
return_type: Literal["pandas"],
320327
branch: Ref = ...,
321-
sql_dialect: SqlDialect = ...,
328+
sql_dialect: FurnaceSqlDialect = ...,
322329
arrow_compression_codec: ArrowCompressionCodec = ...,
323330
timeout: int = ...,
331+
experimental_use_trino: bool = ...,
324332
) -> pd.core.frame.DataFrame: ...
325333

326334
@overload
@@ -329,9 +337,10 @@ def query_foundry_sql(
329337
query: str,
330338
return_type: Literal["polars"],
331339
branch: Ref = ...,
332-
sql_dialect: SqlDialect = ...,
340+
sql_dialect: FurnaceSqlDialect = ...,
333341
arrow_compression_codec: ArrowCompressionCodec = ...,
334342
timeout: int = ...,
343+
experimental_use_trino: bool = ...,
335344
) -> pl.DataFrame: ...
336345

337346
@overload
@@ -340,9 +349,10 @@ def query_foundry_sql(
340349
query: str,
341350
return_type: Literal["spark"],
342351
branch: Ref = ...,
343-
sql_dialect: SqlDialect = ...,
352+
sql_dialect: FurnaceSqlDialect = ...,
344353
arrow_compression_codec: ArrowCompressionCodec = ...,
345354
timeout: int = ...,
355+
experimental_use_trino: bool = ...,
346356
) -> pyspark.sql.DataFrame: ...
347357

348358
@overload
@@ -351,9 +361,10 @@ def query_foundry_sql(
351361
query: str,
352362
return_type: Literal["arrow"],
353363
branch: Ref = ...,
354-
sql_dialect: SqlDialect = ...,
364+
sql_dialect: FurnaceSqlDialect = ...,
355365
arrow_compression_codec: ArrowCompressionCodec = ...,
356366
timeout: int = ...,
367+
experimental_use_trino: bool = ...,
357368
) -> pa.Table: ...
358369

359370
@overload
@@ -362,19 +373,21 @@ def query_foundry_sql(
362373
query: str,
363374
return_type: SQLReturnType = ...,
364375
branch: Ref = ...,
365-
sql_dialect: SqlDialect = ...,
376+
sql_dialect: FurnaceSqlDialect = ...,
366377
arrow_compression_codec: ArrowCompressionCodec = ...,
367378
timeout: int = ...,
379+
experimental_use_trino: bool = ...,
368380
) -> tuple[dict, list[list]] | pd.core.frame.DataFrame | pl.DataFrame | pa.Table | pyspark.sql.DataFrame: ...
369381

370382
def query_foundry_sql(
371383
self,
372384
query: str,
373385
return_type: SQLReturnType = "pandas",
374386
branch: Ref = "master",
375-
sql_dialect: SqlDialect = "SPARK",
387+
sql_dialect: FurnaceSqlDialect = "SPARK",
376388
arrow_compression_codec: ArrowCompressionCodec = "NONE",
377389
timeout: int = 600,
390+
experimental_use_trino: bool = False,
378391
) -> tuple[dict, list[list]] | pd.core.frame.DataFrame | pl.DataFrame | pa.Table | pyspark.sql.DataFrame:
379392
"""Queries the Foundry SQL server using the V2 API.
380393
@@ -389,9 +402,10 @@ def query_foundry_sql(
389402
query: The SQL Query
390403
return_type: See :py:class:foundry_dev_tools.foundry_api_client.SQLReturnType
391404
branch: The dataset branch to query
392-
sql_dialect: The SQL dialect to use
405+
sql_dialect: The SQL dialect to use (only SPARK is supported for V2)
393406
arrow_compression_codec: Arrow compression codec (NONE, LZ4, ZSTD)
394407
timeout: Query timeout in seconds
408+
experimental_use_trino: If True, modifies the query to use Trino backend by adding /*+ backend(trino) */ hint
395409
396410
Returns:
397411
:external+pandas:py:class:`~pandas.DataFrame` | :external+polars:py:class:`~polars.DataFrame` | :external+pyarrow:py:class:`~pyarrow.Table` | :external+spark:py:class:`~pyspark.sql.DataFrame`:
@@ -403,6 +417,11 @@ def query_foundry_sql(
403417
FoundrySqlQueryClientTimedOutError: If the query times out
404418
405419
""" # noqa: E501
420+
assert_in_literal(sql_dialect, FurnaceSqlDialect, "sql_dialect")
421+
422+
if experimental_use_trino:
423+
query = query.replace("SELECT ", "SELECT /*+ backend(trino) */ ", 1)
424+
406425
response_json = self.api_query(
407426
query=query, dialect=sql_dialect, branch=branch, arrow_compression_codec=arrow_compression_codec
408427
).json()
@@ -507,7 +526,7 @@ def read_stream_results_arrow(self, ticket: dict[str, Any]) -> pa.ipc.RecordBatc
507526
def api_query(
508527
self,
509528
query: str,
510-
dialect: SqlDialect,
529+
dialect: FurnaceSqlDialect,
511530
branch: Ref,
512531
arrow_compression_codec: ArrowCompressionCodec = "NONE",
513532
**kwargs,
@@ -516,7 +535,7 @@ def api_query(
516535
517536
Args:
518537
query: The SQL query string
519-
dialect: The SQL dialect to use
538+
dialect: The SQL dialect to use (only SPARK is supported)
520539
branch: The dataset branch to query
521540
arrow_compression_codec: Arrow compression codec (NONE, LZ4, ZSTD)
522541
**kwargs: gets passed to :py:meth:`APIClient.api_request`

libs/foundry-dev-tools/src/foundry_dev_tools/utils/api_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def assert_in_literal(option, literal, variable_name) -> None: # noqa: ANN001
9595
SqlDialect = Literal["ANSI", "SPARK"]
9696
"""The SQL Dialect for Foundry SQL queries."""
9797

98+
FurnaceSqlDialect = Literal["SPARK"]
99+
"""The SQL Dialect for Furnace SQL queries (V2 API). Only SPARK is supported."""
100+
98101
ArrowCompressionCodec = Literal["NONE", "LZ4", "ZSTD"]
99102
"""The Arrow compression codec for Foundry SQL queries."""
100103

tests/integration/clients/test_foundry_sql_server.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,25 @@ def test_legacy_fallback(mocker):
7373
query_foundry_sql_legacy_spy.assert_called()
7474

7575

76+
def test_v1_ansi_sql_dialect():
77+
"""Test V1 client with ANSI SQL dialect (uses double quotes instead of backticks)."""
78+
# Test basic query with ANSI dialect - note the use of double quotes instead of backticks
79+
result = TEST_SINGLETON.ctx.foundry_sql_server.query_foundry_sql(
80+
query=f'SELECT sepal_width, sepal_length FROM "{TEST_SINGLETON.iris_new.rid}" LIMIT 5',
81+
sql_dialect="ANSI",
82+
)
83+
assert result.shape[0] == 5
84+
assert result.shape[1] == 2
85+
86+
# Test with aggregation using ANSI dialect
87+
result_agg = TEST_SINGLETON.ctx.foundry_sql_server.query_foundry_sql(
88+
query=f'SELECT COUNT(*) as cnt FROM "{TEST_SINGLETON.iris_new.rid}"',
89+
sql_dialect="ANSI",
90+
)
91+
assert result_agg.shape[0] == 1
92+
assert "cnt" in result_agg.columns
93+
94+
7695
# V2 Client Tests
7796

7897

@@ -176,3 +195,73 @@ def test_v2_polars_return_type():
176195
assert isinstance(polars_df, pl.DataFrame)
177196
assert polars_df.height == 2
178197
assert polars_df.width == 1
198+
199+
200+
def test_v2_polars_parquet():
201+
polars_df = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql(
202+
f"SELECT sepal_length FROM `{TEST_SINGLETON.iris_parquet.rid}` LIMIT 2",
203+
return_type="polars",
204+
)
205+
assert isinstance(polars_df, pl.DataFrame)
206+
assert polars_df.height == 2
207+
assert polars_df.width == 1
208+
209+
polars_df = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql(
210+
f"SELECT sepal_length FROM `{TEST_SINGLETON.iris_parquet.rid}` LIMIT 2",
211+
return_type="polars",
212+
experimental_use_trino=True,
213+
)
214+
assert isinstance(polars_df, pl.DataFrame)
215+
assert polars_df.height == 2
216+
assert polars_df.width == 1
217+
218+
219+
def test_v2_polars_parquet_hive_partitioning():
220+
polars_df = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql(
221+
f"SELECT sepal_length FROM `{TEST_SINGLETON.iris_hive_partitioned.rid}` LIMIT 2",
222+
return_type="polars",
223+
experimental_use_trino=True,
224+
)
225+
assert isinstance(polars_df, pl.DataFrame)
226+
assert polars_df.height == 2
227+
assert polars_df.width == 1
228+
229+
polars_df = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql(
230+
f"SELECT sepal_length FROM `{TEST_SINGLETON.iris_hive_partitioned.rid}` LIMIT 2", return_type="polars"
231+
)
232+
assert isinstance(polars_df, pl.DataFrame)
233+
assert polars_df.height == 2
234+
assert polars_df.width == 1
235+
236+
237+
def test_v2_arrow_compression_codecs():
238+
"""Test V2 client with different arrow compression codecs."""
239+
# Test with LZ4 compression
240+
result_lz4 = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql(
241+
query=f"SELECT * FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 10",
242+
arrow_compression_codec="LZ4",
243+
)
244+
assert result_lz4.shape[0] == 10
245+
assert result_lz4.shape[1] == 5
246+
247+
# Test with ZSTD compression
248+
result_zstd = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql(
249+
query=f"SELECT * FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 10",
250+
arrow_compression_codec="ZSTD",
251+
)
252+
assert result_zstd.shape[0] == 10
253+
assert result_zstd.shape[1] == 5
254+
255+
# Test with NONE compression (default)
256+
result_none = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql(
257+
query=f"SELECT * FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 10",
258+
arrow_compression_codec="NONE",
259+
)
260+
assert result_none.shape[0] == 10
261+
assert result_none.shape[1] == 5
262+
263+
# Verify all results have the same data
264+
import pandas as pd
265+
266+
pd.testing.assert_frame_equal(result_lz4, result_zstd)
267+
pd.testing.assert_frame_equal(result_lz4, result_none)

tests/unit/clients/test_foundry_sql_server.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,106 @@ def test_exception_unknown_json(mocker, test_context_mock):
121121
timeout=0.001,
122122
)
123123
assert exception.value.error_message == ""
124+
125+
126+
def test_v2_experimental_use_trino(mocker, test_context_mock):
127+
"""Test that experimental_use_trino parameter modifies the query correctly."""
128+
import pandas as pd
129+
130+
mocker.patch("time.sleep") # we do not want to wait in tests
131+
132+
# Mock the arrow stream reader to return a simple pandas DataFrame
133+
mock_arrow_reader = mocker.MagicMock()
134+
mock_arrow_reader.read_pandas.return_value = pd.DataFrame({"col1": [1, 2, 3]})
135+
mocker.patch.object(
136+
test_context_mock.foundry_sql_server_v2,
137+
"read_stream_results_arrow",
138+
return_value=mock_arrow_reader,
139+
)
140+
141+
# Mock the api_query endpoint (initial query execution)
142+
query_matcher = mocker.MagicMock()
143+
test_context_mock.mock_adapter.register_uri(
144+
"POST",
145+
build_api_url(TEST_HOST.url, "foundry-sql-server", "sql-endpoint/v1/queries/query"),
146+
json={"type": "running", "running": {"queryHandle": {"queryId": "test-query-id-123", "type": "foundry"}}},
147+
additional_matcher=query_matcher,
148+
)
149+
150+
# Mock the api_status endpoint (poll for completion - returns ready immediately)
151+
test_context_mock.mock_adapter.register_uri(
152+
"POST",
153+
build_api_url(TEST_HOST.url, "foundry-sql-server", "sql-endpoint/v1/queries/status"),
154+
json={
155+
"status": {
156+
"type": "ready",
157+
"ready": {"tickets": [{"tickets": ["eyJhbGc...mock-ticket-1", "eyJhbGc...mock-ticket-2"]}]},
158+
}
159+
},
160+
)
161+
162+
# Test with experimental_use_trino=True
163+
df = test_context_mock.foundry_sql_server_v2.query_foundry_sql(
164+
"SELECT * FROM `ri.foundry.main.dataset.test-dataset`",
165+
experimental_use_trino=True,
166+
)
167+
168+
# Verify the query was modified to include the Trino backend hint
169+
call_args = query_matcher.call_args_list[0]
170+
request = call_args[0][0]
171+
request_json = request.json()
172+
173+
assert "SELECT /*+ backend(trino) */ * FROM" in request_json["querySpec"]["query"]
174+
assert df.shape[0] == 3
175+
176+
# Reset for second test
177+
query_matcher.reset_mock()
178+
179+
# Test with experimental_use_trino=False (default)
180+
df = test_context_mock.foundry_sql_server_v2.query_foundry_sql(
181+
"SELECT * FROM `ri.foundry.main.dataset.test-dataset`",
182+
experimental_use_trino=False,
183+
)
184+
185+
# Verify the query was NOT modified
186+
call_args = query_matcher.call_args_list[0]
187+
request = call_args[0][0]
188+
request_json = request.json()
189+
190+
assert request_json["querySpec"]["query"] == "SELECT * FROM `ri.foundry.main.dataset.test-dataset`"
191+
assert "backend(trino)" not in request_json["querySpec"]["query"]
192+
assert df.shape[0] == 3
193+
194+
195+
def test_v2_poll_for_query_completion_timeout(mocker, test_context_mock):
196+
"""Test that V2 query times out correctly when polling takes too long."""
197+
mocker.patch("time.sleep") # we do not want to wait in tests
198+
199+
# Mock the api_query endpoint (initial query execution)
200+
test_context_mock.mock_adapter.register_uri(
201+
"POST",
202+
build_api_url(TEST_HOST.url, "foundry-sql-server", "sql-endpoint/v1/queries/query"),
203+
json={"type": "running", "running": {"queryHandle": {"queryId": "test-query-timeout-123", "type": "foundry"}}},
204+
)
205+
206+
# Mock the api_status endpoint to always return running status
207+
test_context_mock.mock_adapter.register_uri(
208+
"POST",
209+
build_api_url(TEST_HOST.url, "foundry-sql-server", "sql-endpoint/v1/queries/status"),
210+
json={"status": {"type": "running", "running": {}}},
211+
)
212+
213+
with pytest.raises(FoundrySqlQueryClientTimedOutError):
214+
test_context_mock.foundry_sql_server_v2.query_foundry_sql(
215+
"SELECT * FROM `ri.foundry.main.dataset.test-dataset`",
216+
timeout=0.001,
217+
)
218+
219+
220+
def test_v2_ansi_dialect_not_supported(test_context_mock):
221+
"""Test that V2 client rejects ANSI SQL dialect."""
222+
with pytest.raises(TypeError, match="'ANSI' is not a valid option for sql_dialect"):
223+
test_context_mock.foundry_sql_server_v2.query_foundry_sql(
224+
"SELECT * FROM `ri.foundry.main.dataset.test-dataset`",
225+
sql_dialect="ANSI", # type: ignore[arg-type]
226+
)

0 commit comments

Comments
 (0)