Skip to content

Commit 361735c

Browse files
committed
refined fetch table function
1 parent fda9a8c commit 361735c

File tree

1 file changed

+108
-123
lines changed

1 file changed

+108
-123
lines changed

src/webapp/databricks.py

Lines changed: 108 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import List, Any, Dict, IO, cast, Optional
1919
from databricks.sdk.errors import DatabricksError
2020
from fastapi import HTTPException
21-
import json
21+
import json, time, requests
2222

2323
try:
2424
import tomllib as _toml # Py 3.11+
@@ -309,136 +309,121 @@ def delete_inst(self, inst_name: str) -> None:
309309
f"Tables or schemas could not be deleted for {medallion}{e}"
310310
)
311311

312-
def fetch_table_data(
313-
self,
314-
catalog_name: str,
315-
inst_name: str,
316-
table_name: str,
317-
warehouse_id: str,
318-
) -> List[Dict[str, Any]]:
319-
"""
320-
Executes a SELECT * query on the specified table within the given catalog and schema,
321-
using the provided SQL warehouse. Returns the result as a list of dictionaries.
322-
"""
323-
try:
324-
w = WorkspaceClient(
325-
host=databricks_vars["DATABRICKS_HOST_URL"],
326-
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
327-
)
328-
LOGGER.info("Successfully created Databricks WorkspaceClient.")
329-
except Exception as e:
330-
LOGGER.exception(
331-
"Failed to create Databricks WorkspaceClient with host: %s and service account: %s",
332-
databricks_vars["DATABRICKS_HOST_URL"],
333-
gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
334-
)
335-
raise ValueError(
336-
f"fetch_table_data(): Workspace client initialization failed: {e}"
337-
)
338-
339-
# Construct the fully qualified table name
340-
schema_name = databricksify_inst_name(inst_name)
341-
fully_qualified_table = (
342-
f"`{catalog_name}`.`{schema_name}_silver`.`{table_name}`"
312+
def fetch_table_data(self, catalog_name: str, inst_name: str, table_name: str, warehouse_id: str) -> List[Dict[str, Any]]:
313+
w = WorkspaceClient(
314+
host=databricks_vars["DATABRICKS_HOST_URL"],
315+
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
316+
)
317+
schema = databricksify_inst_name(inst_name)
318+
table_fqn = f"`{catalog_name}`.`{schema}_silver`.`{table_name}`"
319+
sql = f"SELECT * FROM {table_fqn}"
320+
321+
#1) Execute INLINE + poll until SUCCEEDED
322+
resp = w.statement_execution.execute_statement(
323+
warehouse_id=warehouse_id, statement=sql,
324+
disposition=Disposition.INLINE, format=Format.JSON_ARRAY,
325+
wait_timeout="30s", on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CONTINUE,
343326
)
344-
sql_query = f"SELECT * FROM {fully_qualified_table}"
345-
LOGGER.info(f"Executing SQL: {sql_query}")
346-
347-
try:
348-
# Execute the SQL statement
349-
response = w.statement_execution.execute_statement(
350-
warehouse_id=warehouse_id,
351-
statement=sql_query,
352-
disposition=Disposition.INLINE, # Use Enum member
353-
format=Format.JSON_ARRAY, # Use Enum member
354-
wait_timeout="30s", # Wait up to 30 seconds for execution
355-
on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CANCEL, # Use Enum member
356-
)
357-
LOGGER.info("Databricks SQL execution successful.")
358-
except DatabricksError as e:
359-
LOGGER.exception("Databricks API call failed.")
360-
raise ValueError(f"Databricks API call failed: {e}")
361-
362-
# Check if the query execution was successful
363-
status = response.status
364-
if not status or status.state != StatementState.SUCCEEDED:
365-
error_message = (
366-
status.error.message
367-
if status and status.error
368-
else "No additional error info."
369-
)
370-
raise ValueError(
371-
f"Query did not succeed (state={status.state if status else 'None'}): {error_message}"
372-
)
373-
374-
if (
375-
not response.manifest
376-
or not response.manifest.schema
377-
or not response.manifest.schema.columns
378-
or not response.result
379-
or not response.result.data_array
380-
):
381-
raise ValueError("Query succeeded but schema or result data is missing.")
382327

383-
column_names = [str(column.name) for column in response.manifest.schema.columns]
384-
rows: List[List[Any]] = []
385-
first_chunk = response.result
386-
if getattr(first_chunk, "data_array", None):
387-
rows.extend(first_chunk.data_array)
328+
MAX_BYTES = 20 * 1024 * 1024 # 20 MiB
329+
POLL_INTERVAL_S = 1.0
330+
POLL_TIMEOUT_S = 300.0 # 5 minutes
331+
332+
start = time.time()
333+
while not resp.status or resp.status.state not in {"SUCCEEDED", "FAILED", "CANCELED"}:
334+
if time.time() - start > POLL_TIMEOUT_S:
335+
raise TimeoutError("Timed out waiting for statement to finish (INLINE)")
336+
time.sleep(POLL_INTERVAL_S)
337+
resp = w.statement_execution.get_statement(statement_id=resp.statement_id)
338+
if resp.status.state != "SUCCEEDED":
339+
msg = resp.status.error.message if resp.status and resp.status.error else "no details"
340+
raise ValueError(f"Statement ended in {resp.status.state}: {msg}")
341+
342+
if not (resp.manifest and resp.manifest.schema and resp.manifest.schema.columns):
343+
raise ValueError("Schema/columns missing.")
344+
cols = [c.name for c in resp.manifest.schema.columns]
345+
346+
#2) Build INLINE records until ~20 MiB; if projected to exceed, switch to EXTERNAL_LINKS ---
347+
records: List[Dict[str, Any]] = []
348+
bytes_so_far, have_items = 0, False
349+
350+
def add_row(rd: Dict[str, Any]) -> bool:
351+
nonlocal bytes_so_far, have_items
352+
b = json.dumps(rd, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
353+
projected = bytes_so_far + (1 if have_items else 0) + len(b) + 2
354+
if projected > MAX_BYTES:
355+
return False
356+
records.append(rd)
357+
bytes_so_far += (1 if have_items else 0) + len(b)
358+
have_items = True
359+
return True
388360

389-
if getattr(first_chunk, "truncated", False):
390-
LOGGER.warning(
391-
"Databricks marked the result as truncated by server limits."
392-
)
361+
# Consume INLINE chunks
362+
def consume_inline_chunk(chunk_obj) -> bool:
363+
if getattr(chunk_obj, "truncated", False):
364+
raise ValueError("Server truncated INLINE result.")
365+
arr = getattr(chunk_obj, "data_array", None) or []
366+
for row in arr:
367+
if not add_row(dict(zip(cols, row))):
368+
return False
369+
return True
393370

394-
next_idx = getattr(first_chunk, "next_chunk_index", None)
395-
stmt_id = response.statement_id
371+
first = resp.result
372+
if first and not consume_inline_chunk(first):
373+
inline_over_limit = True
374+
else:
375+
inline_over_limit = False
376+
next_idx = getattr(first, "next_chunk_index", None) if first else None
377+
while next_idx is not None:
378+
chunk = w.statement_execution.get_statement_result_chunk_n(
379+
statement_id=resp.statement_id, chunk_index=next_idx
380+
)
381+
if not consume_inline_chunk(chunk):
382+
inline_over_limit = True
383+
break
384+
next_idx = getattr(chunk, "next_chunk_index", None)
385+
386+
if not inline_over_limit:
387+
return records # INLINE fit under 20 MiB
388+
389+
#3) Re-execute with EXTERNAL_LINKS, then download each presigned URL (no auth header) ---
390+
resp = w.statement_execution.execute_statement(
391+
warehouse_id=warehouse_id, statement=sql,
392+
disposition=Disposition.EXTERNAL_LINKS, format=Format.JSON_ARRAY,
393+
wait_timeout="30s", on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CONTINUE,
394+
)
395+
start = time.time()
396+
while not resp.status or resp.status.state not in {"SUCCEEDED", "FAILED", "CANCELED"}:
397+
if time.time() - start > POLL_TIMEOUT_S:
398+
raise TimeoutError("Timed out waiting for statement to finish (EXTERNAL_LINKS)")
399+
time.sleep(POLL_INTERVAL_S)
400+
resp = w.statement_execution.get_statement(statement_id=resp.statement_id)
401+
if resp.status.state != "SUCCEEDED":
402+
msg = resp.status.error.message if resp.status and resp.status.error else "no details"
403+
raise ValueError(f"Statement (EXTERNAL_LINKS) ended in {resp.status.state}: {msg}")
404+
405+
if not (resp.manifest and resp.manifest.schema and resp.manifest.schema.columns):
406+
raise ValueError("Schema/columns missing (EXTERNAL_LINKS).")
407+
cols = [c.name for c in resp.manifest.schema.columns]
408+
409+
def consume_external_result(result_obj):
410+
links = getattr(result_obj, "external_links", None) or []
411+
for l in links:
412+
url = l.external_link if hasattr(l, "external_link") else l.get("external_link")
413+
r = requests.get(url, timeout=120)
414+
r.raise_for_status()
415+
for row in r.json():
416+
records.append(dict(zip(cols, row)))
417+
return getattr(result_obj, "next_chunk_index", None)
418+
419+
records.clear()
420+
next_idx = consume_external_result(resp.result)
396421

397422
while next_idx is not None:
398423
chunk = w.statement_execution.get_statement_result_chunk_n(
399-
statement_id=stmt_id,
400-
chunk_index=next_idx,
401-
)
402-
if getattr(chunk, "data_array", None):
403-
rows.extend(chunk.data_array)
404-
405-
if getattr(chunk, "truncated", False):
406-
LOGGER.warning("A result chunk was marked truncated by the server.")
407-
408-
next_idx = getattr(chunk, "next_chunk_index", None)
409-
410-
print("Fetched %d rows from table: %s", len(rows), fully_qualified_table)
411-
LOGGER.info("Fetched %d rows from table: %s", len(rows), fully_qualified_table)
412-
413-
# Build list of dicts
414-
records: List[Dict[str, Any]] = [dict(zip(column_names, r)) for r in rows]
415-
416-
try:
417-
encoded = json.dumps(
418-
records, ensure_ascii=False, separators=(",", ":")
419-
).encode("utf-8")
420-
except Exception as e:
421-
LOGGER.exception("Failed to serialize records to JSON.")
422-
raise ValueError(f"Failed to serialize records to JSON: {e}")
423-
424-
payload_bytes = len(encoded)
425-
print(
426-
"Final JSON payload size: %.2f MiB (%d bytes)",
427-
payload_bytes / (1024 * 1024),
428-
payload_bytes,
429-
)
430-
LOGGER.info(
431-
"Final JSON payload size: %.2f MiB (%d bytes)",
432-
payload_bytes / (1024 * 1024),
433-
payload_bytes,
434-
)
435-
436-
max_json_size = 25 * 1024 * 1024
437-
if payload_bytes > max_json_size:
438-
raise ValueError(
439-
f"Result exceeds maximum allowed JSON payload of {max_json_size} bytes "
440-
f"({max_json_size / (1024 * 1024):.2f} MiB). Got {payload_bytes} bytes."
424+
statement_id=resp.statement_id, chunk_index=next_idx
441425
)
426+
next_idx = consume_external_result(chunk)
442427

443428
return records
444429

0 commit comments

Comments
 (0)