|
18 | 18 | from typing import List, Any, Dict, IO, cast, Optional |
19 | 19 | from databricks.sdk.errors import DatabricksError |
20 | 20 | from fastapi import HTTPException |
21 | | -import json |
| 21 | +import json, time, requests |
22 | 22 |
|
23 | 23 | try: |
24 | 24 | import tomllib as _toml # Py 3.11+ |
@@ -309,136 +309,121 @@ def delete_inst(self, inst_name: str) -> None: |
309 | 309 | f"Tables or schemas could not be deleted for {medallion} — {e}" |
310 | 310 | ) |
311 | 311 |
|
312 | | - def fetch_table_data( |
313 | | - self, |
314 | | - catalog_name: str, |
315 | | - inst_name: str, |
316 | | - table_name: str, |
317 | | - warehouse_id: str, |
318 | | - ) -> List[Dict[str, Any]]: |
319 | | - """ |
320 | | - Executes a SELECT * query on the specified table within the given catalog and schema, |
321 | | - using the provided SQL warehouse. Returns the result as a list of dictionaries. |
322 | | - """ |
323 | | - try: |
324 | | - w = WorkspaceClient( |
325 | | - host=databricks_vars["DATABRICKS_HOST_URL"], |
326 | | - google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], |
327 | | - ) |
328 | | - LOGGER.info("Successfully created Databricks WorkspaceClient.") |
329 | | - except Exception as e: |
330 | | - LOGGER.exception( |
331 | | - "Failed to create Databricks WorkspaceClient with host: %s and service account: %s", |
332 | | - databricks_vars["DATABRICKS_HOST_URL"], |
333 | | - gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], |
334 | | - ) |
335 | | - raise ValueError( |
336 | | - f"fetch_table_data(): Workspace client initialization failed: {e}" |
337 | | - ) |
338 | | - |
339 | | - # Construct the fully qualified table name |
340 | | - schema_name = databricksify_inst_name(inst_name) |
341 | | - fully_qualified_table = ( |
342 | | - f"`{catalog_name}`.`{schema_name}_silver`.`{table_name}`" |
| 312 | + def fetch_table_data(self, catalog_name: str, inst_name: str, table_name: str, warehouse_id: str) -> List[Dict[str, Any]]: |
| 313 | + w = WorkspaceClient( |
| 314 | + host=databricks_vars["DATABRICKS_HOST_URL"], |
| 315 | + google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], |
| 316 | + ) |
| 317 | + schema = databricksify_inst_name(inst_name) |
| 318 | + table_fqn = f"`{catalog_name}`.`{schema}_silver`.`{table_name}`" |
| 319 | + sql = f"SELECT * FROM {table_fqn}" |
| 320 | + |
| 321 | + #1) Execute INLINE + poll until SUCCEEDED |
| 322 | + resp = w.statement_execution.execute_statement( |
| 323 | + warehouse_id=warehouse_id, statement=sql, |
| 324 | + disposition=Disposition.INLINE, format=Format.JSON_ARRAY, |
| 325 | + wait_timeout="30s", on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CONTINUE, |
343 | 326 | ) |
344 | | - sql_query = f"SELECT * FROM {fully_qualified_table}" |
345 | | - LOGGER.info(f"Executing SQL: {sql_query}") |
346 | | - |
347 | | - try: |
348 | | - # Execute the SQL statement |
349 | | - response = w.statement_execution.execute_statement( |
350 | | - warehouse_id=warehouse_id, |
351 | | - statement=sql_query, |
352 | | - disposition=Disposition.INLINE, # Use Enum member |
353 | | - format=Format.JSON_ARRAY, # Use Enum member |
354 | | - wait_timeout="30s", # Wait up to 30 seconds for execution |
355 | | - on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CANCEL, # Use Enum member |
356 | | - ) |
357 | | - LOGGER.info("Databricks SQL execution successful.") |
358 | | - except DatabricksError as e: |
359 | | - LOGGER.exception("Databricks API call failed.") |
360 | | - raise ValueError(f"Databricks API call failed: {e}") |
361 | | - |
362 | | - # Check if the query execution was successful |
363 | | - status = response.status |
364 | | - if not status or status.state != StatementState.SUCCEEDED: |
365 | | - error_message = ( |
366 | | - status.error.message |
367 | | - if status and status.error |
368 | | - else "No additional error info." |
369 | | - ) |
370 | | - raise ValueError( |
371 | | - f"Query did not succeed (state={status.state if status else 'None'}): {error_message}" |
372 | | - ) |
373 | | - |
374 | | - if ( |
375 | | - not response.manifest |
376 | | - or not response.manifest.schema |
377 | | - or not response.manifest.schema.columns |
378 | | - or not response.result |
379 | | - or not response.result.data_array |
380 | | - ): |
381 | | - raise ValueError("Query succeeded but schema or result data is missing.") |
382 | 327 |
|
383 | | - column_names = [str(column.name) for column in response.manifest.schema.columns] |
384 | | - rows: List[List[Any]] = [] |
385 | | - first_chunk = response.result |
386 | | - if getattr(first_chunk, "data_array", None): |
387 | | - rows.extend(first_chunk.data_array) |
| 328 | + MAX_BYTES = 20 * 1024 * 1024 # 20 MiB |
| 329 | + POLL_INTERVAL_S = 1.0 |
| 330 | + POLL_TIMEOUT_S = 300.0 # 5 minutes |
| 331 | + |
| 332 | + start = time.time() |
| 333 | + while not resp.status or resp.status.state not in {"SUCCEEDED", "FAILED", "CANCELED"}: |
| 334 | + if time.time() - start > POLL_TIMEOUT_S: |
| 335 | + raise TimeoutError("Timed out waiting for statement to finish (INLINE)") |
| 336 | + time.sleep(POLL_INTERVAL_S) |
| 337 | + resp = w.statement_execution.get_statement(statement_id=resp.statement_id) |
| 338 | + if resp.status.state != "SUCCEEDED": |
| 339 | + msg = resp.status.error.message if resp.status and resp.status.error else "no details" |
| 340 | + raise ValueError(f"Statement ended in {resp.status.state}: {msg}") |
| 341 | + |
| 342 | + if not (resp.manifest and resp.manifest.schema and resp.manifest.schema.columns): |
| 343 | + raise ValueError("Schema/columns missing.") |
| 344 | + cols = [c.name for c in resp.manifest.schema.columns] |
| 345 | + |
| 346 | + #2) Build INLINE records until ~20 MiB; if projected to exceed, switch to EXTERNAL_LINKS --- |
| 347 | + records: List[Dict[str, Any]] = [] |
| 348 | + bytes_so_far, have_items = 0, False |
| 349 | + |
| 350 | + def add_row(rd: Dict[str, Any]) -> bool: |
| 351 | + nonlocal bytes_so_far, have_items |
| 352 | + b = json.dumps(rd, ensure_ascii=False, separators=(",", ":")).encode("utf-8") |
| 353 | + projected = bytes_so_far + (1 if have_items else 0) + len(b) + 2 |
| 354 | + if projected > MAX_BYTES: |
| 355 | + return False |
| 356 | + records.append(rd) |
| 357 | + bytes_so_far += (1 if have_items else 0) + len(b) |
| 358 | + have_items = True |
| 359 | + return True |
388 | 360 |
|
389 | | - if getattr(first_chunk, "truncated", False): |
390 | | - LOGGER.warning( |
391 | | - "Databricks marked the result as truncated by server limits." |
392 | | - ) |
| 361 | + # Consume INLINE chunks |
| 362 | + def consume_inline_chunk(chunk_obj) -> bool: |
| 363 | + if getattr(chunk_obj, "truncated", False): |
| 364 | + raise ValueError("Server truncated INLINE result.") |
| 365 | + arr = getattr(chunk_obj, "data_array", None) or [] |
| 366 | + for row in arr: |
| 367 | + if not add_row(dict(zip(cols, row))): |
| 368 | + return False |
| 369 | + return True |
393 | 370 |
|
394 | | - next_idx = getattr(first_chunk, "next_chunk_index", None) |
395 | | - stmt_id = response.statement_id |
| 371 | + first = resp.result |
| 372 | + if first and not consume_inline_chunk(first): |
| 373 | + inline_over_limit = True |
| 374 | + else: |
| 375 | + inline_over_limit = False |
| 376 | + next_idx = getattr(first, "next_chunk_index", None) if first else None |
| 377 | + while next_idx is not None: |
| 378 | + chunk = w.statement_execution.get_statement_result_chunk_n( |
| 379 | + statement_id=resp.statement_id, chunk_index=next_idx |
| 380 | + ) |
| 381 | + if not consume_inline_chunk(chunk): |
| 382 | + inline_over_limit = True |
| 383 | + break |
| 384 | + next_idx = getattr(chunk, "next_chunk_index", None) |
| 385 | + |
| 386 | + if not inline_over_limit: |
| 387 | + return records # INLINE fit under 20 MiB |
| 388 | + |
| 389 | + #3) Re-execute with EXTERNAL_LINKS, then download each presigned URL (no auth header) --- |
| 390 | + resp = w.statement_execution.execute_statement( |
| 391 | + warehouse_id=warehouse_id, statement=sql, |
| 392 | + disposition=Disposition.EXTERNAL_LINKS, format=Format.JSON_ARRAY, |
| 393 | + wait_timeout="30s", on_wait_timeout=ExecuteStatementRequestOnWaitTimeout.CONTINUE, |
| 394 | + ) |
| 395 | + start = time.time() |
| 396 | + while not resp.status or resp.status.state not in {"SUCCEEDED", "FAILED", "CANCELED"}: |
| 397 | + if time.time() - start > POLL_TIMEOUT_S: |
| 398 | + raise TimeoutError("Timed out waiting for statement to finish (EXTERNAL_LINKS)") |
| 399 | + time.sleep(POLL_INTERVAL_S) |
| 400 | + resp = w.statement_execution.get_statement(statement_id=resp.statement_id) |
| 401 | + if resp.status.state != "SUCCEEDED": |
| 402 | + msg = resp.status.error.message if resp.status and resp.status.error else "no details" |
| 403 | + raise ValueError(f"Statement (EXTERNAL_LINKS) ended in {resp.status.state}: {msg}") |
| 404 | + |
| 405 | + if not (resp.manifest and resp.manifest.schema and resp.manifest.schema.columns): |
| 406 | + raise ValueError("Schema/columns missing (EXTERNAL_LINKS).") |
| 407 | + cols = [c.name for c in resp.manifest.schema.columns] |
| 408 | + |
| 409 | + def consume_external_result(result_obj): |
| 410 | + links = getattr(result_obj, "external_links", None) or [] |
| 411 | + for l in links: |
| 412 | + url = l.external_link if hasattr(l, "external_link") else l.get("external_link") |
| 413 | + r = requests.get(url, timeout=120) |
| 414 | + r.raise_for_status() |
| 415 | + for row in r.json(): |
| 416 | + records.append(dict(zip(cols, row))) |
| 417 | + return getattr(result_obj, "next_chunk_index", None) |
| 418 | + |
| 419 | + records.clear() |
| 420 | + next_idx = consume_external_result(resp.result) |
396 | 421 |
|
397 | 422 | while next_idx is not None: |
398 | 423 | chunk = w.statement_execution.get_statement_result_chunk_n( |
399 | | - statement_id=stmt_id, |
400 | | - chunk_index=next_idx, |
401 | | - ) |
402 | | - if getattr(chunk, "data_array", None): |
403 | | - rows.extend(chunk.data_array) |
404 | | - |
405 | | - if getattr(chunk, "truncated", False): |
406 | | - LOGGER.warning("A result chunk was marked truncated by the server.") |
407 | | - |
408 | | - next_idx = getattr(chunk, "next_chunk_index", None) |
409 | | - |
410 | | - print("Fetched %d rows from table: %s", len(rows), fully_qualified_table) |
411 | | - LOGGER.info("Fetched %d rows from table: %s", len(rows), fully_qualified_table) |
412 | | - |
413 | | - # Build list of dicts |
414 | | - records: List[Dict[str, Any]] = [dict(zip(column_names, r)) for r in rows] |
415 | | - |
416 | | - try: |
417 | | - encoded = json.dumps( |
418 | | - records, ensure_ascii=False, separators=(",", ":") |
419 | | - ).encode("utf-8") |
420 | | - except Exception as e: |
421 | | - LOGGER.exception("Failed to serialize records to JSON.") |
422 | | - raise ValueError(f"Failed to serialize records to JSON: {e}") |
423 | | - |
424 | | - payload_bytes = len(encoded) |
425 | | - print( |
426 | | - "Final JSON payload size: %.2f MiB (%d bytes)", |
427 | | - payload_bytes / (1024 * 1024), |
428 | | - payload_bytes, |
429 | | - ) |
430 | | - LOGGER.info( |
431 | | - "Final JSON payload size: %.2f MiB (%d bytes)", |
432 | | - payload_bytes / (1024 * 1024), |
433 | | - payload_bytes, |
434 | | - ) |
435 | | - |
436 | | - max_json_size = 25 * 1024 * 1024 |
437 | | - if payload_bytes > max_json_size: |
438 | | - raise ValueError( |
439 | | - f"Result exceeds maximum allowed JSON payload of {max_json_size} bytes " |
440 | | - f"({max_json_size / (1024 * 1024):.2f} MiB). Got {payload_bytes} bytes." |
| 424 | + statement_id=resp.statement_id, chunk_index=next_idx |
441 | 425 | ) |
| 426 | + next_idx = consume_external_result(chunk) |
442 | 427 |
|
443 | 428 | return records |
444 | 429 |
|
|
0 commit comments