diff --git a/knowledge_base/workflow_vector_search_ingestion_for_rag/.gitignore b/knowledge_base/workflow_vector_search_ingestion_for_rag/.gitignore new file mode 100644 index 0000000..0707725 --- /dev/null +++ b/knowledge_base/workflow_vector_search_ingestion_for_rag/.gitignore @@ -0,0 +1,28 @@ +# Databricks +.databricks/ + +# Python +build/ +dist/ +__pycache__/ +*.egg-info +.venv/ +*.py[cod] + +# Local configuration (keep your settings private) +databricks.local.yml + +# IDE +.idea/ +.vscode/ +.DS_Store + +# Scratch/temporary files +scratch/** +!scratch/README.md + +# Test documents (don't commit large PDFs) +*.pdf +*.png +*.jpg +*.jpeg diff --git a/knowledge_base/workflow_vector_search_ingestion_for_rag/README.md b/knowledge_base/workflow_vector_search_ingestion_for_rag/README.md new file mode 100644 index 0000000..4731162 --- /dev/null +++ b/knowledge_base/workflow_vector_search_ingestion_for_rag/README.md @@ -0,0 +1,163 @@ +# AI Document Processing Workflow for RAG with Structured Streaming + +A Databricks Asset Bundle demonstrating **incremental document processing** using `ai_parse_document`, chunking, and Databricks Workflows with Structured Streaming. + +## Overview + +This example shows how to build an incremental workflow that: +1. **Parses** PDFs and images using [`ai_parse_document`](https://docs.databricks.com/aws/en/sql/language-manual/functions/ai_parse_document) +2. **Extracts** clean text with incremental processing +3. **Chunks** text into smaller pieces suitable for embedding +4. **Indexes** chunks using Databricks Vector Search for RAG applications + +All stages run as Python notebook tasks in a Databricks Workflow using Structured Streaming with serverless compute. + +## Architecture + +``` +Source Documents (UC Volume) + ↓ + Task 1: ai_parse_document → parsed_documents_raw (variant) + ↓ + Task 2: text extraction → parsed_documents_text (string) + ↓ + Task 3: text chunking → parsed_documents_text_chunked (string) + ↓ + Task 4: vector search index → Vector Search Index +``` + +### Key Features + +- **Incremental processing**: Only new files are processed using Structured Streaming checkpoints +- **Serverless compute**: Runs on serverless compute for cost efficiency +- **Task dependencies**: Sequential execution with automatic dependency management +- **Parameterized**: Catalog, schema, volumes, and table names configurable via variables +- **Error handling**: Gracefully handles parsing failures +- **Token-aware chunking**: Smart text splitting based on embedding model tokenization +- **Change Data Feed**: Enables efficient incremental updates through the pipeline + +## Prerequisites + +- Databricks workspace with Unity Catalog +- Databricks CLI v0.218.0+ +- Unity Catalog volumes for: + - Source documents (PDFs/images) + - Parsed output images + - Streaming checkpoints + - Chunking cache +- AI functions (`ai_parse_document`) +- Embedding model endpoint (e.g., `databricks-gte-large-en`) +- Vector Search endpoint (or it will be created automatically) + +## Quick Start + +1. **Install and authenticate** + ```bash + databricks auth login --host https://your-workspace.cloud.databricks.com + ``` + +2. **Configure** `databricks.yml` with your workspace settings + +3. **Validate** the bundle configuration + ```bash + databricks bundle validate + ``` + +4. **Deploy** + ```bash + databricks bundle deploy + ``` + +5. **Upload documents** to your source volume + +6. **Run workflow** from the Databricks UI (Workflows) + +## Configuration + +Edit `databricks.yml`: + +```yaml +variables: + catalog: main # Your catalog + schema: default # Your schema + source_volume_path: /Volumes/main/default/source_documents # Source PDFs + output_volume_path: /Volumes/main/default/parsed_output # Parsed images + checkpoint_base_path: /tmp/checkpoints/ai_parse_workflow # Checkpoints + chunking_cache_location: /tmp/cache/chunking # Chunking cache + raw_table_name: parsed_documents_raw # Table names + text_table_name: parsed_documents_text + chunked_table_name: parsed_documents_text_chunked + vector_index_name: parsed_documents_vector_index + embedding_model_endpoint: databricks-gte-large-en # Embedding model + vector_search_endpoint_name: vector-search-shared-endpoint # Vector Search endpoint +``` + +## Workflow Tasks + +### Task 1: Document Parsing +**File**: `src/transformations/01_parse_documents.py` + +Uses `ai_parse_document` to extract text, tables, and metadata from PDFs/images: +- Reads files from volume using Structured Streaming +- Stores variant output with bounding boxes +- Incremental: checkpointed streaming prevents reprocessing +- Supports PDF, JPG, JPEG, PNG files + +### Task 2: Text Extraction +**File**: `src/transformations/02_extract_text.py` + +Extracts clean concatenated text using `transform()`: +- Reads from previous task's table via streaming +- Handles both parser v1.0 and v2.0 formats +- Uses `transform()` for efficient text extraction +- Includes error handling for failed parses +- Enables Change Data Feed (CDF) on output table + +### Task 3: Text Chunking +**File**: `src/transformations/03_chunk_text.py` + +Chunks text into smaller pieces suitable for embedding: +- Reads from text table via streaming using CDF +- Uses LangChain's RecursiveCharacterTextSplitter +- Token-aware chunking based on embedding model +- Supports multiple embedding models (GTE, BGE, OpenAI) +- Configurable chunk size and overlap +- Creates MD5 hash as chunk_id for deduplication +- Enables Change Data Feed (CDF) on output table + +### Task 4: Vector Search Index +**File**: `src/transformations/04_vector_search_index.py` + +Creates and manages Databricks Vector Search index: +- Creates or validates Vector Search endpoint +- Creates Delta Sync index from chunked text table +- Automatically generates embeddings using specified model +- Triggers index sync to process new/updated chunks +- Uses triggered pipeline type for on-demand updates +- Robust error handling and status checking + +## Project Structure + +``` +. +├── databricks.yml # Bundle configuration +├── requirements.txt # Python dependencies +├── resources/ +│ └── vector_search_ingestion.job.yml # Workflow definition +├── src/ +│ └── transformations/ +│ ├── 01_parse_documents.py # Parse PDFs/images with ai_parse_document +│ ├── 02_extract_text.py # Extract text from parsed documents +│ ├── 03_chunk_text.py # Chunk text for embedding +│ └── 04_vector_search_index.py # Create Vector Search index +└── README.md +``` + +## Resources + +- [Databricks Asset Bundles](https://docs.databricks.com/dev-tools/bundles/) +- [Databricks Workflows](https://docs.databricks.com/workflows/) +- [Structured Streaming](https://docs.databricks.com/structured-streaming/) +- [`ai_parse_document` Function](https://docs.databricks.com/aws/en/sql/language-manual/functions/ai_parse_document) +- [Databricks Vector Search](https://docs.databricks.com/generative-ai/vector-search.html) +- [LangChain Text Splitters](https://python.langchain.com/docs/modules/data_connection/document_transformers/) diff --git a/knowledge_base/workflow_vector_search_ingestion_for_rag/databricks.yml b/knowledge_base/workflow_vector_search_ingestion_for_rag/databricks.yml new file mode 100644 index 0000000..28d0dbe --- /dev/null +++ b/knowledge_base/workflow_vector_search_ingestion_for_rag/databricks.yml @@ -0,0 +1,73 @@ +# This is a Databricks asset bundle definition for ai_parse_document_workflow. +# See https://docs.databricks.com/dev-tools/bundles/index.html for documentation. +bundle: + name: ai_parse_document_workflow + +variables: + catalog: + description: The catalog name for the workflow + default: main + schema: + description: The schema name for the workflow + default: rag + source_volume_path: + description: Source volume path for PDF files + default: /Volumes/main/rag/documents + output_volume_path: + description: Output volume path for processed images + default: /Volumes/main/rag/temp/parsed_output + checkpoint_base_path: + description: Base path for Structured Streaming checkpoints + default: /Volumes/main/rag/temp/checkpoints/ai_parse_workflow + raw_table_name: + description: Table name for raw parsed documents + default: parsed_documents_raw + text_table_name: + description: Table name for extracted text + default: parsed_documents_text + chunk_table_name: + description: Table name for chunked text + default: parsed_documents_text_chunked + embedding_model_endpoint: + description: Embedding model endpoint for chunking and vector search + default: databricks-gte-large-en + chunk_size_tokens: + description: Chunk size in tokens + default: 1024 + chunk_overlap_tokens: + description: Chunk overlap in tokens + default: 256 + vector_search_endpoint_name: + description: Vector Search endpoint name + default: vector-search-shared-endpoint + vector_search_index_name: + description: Vector Search index name + default: parsed_documents_vector_index + vector_search_primary_key: + description: Primary key column for vector search + default: chunk_id + vector_search_embedding_source_column: + description: Text column for embeddings in vector search + default: chunked_text + +include: + - resources/*.yml + +targets: + dev: + # The default target uses 'mode: development' to create a development copy. + # - Deployed resources get prefixed with '[dev my_user_name]' + # - Any job schedules and triggers are paused by default. + # See also https://docs.databricks.com/dev-tools/bundles/deployment-modes.html. + mode: development + default: true + workspace: + host: https://e2-demo-field-eng.cloud.databricks.com/ + + prod: + mode: production + workspace: + host: https://e2-demo-field-eng.cloud.databricks.com/ + permissions: + - group_name: users + level: CAN_VIEW diff --git a/knowledge_base/workflow_vector_search_ingestion_for_rag/requirements.txt b/knowledge_base/workflow_vector_search_ingestion_for_rag/requirements.txt new file mode 100644 index 0000000..343481f --- /dev/null +++ b/knowledge_base/workflow_vector_search_ingestion_for_rag/requirements.txt @@ -0,0 +1,4 @@ +databricks-vectorsearch +transformers +langchain-text-splitters +tiktoken \ No newline at end of file diff --git a/knowledge_base/workflow_vector_search_ingestion_for_rag/resources/vector_search_ingestion.job.yml b/knowledge_base/workflow_vector_search_ingestion_for_rag/resources/vector_search_ingestion.job.yml new file mode 100644 index 0000000..3ba1d0c --- /dev/null +++ b/knowledge_base/workflow_vector_search_ingestion_for_rag/resources/vector_search_ingestion.job.yml @@ -0,0 +1,76 @@ +resources: + jobs: + ai_parse_document_workflow: + name: ai_parse_document_workflow + + # Optional: Add a schedule + # schedule: + # quartz_cron_expression: "0 0 * * * ?" + # timezone_id: "UTC" + + # Job-level parameters shared across all tasks + parameters: + - name: catalog + default: ${var.catalog} + - name: schema + default: ${var.schema} + + environments: + - environment_key: serverless_env + spec: + client: "3" + dependencies: + - "databricks-vectorsearch" + + tasks: + - task_key: parse_documents + environment_key: serverless_env + notebook_task: + notebook_path: ../src/transformations/01_parse_documents.py + base_parameters: + source_volume_path: ${var.source_volume_path} + output_volume_path: ${var.output_volume_path} + checkpoint_location: ${var.checkpoint_base_path}/01_parse_documents + table_name: ${var.raw_table_name} + + - task_key: extract_text + depends_on: + - task_key: parse_documents + environment_key: serverless_env + notebook_task: + notebook_path: ../src/transformations/02_extract_text.py + base_parameters: + checkpoint_location: ${var.checkpoint_base_path}/02_extract_text + source_table_name: ${var.raw_table_name} + table_name: ${var.text_table_name} + + - task_key: chunk_text + depends_on: + - task_key: extract_text + environment_key: serverless_env + notebook_task: + notebook_path: ../src/transformations/03_chunk_text.py + base_parameters: + checkpoint_location: ${var.checkpoint_base_path}/03_chunk_text + chunking_cache_location: ${var.checkpoint_base_path}/03_chunk_text_cache + source_table_name: ${var.text_table_name} + table_name: ${var.chunk_table_name} + embedding_model_endpoint: ${var.embedding_model_endpoint} + chunk_size_tokens: ${var.chunk_size_tokens} + chunk_overlap_tokens: ${var.chunk_overlap_tokens} + + - task_key: create_vector_search_index + depends_on: + - task_key: chunk_text + environment_key: serverless_env + notebook_task: + notebook_path: ../src/transformations/04_vector_search_index.py + base_parameters: + source_table_name: ${var.chunk_table_name} + endpoint_name: ${var.vector_search_endpoint_name} + index_name: ${var.vector_search_index_name} + primary_key: ${var.vector_search_primary_key} + embedding_source_column: ${var.vector_search_embedding_source_column} + embedding_model_endpoint: ${var.embedding_model_endpoint} + + max_concurrent_runs: 1 diff --git a/knowledge_base/workflow_vector_search_ingestion_for_rag/src/explorations/ai_parse_document -- debug output.py b/knowledge_base/workflow_vector_search_ingestion_for_rag/src/explorations/ai_parse_document -- debug output.py new file mode 100644 index 0000000..2f39afa --- /dev/null +++ b/knowledge_base/workflow_vector_search_ingestion_for_rag/src/explorations/ai_parse_document -- debug output.py @@ -0,0 +1,782 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # 🔍 AI Parse Document Debug Interface +# MAGIC +# MAGIC Version 1.3 +# MAGIC +# MAGIC Last update: Oct 6, 2025 +# MAGIC +# MAGIC Changelog: +# MAGIC - Simplified widget parameters: `input_file` and `image_output_path` now accept full volume paths +# MAGIC - Removed separate `catalog`, `schema`, `volume` widgets +# MAGIC - `input_file` supports wildcards for processing multiple files (e.g., `/Volumes/catalog/schema/volume/input/*`) +# MAGIC +# MAGIC ## Overview +# MAGIC This notebook provides a **visual debugging interface** for analyzing the output of Databricks' `ai_parse_document` function. It renders parsed documents with interactive bounding box overlays, allowing you to inspect what content was extracted from each region of your documents. +# MAGIC +# MAGIC ## Features +# MAGIC - 📊 **Visual Bounding Boxes**: Color-coded overlays showing the exact regions where text/elements were detected +# MAGIC - 🎯 **Interactive Tooltips**: Hover over any bounding box to see the parsed content from that region +# MAGIC - 📐 **Automatic Scaling**: Large documents are automatically scaled to fit within 1024px width for optimal viewing +# MAGIC - 🎨 **Element Type Visualization**: Different colors for different element types (text, headers, tables, figures, etc.) +# MAGIC +# MAGIC ## Required Parameters +# MAGIC +# MAGIC This interface requires widget parameters to be configured before running: +# MAGIC +# MAGIC ### 1. `input_file` +# MAGIC - **Description**: Full Unity Catalog volume path to the document(s) you want to parse and visualize +# MAGIC - **Examples**: +# MAGIC - Single file: `/Volumes/catalog/schema/volume/input/document.pdf` +# MAGIC - All files in directory: `/Volumes/catalog/schema/volume/input/*` +# MAGIC - Pattern matching: `/Volumes/catalog/schema/volume/input/*.pdf` +# MAGIC - **Requirements**: Read access to the volume containing your PDF/image files +# MAGIC +# MAGIC ### 2. `image_output_path` +# MAGIC - **Description**: Full Unity Catalog volume path where `ai_parse_document` will store the extracted page images +# MAGIC - **Example**: `/Volumes/catalog/schema/volume/output/` +# MAGIC - **Requirements**: Write access required for storing intermediate image outputs +# MAGIC - **Note**: As documented in the [official Databricks documentation](https://docs.databricks.com/aws/en/sql/language-manual/functions/ai_parse_document), this path is used by the parsing function to store page images that are referenced in the output +# MAGIC +# MAGIC ### 3. `page_selection` +# MAGIC - **Description**: Specifies which pages to display in the visualization +# MAGIC - **Supported formats**: +# MAGIC - `"all"` or leave empty: Display all pages +# MAGIC - `"3"`: Display only page 3 (1-indexed) +# MAGIC - `"1-5"`: Display pages 1 through 5 (inclusive, 1-indexed) +# MAGIC - `"1,3,5"`: Display specific pages (1-indexed) +# MAGIC - `"1-3,7,10-12"`: Mixed ranges and individual pages +# MAGIC +# MAGIC ## Usage Instructions +# MAGIC +# MAGIC 1. **Clone this notebook** to your workspace: +# MAGIC - Select **"File -> Clone"** button in the top toolbar +# MAGIC - Choose your desired location in your workspace +# MAGIC - This ensures you have a personal copy you can modify and run +# MAGIC +# MAGIC 2. **Prepare your Unity Catalog volumes**: +# MAGIC - Create or identify a volume for your PDF/image files +# MAGIC - Create or identify a volume for output images +# MAGIC - Upload your PDF files to the input location +# MAGIC +# MAGIC 3. **Configure the widget parameters** at the top of this notebook: +# MAGIC - Set `input_file` to the full volume path (file or directory with wildcard) +# MAGIC - Set `image_output_path` to the full volume path for outputs +# MAGIC - Set `page_selection` to control which pages to visualize +# MAGIC +# MAGIC 4. **Run all code cells** which will generate visual debugging results. +# MAGIC +# MAGIC ## What You'll See +# MAGIC +# MAGIC - **Document Summary**: Overview of pages, element counts, and document metadata +# MAGIC - **Color Legend**: Visual guide showing which colors represent which element types +# MAGIC - **Annotated Images**: Each page with overlaid bounding boxes +# MAGIC - Hover over any box to see the extracted content +# MAGIC - Yellow highlight indicates the currently hovered element +# MAGIC - **Parsed Elements List**: Complete list of all extracted elements with their content + +# COMMAND ---------- + +# Exec Parameters + +dbutils.widgets.text("input_file", "/Volumes/main/default/source_documents/sample.pdf") +dbutils.widgets.text("image_output_path", "/Volumes/main/default/parsed_output/") +dbutils.widgets.text("page_selection", "all") + +input_file = dbutils.widgets.get("input_file") +image_output_path = dbutils.widgets.get("image_output_path") +page_selection = dbutils.widgets.get("page_selection") + +# COMMAND ---------- + +# DBTITLE 1,Configuration Parameters +# Path configuration - use widget values as-is + +source_files = input_file + +# Parse page selection string and return list of page indices to display. +# +# Supported formats: +# - "all" or None: Display all pages +# - "3": Display specific page (1-indexed) +# - "1-5": Display page range (inclusive, 1-indexed) +# - "1,3,5": Display list of specific pages (1-indexed) +# - "1-3,7,10-12": Mixed ranges and individual pages +page_selection = f"{page_selection}" + +# COMMAND ---------- + +# DBTITLE 1,Run Document Parse Code (may take some time) +# SQL statement with ai_parse_document() +# Note: input_file can be a single file path or a directory path with wildcard +sql = f''' +with parsed_documents AS ( + SELECT + path, + ai_parse_document(content + , + map( + 'version', '2.0', + 'imageOutputPath', '{image_output_path}', + 'descriptionElementTypes', '*' + ) + ) as parsed + FROM + read_files('{source_files}', format => 'binaryFile') +) +select * from parsed_documents +''' + +parsed_results = [row.parsed for row in spark.sql(sql).collect()] + +# COMMAND ---------- + +import json +from typing import Dict, List, Any, Optional, Tuple, Set, Union +from IPython.display import HTML, display +import base64 +import os +from PIL import Image +import io + +class DocumentRenderer: + def __init__(self): + # Color mapping for different element types + self.element_colors = { + 'section_header': '#FF6B6B', + 'text': '#4ECDC4', + 'figure': '#45B7D1', + 'caption': '#96CEB4', + 'page_footer': '#FFEAA7', + 'page_header': '#DDA0DD', + 'table': '#98D8C8', + 'list': '#F7DC6F', + 'default': '#BDC3C7' + } + + def _parse_page_selection(self, page_selection: Union[str, None], total_pages: int) -> Set[int]: + """Parse page selection string and return set of page indices (0-based). + + Args: + page_selection: Selection string or None + total_pages: Total number of pages available + + Returns: + Set of 0-based page indices to display + """ + # Handle None or "all" - return all pages + if page_selection is None or page_selection.lower() == "all": + return set(range(total_pages)) + + selected_pages = set() + + # Clean the input + page_selection = page_selection.strip() + + # Split by commas for multiple selections + parts = page_selection.split(',') + + for part in parts: + part = part.strip() + + # Check if it's a range (contains hyphen) + if '-' in part: + try: + # Split range and convert to integers + range_parts = part.split('-') + if len(range_parts) == 2: + start = int(range_parts[0].strip()) + end = int(range_parts[1].strip()) + + # Convert from 1-indexed to 0-indexed + start_idx = start - 1 + end_idx = end - 1 + + # Add all pages in range (inclusive) + for i in range(start_idx, end_idx + 1): + if 0 <= i < total_pages: + selected_pages.add(i) + except ValueError: + print(f"Warning: Invalid range '{part}' in page selection") + else: + # Single page number + try: + page_num = int(part.strip()) + # Convert from 1-indexed to 0-indexed + page_idx = page_num - 1 + if 0 <= page_idx < total_pages: + selected_pages.add(page_idx) + else: + print(f"Warning: Page {page_num} is out of range (1-{total_pages})") + except ValueError: + print(f"Warning: Invalid page number '{part}' in page selection") + + # If no valid pages were selected, default to all pages + if not selected_pages: + print(f"Warning: No valid pages in selection '{page_selection}'. Showing all pages.") + return set(range(total_pages)) + + return selected_pages + + def _get_element_color(self, element_type: str) -> str: + """Get color for element type.""" + return self.element_colors.get(element_type.lower(), self.element_colors['default']) + + def _get_image_dimensions(self, image_path: str) -> Optional[Tuple[int, int]]: + """Get dimensions of an image file.""" + try: + if os.path.exists(image_path): + with Image.open(image_path) as img: + return img.size # Returns (width, height) + return None + except Exception as e: + print(f"Error getting image dimensions for {image_path}: {e}") + return None + + def _load_image_as_base64(self, image_path: str) -> Optional[str]: + """Load image from file path and convert to base64.""" + try: + if os.path.exists(image_path): + with open(image_path, 'rb') as img_file: + img_data = img_file.read() + img_base64 = base64.b64encode(img_data).decode('utf-8') + ext = os.path.splitext(image_path)[1].lower() + if ext in ['.jpg', '.jpeg']: + return f"data:image/jpeg;base64,{img_base64}" + elif ext in ['.png']: + return f"data:image/png;base64,{img_base64}" + else: + return f"data:image/jpeg;base64,{img_base64}" + return None + except Exception as e: + print(f"Error loading image {image_path}: {e}") + return None + + def _render_element_content(self, element: Dict, for_tooltip: bool = False) -> str: + """Render element content with appropriate formatting for both tooltip and element list display. + + Args: + element: The element dictionary containing content/description + for_tooltip: Whether this is for tooltip display (affects styling and truncation) + """ + element_type = element.get('type', 'unknown') + content = element.get('content', '') + description = element.get('description', '') + + display_content = "" + + if content: + if element_type == 'table': + # Render the HTML table with styling + table_html = content + + # Apply different styling based on context + if for_tooltip: + # Compact styling for tooltips with light theme + # Use full width available for tooltip tables + table_style = f'''style="width: 100%; border-collapse: collapse; margin: 5px 0; font-size: 10px;"''' + th_style = 'style="border: 1px solid #ddd; padding: 4px; background: #f8f9fa; color: #333; font-weight: bold; text-align: left; font-size: 10px;"' + td_style = 'style="border: 1px solid #ddd; padding: 4px; color: #333; font-size: 10px;"' + thead_style = 'style="background: #e9ecef;"' + else: + # Full styling for element list + table_style = '''style="width: 100%; border-collapse: collapse; margin: 10px 0; font-size: 13px;"''' + th_style = 'style="border: 1px solid #ddd; padding: 8px; background: #f5f5f5; font-weight: bold; text-align: left;"' + td_style = 'style="border: 1px solid #ddd; padding: 8px;"' + thead_style = 'style="background: #f0f0f0;"' + + # Apply styling transformations + if '' in table_html: + table_html = table_html.replace('
', f'
') + if '' in table_html: + table_html = table_html.replace('', f'') + + if for_tooltip: + display_content = table_html + else: + display_content = f"
{table_html}
" + else: + # Regular content handling + if for_tooltip and len(content) > 500: + # Truncate for tooltip display and escape HTML for safety + display_content = self._escape_for_html_attribute(content[:500] + "...") + else: + display_content = self._escape_for_html_attribute(content) if for_tooltip else content + elif description: + desc_content = description + if for_tooltip and len(desc_content) > 500: + desc_content = desc_content[:500] + "..." + + if for_tooltip: + display_content = self._escape_for_html_attribute(f"Description: {desc_content}") + else: + display_content = f"Description: {desc_content}" + else: + display_content = "No content available" if for_tooltip else "No content" + + return display_content + + def _escape_for_html_attribute(self, text: str) -> str: + """Escape text for safe use in HTML attributes.""" + return (text.replace('&', '&') + .replace('<', '<') + .replace('>', '>') + .replace('"', '"') + .replace("'", ''') + .replace('\n', '
')) + + def _calculate_tooltip_width(self, element: Dict, image_width: int) -> int: + """Calculate dynamic tooltip width based on table content.""" + element_type = element.get('type', 'unknown') + content = element.get('content', '') + + if element_type == 'table' and content: + # Count columns by looking for ', content, re.DOTALL | re.IGNORECASE) + if first_row_match: + first_row = first_row_match.group(1) + # Count th or td tags + th_count = len(re.findall(r']*>', first_row, re.IGNORECASE)) + td_count = len(re.findall(r']*>', first_row, re.IGNORECASE)) + column_count = max(th_count, td_count) + + if column_count > 0: + # Base width + additional width per column + base_width = 300 + width_per_column = 80 + calculated_width = base_width + (column_count * width_per_column) + + # Cap at 4/5th of image width + max_width = int(image_width * 0.8) + return min(calculated_width, max_width) + + # Default width for non-tables or when calculation fails + return 400 + + def _create_annotated_image(self, page: Dict, elements: List[Dict]) -> str: + """Create annotated image with SCALING to fit within 1024px width.""" + image_uri = page.get('image_uri', '') + page_id = page.get('id', 0) + + if not image_uri: + return "

No image URI found for this page

" + + # Load image + img_data_uri = self._load_image_as_base64(image_uri) + if not img_data_uri: + return f""" +
+ Could not load image: {image_uri}
+ Make sure the file exists and is accessible. +
+ """ + + # Get original image dimensions + original_dimensions = self._get_image_dimensions(image_uri) + if not original_dimensions: + # Fallback: display without explicit scaling + original_width, original_height = 1024, 768 # Default fallback + else: + original_width, original_height = original_dimensions + + # Calculate scaling factor to fit within 1024px width + max_display_width = 1024 + scale_factor = 1.0 + display_width = original_width + display_height = original_height + + if original_width > max_display_width: + scale_factor = max_display_width / original_width + display_width = max_display_width + display_height = int(original_height * scale_factor) + + # Filter elements for this page and collect their bounding boxes + page_elements = [] + + for elem in elements: + elem_bboxes = [] + for bbox in elem.get('bbox', []): + if bbox.get('page_id', 0) == page_id: + coord = bbox.get('coord', []) + if len(coord) >= 4: + elem_bboxes.append(bbox) + + if elem_bboxes: + page_elements.append({ + 'element': elem, + 'bboxes': elem_bboxes + }) + + if not page_elements: + return f"

No elements found for page {page_id}

" + + header_info = f""" +
+ Page {page_id + 1}: {len(page_elements)} elements
+ Original size: {original_width}×{original_height}px | + Display size: {display_width}×{display_height}px | + Scale factor: {scale_factor:.3f}
+
+ """ + + # Generate unique container ID for this page + container_id = f"page_container_{page_id}_{id(self)}" + + # Create bounding box overlays using SCALED coordinates with hover functionality + overlays = [] + + for idx, item in enumerate(page_elements): + element = item['element'] + element_id = element.get('id', 'N/A') + element_type = element.get('type', 'unknown') + color = self._get_element_color(element_type) + + # Use the shared content renderer for tooltip + tooltip_content = self._render_element_content(element, for_tooltip=True) + + # Calculate dynamic tooltip width + tooltip_width = self._calculate_tooltip_width(element, display_width) + + # Tables should render as HTML, other content should be escaped + + for bbox_idx, bbox in enumerate(item['bboxes']): + coord = bbox.get('coord', []) + if len(coord) >= 4: + x1, y1, x2, y2 = coord + + # Apply scaling to coordinates + scaled_x1 = x1 * scale_factor + scaled_y1 = y1 * scale_factor + scaled_x2 = x2 * scale_factor + scaled_y2 = y2 * scale_factor + + width = scaled_x2 - scaled_x1 + height = scaled_y2 - scaled_y1 + + # Skip invalid boxes + if width <= 0 or height <= 0: + continue + + # Position label above box when possible + label_top = -18 if scaled_y1 >= 18 else 2 + + # Unique ID for this bounding box + box_id = f"bbox_{page_id}_{idx}_{bbox_idx}" + + # Calculate tooltip position (prefer right side, but switch to left if needed) + tooltip_left = 10 + + overlay = f""" +
+
+ {element_type.upper()[:6]}#{element_id} +
+ +
+
+ {element_type.upper()} #{element_id} +
+
+ {tooltip_content} +
+
+
+ """ + overlays.append(overlay) + + # Pure CSS hover functionality (works in Databricks) + styles = f""" + + """ + + return f""" + {header_info} + {styles} +
+ Page {page_id + 1} + {''.join(overlays)} +
+ """ + + def _create_page_elements_list(self, page_id: int, elements: List[Dict]) -> str: + """Create a detailed list of elements for a specific page.""" + # Filter elements for this page + page_elements = [] + + for elem in elements: + elem_bboxes = [] + for bbox in elem.get('bbox', []): + if bbox.get('page_id', 0) == page_id: + elem_bboxes.append(bbox) + + if elem_bboxes: + page_elements.append(elem) + + if not page_elements: + return f"

No elements found for page {page_id + 1}

" + + html_parts = [] + + for element in page_elements: + element_id = element.get('id', 'N/A') + element_type = element.get('type', 'unknown') + color = self._get_element_color(element_type) + + # Get bounding box info for this page only + bbox_info = "No bbox" + bbox_list = element.get('bbox', []) + if bbox_list: + bbox_details = [] + for bbox in bbox_list: + if bbox.get('page_id', 0) == page_id: + coord = bbox.get('coord', []) + if len(coord) >= 4: + bbox_details.append(f"[{coord[0]:.0f}, {coord[1]:.0f}, {coord[2]:.0f}, {coord[3]:.0f}]") + bbox_info = "; ".join(bbox_details) if bbox_details else "Invalid bbox" + + # Use the shared content renderer for element list display + display_content = self._render_element_content(element, for_tooltip=False) + + element_html = f""" +
+
+

+ {element_type.upper().replace('_', ' ')} (ID: {element_id}) +

+ + {bbox_info} + +
+
+ {display_content} +
+
+ """ + html_parts.append(element_html) + + return f""" +
+

📋 Page {page_id + 1} Elements ({len(page_elements)} items)

+ {''.join(html_parts)} +
+ """ + + def _create_summary(self, document: Dict, metadata: Dict, selected_pages: Set[int], total_pages: int) -> str: + """Create a summary with page selection info.""" + elements = document.get('elements', []) + + # Count elements only on selected pages + selected_elements = [] + for elem in elements: + for bbox in elem.get('bbox', []): + if bbox.get('page_id', 0) in selected_pages: + selected_elements.append(elem) + break + + # Count by type (for selected pages) + type_counts = {} + for elem in selected_elements: + elem_type = elem.get('type', 'unknown') + type_counts[elem_type] = type_counts.get(elem_type, 0) + 1 + + type_list = ', '.join([f"{t}: {c}" for t, c in type_counts.items()]) + + # Create page selection info + if len(selected_pages) == total_pages: + page_info = f"All {total_pages} pages" + else: + # Convert to 1-indexed for display + page_nums = sorted([p + 1 for p in selected_pages]) + if len(page_nums) <= 10: + page_info = f"Pages {', '.join(map(str, page_nums))} ({len(selected_pages)} of {total_pages})" + else: + page_info = f"{len(selected_pages)} of {total_pages} pages selected" + + return f""" +
+

📄 Document Summary

+

Displaying: {page_info}

+

Elements on selected pages: {len(selected_elements)}

+

Element Types: {type_list if type_list else 'None'}

+

Document ID: {str(metadata.get('id', 'N/A'))[:12]}...

+
+ """ + + def render_document(self, parsed_result: Any, page_selection: Union[str, None] = None) -> None: + """Main render function with page selection support. + + Args: + parsed_result: The parsed document result + page_selection: Page selection string. Supported formats: + - "all" or None: Display all pages + - "3": Display only page 3 (1-indexed) + - "1-5": Display pages 1 through 5 (inclusive) + - "1,3,5": Display specific pages + - "1-3,7,10-12": Mixed format + """ + try: + # Convert to dict + if hasattr(parsed_result, 'toPython'): + parsed_dict = parsed_result.toPython() + elif hasattr(parsed_result, 'toJson'): + parsed_dict = json.loads(parsed_result.toJson()) + elif isinstance(parsed_result, dict): + parsed_dict = parsed_result + else: + display(HTML(f"

❌ Could not convert result. Type: {type(parsed_result)}

")) + return + + # Extract components + document = parsed_dict.get('document', {}) + pages = document.get('pages', []) + elements = document.get('elements', []) + metadata = parsed_dict.get('metadata', {}) + + if not elements: + display(HTML("

❌ No elements found in document

")) + return + + # Parse page selection + selected_pages = self._parse_page_selection(page_selection, len(pages)) + + # Display title + display(HTML("

🔍 AI Parse Document Results

")) + + # Display summary with page selection info + summary_html = self._create_summary(document, metadata, selected_pages, len(pages)) + display(HTML(summary_html)) + + # Display color legend + legend_items = [] + for elem_type, color in self.element_colors.items(): + if elem_type != 'default': + legend_items.append(f""" + + + {elem_type.replace('_', ' ').title()} + + """) + + display(HTML(f""" +
+ 🎨 Element Colors:
+ {''.join(legend_items)} +
+ """)) + + # Display annotated images with their corresponding elements (filtered by selection) + if pages: + display(HTML("

🖼️ Annotated Images & Elements

")) + + # Sort selected pages for display + sorted_selected = sorted(selected_pages) + + for page_idx in sorted_selected: + if page_idx < len(pages): + page = pages[page_idx] + + # Display the annotated image + annotated_html = self._create_annotated_image(page, elements) + display(HTML(f"
{annotated_html}
")) + + # Display elements for this page immediately after the image + page_id = page.get('id', page_idx) + page_elements_html = self._create_page_elements_list(page_id, elements) + display(HTML(page_elements_html)) + + except Exception as e: + display(HTML(f"

❌ Error: {str(e)}

")) + import traceback + display(HTML(f"
{traceback.format_exc()}
")) + + +# Simple usage functions +def render_ai_parse_output(parsed_result, page_selection=None): + """Simple function to render ai_parse_document output with page selection. + + Args: + parsed_result: The parsed document result + page_selection: Optional page selection string. Examples: + - None or "all": Display all pages + - "3": Display only page 3 + - "1-5": Display pages 1 through 5 + - "1,3,5": Display specific pages + - "1-3,7,10-12": Mixed format + """ + renderer = DocumentRenderer() + renderer.render_document(parsed_result, page_selection) + +# COMMAND ---------- + +# DBTITLE 1,Debug Visualization Results +for parsed_result in parsed_results: + render_ai_parse_output(parsed_result, page_selection) \ No newline at end of file diff --git a/knowledge_base/workflow_vector_search_ingestion_for_rag/src/transformations/01_parse_documents.py b/knowledge_base/workflow_vector_search_ingestion_for_rag/src/transformations/01_parse_documents.py new file mode 100644 index 0000000..39b1c80 --- /dev/null +++ b/knowledge_base/workflow_vector_search_ingestion_for_rag/src/transformations/01_parse_documents.py @@ -0,0 +1,87 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Parse Documents using ai_parse_document +# MAGIC +# MAGIC This notebook uses Structured Streaming to incrementally parse PDFs and images using the ai_parse_document function. + +# COMMAND ---------- + +# Get parameters +dbutils.widgets.text("catalog", "", "Catalog name") +dbutils.widgets.text("schema", "", "Schema name") +dbutils.widgets.text("source_volume_path", "", "Source volume path") +dbutils.widgets.text("output_volume_path", "", "Output volume path") +dbutils.widgets.text("checkpoint_location", "", "Checkpoint location") +dbutils.widgets.text("table_name", "parsed_documents_raw", "Output table name") + +catalog = dbutils.widgets.get("catalog") +schema = dbutils.widgets.get("schema") +source_volume_path = dbutils.widgets.get("source_volume_path") +output_volume_path = dbutils.widgets.get("output_volume_path") +checkpoint_location = dbutils.widgets.get("checkpoint_location") +table_name = dbutils.widgets.get("table_name") + +# COMMAND ---------- + +print(source_volume_path) +print(output_volume_path) +print(checkpoint_location) +print(table_name) + +# COMMAND ---------- + +# Set catalog and schema +spark.sql(f"USE CATALOG {catalog}") +spark.sql(f"USE SCHEMA {schema}") + +# COMMAND ---------- + +from pyspark.sql.functions import col, current_timestamp, expr +from pyspark.sql.types import StructType, StructField, StringType, BinaryType, TimestampType, LongType + +# Define schema for binary files (must match exact schema expected by binaryFile format) +binary_file_schema = StructType([ + StructField("path", StringType(), False), + StructField("modificationTime", TimestampType(), False), + StructField("length", LongType(), False), + StructField("content", BinaryType(), True) +]) + +# Read files using Structured Streaming +files_df = (spark.readStream + .format("binaryFile") + .schema(binary_file_schema) + .option("pathGlobFilter", "*.{pdf,jpg,jpeg,png}") + .option("recursiveFileLookup", "true") + .load(source_volume_path) +) + +# Parse documents with ai_parse_document +parsed_df = (files_df + .repartition(8, expr("crc32(path) % 8")) + .withColumn("parsed", + expr(f""" + ai_parse_document( + content, + map( + 'version', '2.0', + 'imageOutputPath', '{output_volume_path}', + 'descriptionElementTypes', '*' + ) + ) + """) + ) + .withColumn("parsed_at", current_timestamp()) + .select("path", "parsed", "parsed_at") +) + +# Write to Delta table with streaming +(parsed_df.writeStream + .format("delta") + .outputMode("append") + .option("checkpointLocation", checkpoint_location) + .option("delta.feature.variantType-preview", "supported") + .option("mergeSchema", "true") + .trigger(availableNow=True) + .toTable(table_name) +) diff --git a/knowledge_base/workflow_vector_search_ingestion_for_rag/src/transformations/02_extract_text.py b/knowledge_base/workflow_vector_search_ingestion_for_rag/src/transformations/02_extract_text.py new file mode 100644 index 0000000..ad5c419 --- /dev/null +++ b/knowledge_base/workflow_vector_search_ingestion_for_rag/src/transformations/02_extract_text.py @@ -0,0 +1,77 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Extract Text from Parsed Documents +# MAGIC +# MAGIC This notebook uses Structured Streaming to extract clean text from parsed documents. + +# COMMAND ---------- + +# Get parameters +dbutils.widgets.text("catalog", "", "Catalog name") +dbutils.widgets.text("schema", "", "Schema name") +dbutils.widgets.text("checkpoint_location", "", "Checkpoint location") +dbutils.widgets.text("source_table_name", "parsed_documents_raw", "Source table name") +dbutils.widgets.text("table_name", "parsed_documents_text", "Output table name") + +catalog = dbutils.widgets.get("catalog") +schema = dbutils.widgets.get("schema") +checkpoint_location = dbutils.widgets.get("checkpoint_location") +source_table_name = dbutils.widgets.get("source_table_name") +table_name = dbutils.widgets.get("table_name") + +# COMMAND ---------- + +# Set catalog and schema +spark.sql(f"USE CATALOG {catalog}") +spark.sql(f"USE SCHEMA {schema}") + +# COMMAND ---------- + +from pyspark.sql.functions import col, concat_ws, expr, lit, when + +# Read from source table using Structured Streaming +parsed_stream = (spark.readStream + .format("delta") + .table(source_table_name) +) + +# Extract text from parsed documents +text_df = parsed_stream.withColumn( + "text", + when( + expr("try_cast(parsed:error_status AS STRING)").isNotNull(), + lit(None) + ).otherwise( + concat_ws( + "\n\n", + expr(""" + transform( + CASE + WHEN try_cast(parsed:metadata:version AS STRING) = '1.0' + THEN try_cast(parsed:document:pages AS ARRAY) + ELSE try_cast(parsed:document:elements AS ARRAY) + END, + element -> try_cast(element:content AS STRING) + ) + """) + ) + ) +).withColumn( + "error_status", + expr("try_cast(parsed:error_status AS STRING)") +).select("path", "text", "error_status", "parsed_at") + +# Write to Delta table with streaming +(text_df.writeStream + .format("delta") + .outputMode("append") + .option("checkpointLocation", checkpoint_location) + .option("mergeSchema", "true") + .trigger(availableNow=True) + .toTable(table_name) +) + +# COMMAND ---------- + +# Enable CDF on the existing table +spark.sql(f"ALTER TABLE {catalog}.{schema}.{table_name} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)") \ No newline at end of file diff --git a/knowledge_base/workflow_vector_search_ingestion_for_rag/src/transformations/03_chunk_text.py b/knowledge_base/workflow_vector_search_ingestion_for_rag/src/transformations/03_chunk_text.py new file mode 100644 index 0000000..676b6fb --- /dev/null +++ b/knowledge_base/workflow_vector_search_ingestion_for_rag/src/transformations/03_chunk_text.py @@ -0,0 +1,284 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Chunk Text from Parsed Documents +# MAGIC +# MAGIC This notebook uses Structured Streaming to chunk text from parsed documents into smaller pieces suitable for embedding. + +# COMMAND ---------- + +# MAGIC %pip install -r ../../requirements.txt -q +# MAGIC %restart_python + +# COMMAND ---------- + +# Get parameters +dbutils.widgets.text("catalog", "", "Catalog name") +dbutils.widgets.text("schema", "", "Schema name") +dbutils.widgets.text("checkpoint_location", "", "Checkpoint location") +dbutils.widgets.text("chunking_cache_location", "", "Chunking cache location") +dbutils.widgets.text("source_table_name", "parsed_documents_text", "Source table name") +dbutils.widgets.text("table_name", "parsed_documents_text_chunked", "Output table name") +dbutils.widgets.text("embedding_model_endpoint", "databricks-gte-large-en", "Embedding model endpoint") +dbutils.widgets.text("chunk_size_tokens", "1024", "Chunk size in tokens") +dbutils.widgets.text("chunk_overlap_tokens", "256", "Chunk overlap in tokens") + +catalog = dbutils.widgets.get("catalog") +schema = dbutils.widgets.get("schema") +checkpoint_location = dbutils.widgets.get("checkpoint_location") +chunking_cache_location = dbutils.widgets.get("chunking_cache_location") +source_table_name = dbutils.widgets.get("source_table_name") +table_name = dbutils.widgets.get("table_name") +embedding_model_endpoint = dbutils.widgets.get("embedding_model_endpoint") +chunk_size_tokens = int(dbutils.widgets.get("chunk_size_tokens")) +chunk_overlap_tokens = int(dbutils.widgets.get("chunk_overlap_tokens")) + +# COMMAND ---------- + +# Set catalog and schema +spark.sql(f"USE CATALOG {catalog}") +spark.sql(f"USE SCHEMA {schema}") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Define chunking utilities + +# COMMAND ---------- + +from typing import Optional +import os + +# Set the TRANSFORMERS_CACHE environment variable to a writable directory +os.environ['TRANSFORMERS_CACHE'] = chunking_cache_location +HF_CACHE_DIR = chunking_cache_location + +# Embedding Models Configuration +EMBEDDING_MODELS = { + "gte-large-en-v1.5": { + "context_window": 8192, + "type": "SENTENCE_TRANSFORMER", + }, + "bge-large-en-v1.5": { + "context_window": 512, + "type": "SENTENCE_TRANSFORMER", + }, + "bge_large_en_v1_5": { + "context_window": 512, + "type": "SENTENCE_TRANSFORMER", + }, + "text-embedding-ada-002": { + "context_window": 8192, + "type": "OPENAI", + }, + "text-embedding-3-small": { + "context_window": 8192, + "type": "OPENAI", + }, + "text-embedding-3-large": { + "context_window": 8192, + "type": "OPENAI", + }, +} + +def get_embedding_model_tokenizer(endpoint_type: str) -> Optional[dict]: + from transformers import AutoTokenizer + import tiktoken + + EMBEDDING_MODELS_W_TOKENIZER = { + "gte-large-en-v1.5": { + "tokenizer": lambda: AutoTokenizer.from_pretrained( + "Alibaba-NLP/gte-large-en-v1.5", cache_dir=HF_CACHE_DIR + ), + "context_window": 8192, + "type": "SENTENCE_TRANSFORMER", + }, + "bge-large-en-v1.5": { + "tokenizer": lambda: AutoTokenizer.from_pretrained( + "BAAI/bge-large-en-v1.5", cache_dir=HF_CACHE_DIR + ), + "context_window": 512, + "type": "SENTENCE_TRANSFORMER", + }, + "bge_large_en_v1_5": { + "tokenizer": lambda: AutoTokenizer.from_pretrained( + "BAAI/bge-large-en-v1.5", cache_dir=HF_CACHE_DIR + ), + "context_window": 512, + "type": "SENTENCE_TRANSFORMER", + }, + "text-embedding-ada-002": { + "context_window": 8192, + "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-ada-002"), + "type": "OPENAI", + }, + "text-embedding-3-small": { + "context_window": 8192, + "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-3-small"), + "type": "OPENAI", + }, + "text-embedding-3-large": { + "context_window": 8192, + "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-3-large"), + "type": "OPENAI", + }, + } + return EMBEDDING_MODELS_W_TOKENIZER.get(endpoint_type).get("tokenizer") + +def get_embedding_model_config(endpoint_type: str) -> Optional[dict]: + return EMBEDDING_MODELS.get(endpoint_type) + +def extract_endpoint_type(llm_endpoint) -> Optional[str]: + try: + return llm_endpoint.config.served_entities[0].external_model.name + except AttributeError: + try: + return llm_endpoint.config.served_entities[0].foundation_model.name + except AttributeError: + return None + +def detect_fmapi_embedding_model_type(model_serving_endpoint: str): + from databricks.sdk import WorkspaceClient + + client = WorkspaceClient() + try: + llm_endpoint = client.serving_endpoints.get(name=model_serving_endpoint) + endpoint_type = extract_endpoint_type(llm_endpoint) + except Exception as e: + endpoint_type = None + + embedding_config = ( + get_embedding_model_config(endpoint_type) if endpoint_type else None + ) + + if embedding_config: + embedding_config["tokenizer"] = ( + get_embedding_model_tokenizer(endpoint_type) if endpoint_type else None + ) + + return (endpoint_type, embedding_config) + +def get_recursive_character_text_splitter( + model_serving_endpoint: str, + embedding_model_name: str = None, + chunk_size_tokens: int = None, + chunk_overlap_tokens: int = 0, +): + from langchain_text_splitters import RecursiveCharacterTextSplitter + from transformers import AutoTokenizer + import tiktoken + + # Detect the embedding model and its configuration + embedding_model_name, chunk_spec = detect_fmapi_embedding_model_type( + model_serving_endpoint + ) + + if chunk_spec is None or embedding_model_name is None: + # Fall back to using provided embedding_model_name + chunk_spec = EMBEDDING_MODELS.get(embedding_model_name) + if chunk_spec is None: + raise ValueError( + f"Embedding model `{embedding_model_name}` not found. Available models: {EMBEDDING_MODELS.keys()}" + ) + + # Update chunk specification based on provided parameters + chunk_spec["chunk_size_tokens"] = ( + chunk_size_tokens or chunk_spec["context_window"] + ) + chunk_spec["chunk_overlap_tokens"] = chunk_overlap_tokens + + def _recursive_character_text_splitter(text: str): + if text is None or text == "": + return [] + + tokenizer = chunk_spec["tokenizer"]() + if chunk_spec["type"] == "SENTENCE_TRANSFORMER": + splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( + tokenizer, + chunk_size=chunk_spec["chunk_size_tokens"], + chunk_overlap=chunk_spec["chunk_overlap_tokens"], + ) + elif chunk_spec["type"] == "OPENAI": + splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( + tokenizer.name, + chunk_size=chunk_spec["chunk_size_tokens"], + chunk_overlap=chunk_spec["chunk_overlap_tokens"], + ) + else: + raise ValueError(f"Unsupported model type: {chunk_spec['type']}") + return splitter.split_text(text) + + return _recursive_character_text_splitter + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Create and apply chunking function + +# COMMAND ---------- + +from pyspark.sql.functions import col, explode, udf, md5 +from pyspark.sql.types import ArrayType, StringType + +# Create the chunking function +print(f"Creating chunking function for embedding model: {embedding_model_endpoint}") +print(f"Chunk size: {chunk_size_tokens} tokens, Overlap: {chunk_overlap_tokens} tokens") + +recursive_character_text_splitter_fn = get_recursive_character_text_splitter( + model_serving_endpoint=embedding_model_endpoint, + chunk_size_tokens=chunk_size_tokens, + chunk_overlap_tokens=chunk_overlap_tokens, +) + +# Create a UDF from the chunking function +chunking_udf = udf( + recursive_character_text_splitter_fn, + returnType=ArrayType(StringType()) +) + +# Read from source table using Structured Streaming +parsed_stream = (spark.readStream + .format("delta") + .table(source_table_name) +) + +# Apply chunking: only process rows without errors and with non-null text +chunked_array_df = (parsed_stream + .filter(col("error_status").isNull()) + .filter(col("text").isNotNull()) + .withColumn("text_chunked_array", chunking_udf(col("text"))) +) + +# Explode the array of chunks into separate rows +chunked_df = (chunked_array_df + .select( + col("path"), + col("parsed_at"), + explode("text_chunked_array").alias("chunked_text") + ) +) + +# Add chunk_id as MD5 hash of the chunked text +text_df = (chunked_df + .withColumn("chunk_id", md5(col("chunked_text"))) + .select( + col("chunk_id"), + col("chunked_text"), + col("path"), + col("parsed_at") + ) +) + +# Write to Delta table with streaming +(text_df.writeStream + .format("delta") + .outputMode("append") + .option("checkpointLocation", checkpoint_location) + .option("mergeSchema", "true") + .trigger(availableNow=True) + .toTable(table_name) +) + +# COMMAND ---------- + +# Enable CDF on the existing table +spark.sql(f"ALTER TABLE {catalog}.{schema}.{table_name} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)") \ No newline at end of file diff --git a/knowledge_base/workflow_vector_search_ingestion_for_rag/src/transformations/04_vector_search_index.py b/knowledge_base/workflow_vector_search_ingestion_for_rag/src/transformations/04_vector_search_index.py new file mode 100644 index 0000000..2e1e2cb --- /dev/null +++ b/knowledge_base/workflow_vector_search_ingestion_for_rag/src/transformations/04_vector_search_index.py @@ -0,0 +1,197 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Create Vector Search Index +# MAGIC +# MAGIC This notebook creates a Databricks Vector Search endpoint and index from document text. + +# COMMAND ---------- + +# MAGIC %pip install -r ../../requirements.txt -q +# MAGIC %restart_python + +# COMMAND ---------- + +# Get parameters +dbutils.widgets.text("catalog", "", "Catalog name") +dbutils.widgets.text("schema", "", "Schema name") +dbutils.widgets.text("source_table_name", "parsed_documents_text_chunked", "Source table name") +dbutils.widgets.text("endpoint_name", "vector-search-shared-endpoint", "Vector Search endpoint name") +dbutils.widgets.text("index_name", "parsed_documents_vector_index", "Vector Search index name") +dbutils.widgets.text("primary_key", "chunk_id", "Primary key column") +dbutils.widgets.text("embedding_source_column", "chunked_text", "Text column for embeddings") +dbutils.widgets.text("embedding_model_endpoint", "databricks-gte-large-en", "Embedding model endpoint") + +catalog = dbutils.widgets.get("catalog") +schema = dbutils.widgets.get("schema") +source_table_name = dbutils.widgets.get("source_table_name") +endpoint_name = dbutils.widgets.get("endpoint_name") +index_name = dbutils.widgets.get("index_name") +primary_key = dbutils.widgets.get("primary_key") +embedding_source_column = dbutils.widgets.get("embedding_source_column") +embedding_model_endpoint = dbutils.widgets.get("embedding_model_endpoint") + +# COMMAND ---------- + +# Set catalog and schema +spark.sql(f"USE CATALOG {catalog}") +spark.sql(f"USE SCHEMA {schema}") + +# COMMAND ---------- + +from databricks.vector_search.client import VectorSearchClient +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.vectorsearch import ( + EndpointType, + VectorSearchIndexesAPI, + DeltaSyncVectorIndexSpecRequest, + EmbeddingSourceColumn, + PipelineType, + VectorIndexType, +) +from databricks.sdk.errors.platform import ResourceDoesNotExist, BadRequest +import time + +# Initialize clients +w = WorkspaceClient() +vsc = VectorSearchClient() + +# COMMAND ---------- + +# Helper functions for robust endpoint and index management + +def find_endpoint(endpoint_name: str): + """Check if vector search endpoint exists.""" + try: + vector_search_endpoints = w.vector_search_endpoints.list_endpoints() + for ve in vector_search_endpoints: + if ve.name == endpoint_name: + return ve + return None + except Exception: + return None + +def create_or_validate_endpoint(endpoint_name: str): + """Create or validate vector search endpoint exists and is online.""" + print(f"Checking if endpoint '{endpoint_name}' exists...") + + endpoint = find_endpoint(endpoint_name) + + if endpoint: + print(f"Endpoint '{endpoint_name}' already exists.") + # Use built-in wait method to ensure it's online + try: + w.vector_search_endpoints.wait_get_endpoint_vector_search_endpoint_online(endpoint_name) + print(f"Endpoint '{endpoint_name}' is online!") + except Exception as e: + print(f"Warning: Could not verify endpoint is online: {e}") + else: + print(f"Endpoint does not exist. Creating endpoint '{endpoint_name}'...") + print("Please wait, this can take up to 20 minutes...") + + # Use the built-in create_endpoint_and_wait method + w.vector_search_endpoints.create_endpoint_and_wait( + endpoint_name, + endpoint_type=EndpointType.STANDARD + ) + + # Ensure it's online + w.vector_search_endpoints.wait_get_endpoint_vector_search_endpoint_online(endpoint_name) + print(f"Endpoint '{endpoint_name}' is online!") + +def find_index(index_name: str): + """Check if vector search index exists.""" + try: + return w.vector_search_indexes.get_index(index_name=index_name) + except ResourceDoesNotExist: + return None + +def wait_for_index_to_be_ready(index_name: str): + """Wait for index to be ready using status.ready boolean.""" + index = find_index(index_name) + if not index: + raise Exception(f"Index {index_name} not found") + + while not index.status.ready: + print(f"Index {index_name} is not ready yet, waiting 30 seconds...") + print(f"Current status: {index.status}") + time.sleep(30) + index = find_index(index_name) + if not index: + raise Exception(f"Index {index_name} disappeared during wait") + + print(f"Index {index_name} is ready!") + +# COMMAND ---------- + +# Create or validate Vector Search endpoint +create_or_validate_endpoint(endpoint_name) + +# COMMAND ---------- + +# Create Vector Search Index +full_index_name = f"{catalog}.{schema}.{index_name}" +full_source_table_name = f"{catalog}.{schema}.{source_table_name}" + +print(f"Managing vector search index '{full_index_name}' on table '{full_source_table_name}'...") + +try: + existing_index = find_index(full_index_name) + + if existing_index: + print(f"Index '{full_index_name}' already exists.") + + # Wait for index to be ready if it's not + wait_for_index_to_be_ready(full_index_name) + + print(f"Index status: {existing_index.status}") + + # Try to sync the index + if existing_index.status.ready: + print("Syncing existing index...") + try: + w.vector_search_indexes.sync_index(index_name=full_index_name) + print("Index sync kicked off successfully!") + except BadRequest as e: + print(f"Index sync already in progress or failed: {e}") + print("Please wait for the index to finish syncing.") + else: + print("Index is not ready for syncing yet.") + else: + print(f"Index does not exist, creating '{full_index_name}'...") + print("Computing document embeddings and Vector Search Index. This can take 15 minutes or longer.") + + # Create delta sync index spec + delta_sync_spec = DeltaSyncVectorIndexSpecRequest( + source_table=full_source_table_name, + pipeline_type=PipelineType.TRIGGERED, + embedding_source_columns=[ + EmbeddingSourceColumn( + name=embedding_source_column, + embedding_model_endpoint_name=embedding_model_endpoint, + ) + ], + ) + + # Create the index + w.vector_search_indexes.create_index( + name=full_index_name, + endpoint_name=endpoint_name, + primary_key=primary_key, + index_type=VectorIndexType.DELTA_SYNC, + delta_sync_index_spec=delta_sync_spec, + ) + + print(f"Vector search index '{full_index_name}' created successfully!") + print("Note: The index will continue syncing/embedding in the background.") + print("Wait for the index to be ready before querying it.") + +except Exception as e: + print(f"Error managing vector search index: {e}") + raise + +# COMMAND ---------- + +print(f"Vector Search setup complete!") +print(f"Endpoint: {endpoint_name}") +print(f"Index: {full_index_name}") +print(f"Note: Check the index status in the Databricks UI to ensure it's fully synced before querying.")
' in table_html: + table_html = table_html.replace('', f'') + if '' in table_html: + table_html = table_html.replace('', f'') + if '
or tags in first row + import re + + # Find first row (either in thead or tbody) + first_row_match = re.search(r']*>(.*?)