Skip to content

Commit 0de8945

Browse files
committed
adjust to new furnace API
1 parent 78c2ebc commit 0de8945

File tree

4 files changed

+93
-84
lines changed

4 files changed

+93
-84
lines changed

libs/foundry-dev-tools/src/foundry_dev_tools/clients/foundry_sql_server.py

Lines changed: 89 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
FoundrySqlQueryFailedError,
1515
FoundrySqlSerializationFormatNotImplementedError,
1616
)
17-
from foundry_dev_tools.utils.api_types import Ref, SqlDialect, SQLReturnType, assert_in_literal
17+
from foundry_dev_tools.utils.api_types import ArrowCompressionCodec, Ref, SqlDialect, SQLReturnType, assert_in_literal
1818

1919
if TYPE_CHECKING:
2020
import pandas as pd
@@ -316,9 +316,10 @@ class FoundrySqlServerClientV2(APIClient):
316316
def query_foundry_sql(
317317
self,
318318
query: str,
319-
application_id: str,
320319
return_type: Literal["pandas"],
321-
disable_arrow_compression: bool = ...,
320+
branch: Ref = ...,
321+
sql_dialect: SqlDialect = ...,
322+
arrow_compression_codec: ArrowCompressionCodec = ...,
322323
timeout: int = ...,
323324
) -> pd.core.frame.DataFrame: ...
324325

@@ -329,61 +330,68 @@ def query_foundry_sql(
329330
return_type: Literal["polars"],
330331
branch: Ref = ...,
331332
sql_dialect: SqlDialect = ...,
333+
arrow_compression_codec: ArrowCompressionCodec = ...,
332334
timeout: int = ...,
333335
) -> pl.DataFrame: ...
334336

335337
@overload
336338
def query_foundry_sql(
337339
self,
338340
query: str,
339-
application_id: str,
340341
return_type: Literal["spark"],
341-
disable_arrow_compression: bool = ...,
342+
branch: Ref = ...,
343+
sql_dialect: SqlDialect = ...,
344+
arrow_compression_codec: ArrowCompressionCodec = ...,
342345
timeout: int = ...,
343346
) -> pyspark.sql.DataFrame: ...
344347

345348
@overload
346349
def query_foundry_sql(
347350
self,
348351
query: str,
349-
application_id: str,
350352
return_type: Literal["arrow"],
351-
disable_arrow_compression: bool = ...,
353+
branch: Ref = ...,
354+
sql_dialect: SqlDialect = ...,
355+
arrow_compression_codec: ArrowCompressionCodec = ...,
352356
timeout: int = ...,
353357
) -> pa.Table: ...
354358

355359
@overload
356360
def query_foundry_sql(
357361
self,
358362
query: str,
359-
application_id: str,
360363
return_type: SQLReturnType = ...,
361-
disable_arrow_compression: bool = ...,
364+
branch: Ref = ...,
365+
sql_dialect: SqlDialect = ...,
366+
arrow_compression_codec: ArrowCompressionCodec = ...,
362367
timeout: int = ...,
363368
) -> tuple[dict, list[list]] | pd.core.frame.DataFrame | pl.DataFrame | pa.Table | pyspark.sql.DataFrame: ...
364369

365370
def query_foundry_sql(
366371
self,
367372
query: str,
368373
return_type: SQLReturnType = "pandas",
369-
disable_arrow_compression: bool = False,
370-
application_id: str | None = None,
374+
branch: Ref = "master",
375+
sql_dialect: SqlDialect = "SPARK",
376+
arrow_compression_codec: ArrowCompressionCodec = "NONE",
377+
timeout: int = 600,
371378
) -> tuple[dict, list[list]] | pd.core.frame.DataFrame | pl.DataFrame | pa.Table | pyspark.sql.DataFrame:
372379
"""Queries the Foundry SQL server using the V2 API.
373380
374381
Uses Arrow IPC to communicate with the Foundry SQL Server Endpoint.
375382
376383
Example:
377384
df = client.query_foundry_sql(
378-
query="SELECT * FROM `ri.foundry.main.dataset.abc` LIMIT 10",
379-
application_id="ri.foundry.main.dataset.abc"
385+
query="SELECT * FROM `ri.foundry.main.dataset.abc` LIMIT 10"
380386
)
381387
382388
Args:
383389
query: The SQL Query
384390
return_type: See :py:class:foundry_dev_tools.foundry_api_client.SQLReturnType
385-
disable_arrow_compression: Whether to disable Arrow compression
386-
application_id: The application/dataset RID, defaults to foundry-dev-tools User-Agent
391+
branch: The dataset branch to query
392+
sql_dialect: The SQL dialect to use
393+
arrow_compression_codec: Arrow compression codec (NONE, LZ4, ZSTD)
394+
timeout: Query timeout in seconds
387395
388396
Returns:
389397
:external+pandas:py:class:`~pandas.DataFrame` | :external+polars:py:class:`~polars.DataFrame` | :external+pyarrow:py:class:`~pyarrow.Table` | :external+spark:py:class:`~pyspark.sql.DataFrame`:
@@ -395,31 +403,29 @@ def query_foundry_sql(
395403
FoundrySqlQueryClientTimedOutError: If the query times out
396404
397405
""" # noqa: E501
398-
# Execute the query
399-
if not application_id:
400-
application_id = self.context.client.headers["User-Agent"]
401-
response_json = self.api_execute(
402-
sql=query,
403-
application_id=application_id,
404-
disable_arrow_compression=disable_arrow_compression,
406+
response_json = self.api_query(
407+
query=query, dialect=sql_dialect, branch=branch, arrow_compression_codec=arrow_compression_codec
405408
).json()
406409

407-
query_identifier = self._extract_query_identifier(response_json)
410+
query_handle = self._extract_query_handle(response_json)
411+
start_time = time.time()
408412

409413
# Poll for completion
410-
while response_json.get("type") != "success":
414+
while response_json.get("status", {}).get("type") != "ready":
411415
time.sleep(0.2)
412-
response = self.api_status(query_identifier)
416+
response = self.api_status(query_handle)
413417
response_json = response.json()
414418

415-
if response_json.get("type") == "failed":
419+
if response_json.get("status", {}).get("type") == "failed":
416420
raise FoundrySqlQueryFailedError(response)
421+
if time.time() > start_time + timeout:
422+
raise FoundrySqlQueryClientTimedOutError(response, timeout=timeout)
417423

418424
# Extract tickets from successful response
419-
tickets = self._extract_tickets(response_json)
425+
ticket = self._extract_ticket(response_json)
420426

421427
# Fetch Arrow data using tickets
422-
arrow_stream_reader = self.read_stream_results_arrow(tickets)
428+
arrow_stream_reader = self.read_stream_results_arrow(ticket)
423429

424430
if return_type == "pandas":
425431
return arrow_stream_reader.read_pandas()
@@ -446,22 +452,20 @@ def query_foundry_sql(
446452

447453
raise ValueError("The following return_type is not supported: " + return_type)
448454

449-
def _extract_query_identifier(self, response_json: dict[str, Any]) -> dict[str, Any]:
450-
"""Extract query identifier from execute response.
455+
def _extract_query_handle(self, response_json: dict[str, Any]) -> dict[str, Any]:
456+
"""Extract query handle from execute response.
451457
452458
Args:
453459
response_json: Response JSON from execute API
454460
461+
455462
Returns:
456-
Query identifier dict
463+
Query handle dict
457464
458465
"""
459-
if response_json["type"] == "triggered" and "plan" in response_json["triggered"]:
460-
plan = response_json["triggered"]["plan"]
461-
LOGGER.debug("plan %s", plan)
462-
return response_json[response_json["type"]]["query"]
466+
return response_json[response_json["type"]]["queryHandle"]
463467

464-
def _extract_tickets(self, response_json: dict[str, Any]) -> list[str]:
468+
def _extract_ticket(self, response_json: dict[str, Any]) -> dict[str, Any]:
465469
"""Extract tickets from success response.
466470
467471
Args:
@@ -471,70 +475,84 @@ def _extract_tickets(self, response_json: dict[str, Any]) -> list[str]:
471475
List of tickets for fetching results
472476
473477
"""
474-
if response_json.get("type") != "success":
475-
msg = f"Expected success response, got: {response_json.get('type')}"
476-
477-
raise ValueError(msg)
478-
479-
chunks = response_json["success"]["result"]["interactive"]["chunks"]
480-
return [chunk["ticket"] for chunk in chunks]
481-
482-
def read_stream_results_arrow(self, tickets: list[str]) -> pa.ipc.RecordBatchStreamReader:
478+
# we combine all tickets into one to get the full data
479+
# if performance is a concern this should be done in parallel
480+
return {
481+
"id": 0,
482+
"tickets": [
483+
ticket
484+
for ticket_group in response_json["status"]["ready"]["tickets"]
485+
for ticket in ticket_group["tickets"]
486+
],
487+
"type": "furnace",
488+
}
489+
490+
def read_stream_results_arrow(self, ticket: dict[str, Any]) -> pa.ipc.RecordBatchStreamReader:
483491
"""Fetch query results using tickets and return Arrow stream reader.
484492
485493
Args:
486-
tickets: List of tickets from status API success response
494+
ticket: dict of tickets e.g. { "id": 0, "tickets": ["ey...", ...], "type": "furnace", }
487495
488496
Returns:
489497
Arrow RecordBatchStreamReader
490498
491499
"""
492500
from foundry_dev_tools._optional.pyarrow import pa
493501

494-
response = self._api_stream_ticket(tickets)
502+
response = self.api_stream_ticket(ticket)
495503
response.raw.decode_content = True
496504

497505
return pa.ipc.RecordBatchStreamReader(response.raw)
498506

499-
def api_execute(
507+
def api_query(
500508
self,
501-
sql: str,
502-
application_id: str,
503-
disable_arrow_compression: bool = False,
509+
query: str,
510+
dialect: SqlDialect,
511+
branch: Ref,
512+
arrow_compression_codec: ArrowCompressionCodec = "NONE",
504513
**kwargs,
505514
) -> requests.Response:
506515
"""Execute a SQL query via the V2 API.
507516
508517
Args:
509-
sql: The SQL query to execute
510-
application_id: The application/dataset RID
511-
disable_arrow_compression: Whether to disable Arrow compression
518+
query: The SQL query string
519+
dialect: The SQL dialect to use
520+
branch: The dataset branch to query
521+
arrow_compression_codec: Arrow compression codec (NONE, LZ4, ZSTD)
512522
**kwargs: gets passed to :py:meth:`APIClient.api_request`
513523
514524
Returns:
515-
Response with query execution status
525+
Response with query handle and initial status
516526
517527
"""
518528
return self.api_request(
519529
"POST",
520-
"", # Root endpoint /api/
530+
"sql-endpoint/v1/queries/query",
521531
json={
522-
"applicationId": application_id,
523-
"sql": sql,
524-
"disableArrowCompression": disable_arrow_compression,
532+
"querySpec": {
533+
"query": query,
534+
"tableProviders": {},
535+
"dialect": dialect,
536+
"options": {"options": [{"option": "arrowCompressionCodec", "value": arrow_compression_codec}]},
537+
},
538+
"executionParams": {
539+
"defaultBranchIds": [{"type": "datasetBranch", "datasetBranch": branch}],
540+
"resultFormat": "ARROW",
541+
"resultMode": "AUTO",
542+
},
525543
},
526544
**kwargs,
527545
)
528546

529547
def api_status(
530548
self,
531-
query_identifier: dict[str, Any],
549+
query_handle: dict[str, Any],
532550
**kwargs,
533551
) -> requests.Response:
534552
"""Get the status of a SQL query via the V2 API.
535553
536554
Args:
537-
query_identifier: Query identifier dict (e.g., {"type": "interactive", "interactive": "query-id"})
555+
query_handle: Query handle dict from execute response
538556
**kwargs: gets passed to :py:meth:`APIClient.api_request`
539557
540558
Returns:
@@ -543,34 +561,31 @@ def api_status(
543561
"""
544562
return self.api_request(
545563
"POST",
546-
"status",
547-
json={
548-
"query": query_identifier,
549-
},
564+
"sql-endpoint/v1/queries/status",
565+
json=query_handle,
550566
**kwargs,
551567
)
552568

553-
def _api_stream_ticket(
569+
def api_stream_ticket(
554570
self,
555-
tickets: list[str],
571+
ticket: dict,
556572
**kwargs,
557573
) -> requests.Response:
558-
"""Fetch query results using tickets via the V2 API.
574+
"""Stream query results using a ticket via the V2 API.
559575
560576
Args:
561-
tickets: List of tickets from status API success response
577+
ticket: Ticket dict containing id, tickets list, and type.
578+
Example: {"id": 0, "tickets": ["eyJhbGc...", "eyJhbGc..."], "type": "furnace"}
562579
**kwargs: gets passed to :py:meth:`APIClient.api_request`
563580
564581
Returns:
565-
Response with Arrow-encoded query results
582+
Response with streaming Arrow data
566583
567584
"""
568585
return self.api_request(
569586
"POST",
570-
"stream",
571-
json={
572-
"tickets": tickets,
573-
},
587+
"sql-endpoint/v1/queries/stream",
588+
json=ticket,
574589
headers={
575590
"Accept": "application/octet-stream",
576591
},

libs/foundry-dev-tools/src/foundry_dev_tools/errors/handling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
"DataProxy:FallbackBranchesNotSpecifiedInQuery": BranchNotFoundError,
5959
"DataProxy:BadSqlQuery": FoundrySqlQueryFailedError,
6060
"FurnaceSql:SqlParseError": FurnaceSqlSqlParseError,
61+
"SqlQueryService:SqlSyntaxError": FurnaceSqlSqlParseError,
6162
"DataProxy:DatasetNotFound": DatasetNotFoundError,
6263
"Catalog:DuplicateDatasetName": DatasetAlreadyExistsError,
6364
"Catalog:DatasetsNotFound": DatasetNotFoundError,

libs/foundry-dev-tools/src/foundry_dev_tools/utils/api_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def assert_in_literal(option, literal, variable_name) -> None: # noqa: ANN001
9595
SqlDialect = Literal["ANSI", "SPARK"]
9696
"""The SQL Dialect for Foundry SQL queries."""
9797

98+
ArrowCompressionCodec = Literal["NONE", "LZ4", "ZSTD"]
99+
"""The Arrow compression codec for Foundry SQL queries."""
100+
98101
SQLReturnType = Literal["pandas", "polars", "spark", "arrow", "raw"]
99102
"""The return_types for sql queries.
100103

0 commit comments

Comments
 (0)