diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json index 2504db607e46..95fef3e26ca2 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 12 + "modification": 13 } diff --git a/sdks/python/apache_beam/ml/rag/ingestion/spanner.py b/sdks/python/apache_beam/ml/rag/ingestion/spanner.py new file mode 100644 index 000000000000..f79db470bca4 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/spanner.py @@ -0,0 +1,646 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cloud Spanner vector store writer for RAG pipelines. + +This module provides a writer for storing embeddings and associated metadata +in Google Cloud Spanner. It supports flexible schema configuration with the +ability to flatten metadata fields into dedicated columns. + +Example usage: + + Default schema (id, embedding, content, metadata): + >>> config = SpannerVectorWriterConfig( + ... project_id="my-project", + ... instance_id="my-instance", + ... database_id="my-db", + ... table_name="embeddings" + ... ) + + Flattened metadata fields: + >>> specs = ( + ... SpannerColumnSpecsBuilder() + ... .with_id_spec() + ... .with_embedding_spec() + ... .with_content_spec() + ... .add_metadata_field("source", str) + ... .add_metadata_field("page_number", int, default=0) + ... .with_metadata_spec() + ... .build() + ... ) + >>> config = SpannerVectorWriterConfig( + ... project_id="my-project", + ... instance_id="my-instance", + ... database_id="my-db", + ... table_name="embeddings", + ... column_specs=specs + ... ) + +Spanner schema example: + + CREATE TABLE embeddings ( + id STRING(1024) NOT NULL, + embedding ARRAY(vector_length=>768), + content STRING(MAX), + source STRING(MAX), + page_number INT64, + metadata JSON + ) PRIMARY KEY (id) +""" + +import functools +import json +from dataclasses import dataclass +from typing import Any +from typing import Callable +from typing import List +from typing import Literal +from typing import NamedTuple +from typing import Optional +from typing import Type + +import apache_beam as beam +from apache_beam.coders import registry +from apache_beam.coders.row_coder import RowCoder +from apache_beam.io.gcp import spanner +from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig +from apache_beam.ml.rag.types import Chunk + + +@dataclass +class SpannerColumnSpec: + """Column specification for Spanner vector writes. + + Defines how to extract and format values from Chunks for insertion into + Spanner table columns. Each spec maps to one column in the target table. + + Attributes: + column_name: Name of the Spanner table column + python_type: Python type for the NamedTuple field (required for RowCoder) + value_fn: Function to extract value from a Chunk + + Examples: + String column: + >>> SpannerColumnSpec( + ... column_name="id", + ... python_type=str, + ... value_fn=lambda chunk: chunk.id + ... ) + + Array column with conversion: + >>> SpannerColumnSpec( + ... column_name="embedding", + ... python_type=List[float], + ... value_fn=lambda chunk: chunk.embedding.dense_embedding + ... ) + """ + column_name: str + python_type: Type + value_fn: Callable[[Chunk], Any] + + +def _extract_and_convert(extract_fn, convert_fn, chunk): + if convert_fn: + return convert_fn(extract_fn(chunk)) + return extract_fn(chunk) + + +class SpannerColumnSpecsBuilder: + """Builder for creating Spanner column specifications. + + Provides a fluent API for defining table schemas and how to populate them + from Chunk objects. Supports standard Chunk fields (id, embedding, content, + metadata) and flattening metadata fields into dedicated columns. + + Example: + >>> specs = ( + ... SpannerColumnSpecsBuilder() + ... .with_id_spec() + ... .with_embedding_spec() + ... .with_content_spec() + ... .add_metadata_field("source", str) + ... .with_metadata_spec() + ... .build() + ... ) + """ + def __init__(self): + self._specs: List[SpannerColumnSpec] = [] + + @staticmethod + def with_defaults() -> 'SpannerColumnSpecsBuilder': + """Create builder with default schema. + + Default schema includes: + - id (STRING): Chunk ID + - embedding (ARRAY): Dense embedding vector + - content (STRING): Chunk content text + - metadata (JSON): Full metadata as JSON + + Returns: + Builder with default column specifications + """ + return ( + SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec(). + with_content_spec().with_metadata_spec()) + + def with_id_spec( + self, + column_name: str = "id", + python_type: Type = str, + convert_fn: Optional[Callable[[str], Any]] = None + ) -> 'SpannerColumnSpecsBuilder': + """Add ID column specification. + + Args: + column_name: Column name (default: "id") + python_type: Python type (default: str) + convert_fn: Optional converter (e.g., to cast to int) + + Returns: + Self for method chaining + + Examples: + Default string ID: + >>> builder.with_id_spec() + + Integer ID with conversion: + >>> builder.with_id_spec( + ... python_type=int, + ... convert_fn=lambda id: int(id.split('_')[1]) + ... ) + """ + + self._specs.append( + SpannerColumnSpec( + column_name=column_name, + python_type=python_type, + value_fn=functools.partial( + _extract_and_convert, lambda chunk: chunk.id, convert_fn))) + return self + + def with_embedding_spec( + self, + column_name: str = "embedding", + convert_fn: Optional[Callable[[List[float]], List[float]]] = None + ) -> 'SpannerColumnSpecsBuilder': + """Add embedding array column (ARRAY or ARRAY). + + Args: + column_name: Column name (default: "embedding") + convert_fn: Optional converter (e.g., normalize, quantize) + + Returns: + Self for method chaining + + Examples: + Default embedding: + >>> builder.with_embedding_spec() + + Normalized embedding: + >>> def normalize(vec): + ... norm = (sum(x**2 for x in vec) ** 0.5) or 1.0 + ... return [x/norm for x in vec] + >>> builder.with_embedding_spec(convert_fn=normalize) + + Rounded precision: + >>> builder.with_embedding_spec( + ... convert_fn=lambda vec: [round(x, 4) for x in vec] + ... ) + """ + def extract_fn(chunk: Chunk) -> List[float]: + if chunk.embedding is None or chunk.embedding.dense_embedding is None: + raise ValueError(f'Chunk must contain embedding: {chunk}') + return chunk.embedding.dense_embedding + + self._specs.append( + SpannerColumnSpec( + column_name=column_name, + python_type=List[float], + value_fn=functools.partial( + _extract_and_convert, extract_fn, convert_fn))) + return self + + def with_content_spec( + self, + column_name: str = "content", + python_type: Type = str, + convert_fn: Optional[Callable[[str], Any]] = None + ) -> 'SpannerColumnSpecsBuilder': + """Add content column. + + Args: + column_name: Column name (default: "content") + python_type: Python type (default: str) + convert_fn: Optional converter + + Returns: + Self for method chaining + + Examples: + Default text content: + >>> builder.with_content_spec() + + Content length as integer: + >>> builder.with_content_spec( + ... column_name="content_length", + ... python_type=int, + ... convert_fn=lambda text: len(text.split()) + ... ) + + Truncated content: + >>> builder.with_content_spec( + ... convert_fn=lambda text: text[:1000] + ... ) + """ + def extract_fn(chunk: Chunk) -> str: + if chunk.content.text is None: + raise ValueError(f'Chunk must contain content: {chunk}') + return chunk.content.text + + self._specs.append( + SpannerColumnSpec( + column_name=column_name, + python_type=python_type, + value_fn=functools.partial( + _extract_and_convert, extract_fn, convert_fn))) + return self + + def with_metadata_spec( + self, column_name: str = "metadata") -> 'SpannerColumnSpecsBuilder': + """Add metadata JSON column. + + Stores the full metadata dictionary as a JSON string in Spanner. + + Args: + column_name: Column name (default: "metadata") + + Returns: + Self for method chaining + + Note: + Metadata is automatically converted to JSON string using json.dumps() + """ + value_fn = lambda chunk: json.dumps(chunk.metadata) + self._specs.append( + SpannerColumnSpec( + column_name=column_name, python_type=str, value_fn=value_fn)) + return self + + def add_metadata_field( + self, + field: str, + python_type: Type, + column_name: Optional[str] = None, + convert_fn: Optional[Callable[[Any], Any]] = None, + default: Any = None) -> 'SpannerColumnSpecsBuilder': + """Flatten a metadata field into its own column. + + Extracts a specific field from chunk.metadata and stores it in a + dedicated table column. + + Args: + field: Key in chunk.metadata to extract + python_type: Python type (must be explicitly specified) + column_name: Column name (default: same as field) + convert_fn: Optional converter for type casting/transformation + default: Default value if field is missing from metadata + + Returns: + Self for method chaining + + Examples: + String field: + >>> builder.add_metadata_field("source", str) + + Integer with default: + >>> builder.add_metadata_field( + ... "page_number", + ... int, + ... default=0 + ... ) + + Float with conversion: + >>> builder.add_metadata_field( + ... "confidence", + ... float, + ... convert_fn=lambda x: round(float(x), 2), + ... default=0.0 + ... ) + + List of strings: + >>> builder.add_metadata_field( + ... "tags", + ... List[str], + ... default=[] + ... ) + + Timestamp with conversion: + >>> builder.add_metadata_field( + ... "created_at", + ... str, + ... convert_fn=lambda ts: ts.isoformat() + ... ) + """ + name = column_name or field + + def value_fn(chunk: Chunk) -> Any: + return chunk.metadata.get(field, default) + + self._specs.append( + SpannerColumnSpec( + column_name=name, + python_type=python_type, + value_fn=functools.partial( + _extract_and_convert, value_fn, convert_fn))) + return self + + def add_column( + self, + column_name: str, + python_type: Type, + value_fn: Callable[[Chunk], Any]) -> 'SpannerColumnSpecsBuilder': + """Add a custom column with full control. + + Args: + column_name: Column name + python_type: Python type (required) + value_fn: Value extraction function + + Returns: + Self for method chaining + + Examples: + Boolean flag: + >>> builder.add_column( + ... column_name="has_code", + ... python_type=bool, + ... value_fn=lambda chunk: "```" in chunk.content.text + ... ) + + Computed value: + >>> builder.add_column( + ... column_name="word_count", + ... python_type=int, + ... value_fn=lambda chunk: len(chunk.content.text.split()) + ... ) + """ + self._specs.append( + SpannerColumnSpec( + column_name=column_name, python_type=python_type, + value_fn=value_fn)) + return self + + def build(self) -> List[SpannerColumnSpec]: + """Build the final list of column specifications. + + Returns: + List of SpannerColumnSpec objects + """ + return self._specs.copy() + + +class _SpannerSchemaBuilder: + """Internal: Builds NamedTuple schema and registers RowCoder. + + Creates a NamedTuple type from column specifications and registers it + with Beam's RowCoder for serialization. + """ + def __init__(self, table_name: str, column_specs: List[SpannerColumnSpec]): + """Initialize schema builder. + + Args: + table_name: Table name (used in NamedTuple type name) + column_specs: List of column specifications + + Raises: + ValueError: If duplicate column names are found + """ + self.table_name = table_name + self.column_specs = column_specs + + # Validate no duplicates + names = [col.column_name for col in column_specs] + duplicates = set(name for name in names if names.count(name) > 1) + if duplicates: + raise ValueError(f"Duplicate column names: {duplicates}") + + # Create NamedTuple type + fields = [(col.column_name, col.python_type) for col in column_specs] + type_name = f"SpannerVectorRecord_{table_name}" + self.record_type = NamedTuple(type_name, fields) # type: ignore + + # Register coder + registry.register_coder(self.record_type, RowCoder) + + def create_converter(self) -> Callable[[Chunk], NamedTuple]: + """Create converter function from Chunk to NamedTuple record. + + Returns: + Function that converts a Chunk to a NamedTuple record + """ + def convert(chunk: Chunk) -> self.record_type: # type: ignore + values = { + col.column_name: col.value_fn(chunk) + for col in self.column_specs + } + return self.record_type(**values) # type: ignore + + return convert + + +class SpannerVectorWriterConfig(VectorDatabaseWriteConfig): + """Configuration for writing vectors to Cloud Spanner. + + Supports flexible schema configuration through column specifications and + provides control over Spanner-specific write parameters. + + Examples: + Default schema: + >>> config = SpannerVectorWriterConfig( + ... project_id="my-project", + ... instance_id="my-instance", + ... database_id="my-db", + ... table_name="embeddings" + ... ) + + Custom schema with flattened metadata: + >>> specs = ( + ... SpannerColumnSpecsBuilder() + ... .with_id_spec() + ... .with_embedding_spec() + ... .with_content_spec() + ... .add_metadata_field("source", str) + ... .add_metadata_field("page_number", int, default=0) + ... .with_metadata_spec() + ... .build() + ... ) + >>> config = SpannerVectorWriterConfig( + ... project_id="my-project", + ... instance_id="my-instance", + ... database_id="my-db", + ... table_name="embeddings", + ... column_specs=specs + ... ) + + With emulator: + >>> config = SpannerVectorWriterConfig( + ... project_id="test-project", + ... instance_id="test-instance", + ... database_id="test-db", + ... table_name="embeddings", + ... emulator_host="http://localhost:9010" + ... ) + """ + def __init__( + self, + project_id: str, + instance_id: str, + database_id: str, + table_name: str, + *, + # Schema configuration + column_specs: Optional[List[SpannerColumnSpec]] = None, + # Write operation type + write_mode: Literal["INSERT", "UPDATE", "REPLACE", + "INSERT_OR_UPDATE"] = "INSERT_OR_UPDATE", + # Batching configuration + max_batch_size_bytes: Optional[int] = None, + max_number_mutations: Optional[int] = None, + max_number_rows: Optional[int] = None, + grouping_factor: Optional[int] = None, + # Networking + host: Optional[str] = None, + emulator_host: Optional[str] = None, + expansion_service: Optional[str] = None, + # Retry/deadline configuration + commit_deadline: Optional[int] = None, + max_cumulative_backoff: Optional[int] = None, + # Error handling + failure_mode: Optional[ + spanner.FailureMode] = spanner.FailureMode.REPORT_FAILURES, + high_priority: bool = False, + # Additional Spanner arguments + **spanner_kwargs): + """Initialize Spanner vector writer configuration. + + Args: + project_id: GCP project ID + instance_id: Spanner instance ID + database_id: Spanner database ID + table_name: Target table name + column_specs: Schema configuration using SpannerColumnSpecsBuilder. + If None, uses default schema (id, embedding, content, metadata) + write_mode: Spanner write operation type: + - INSERT: Fail if row exists + - UPDATE: Fail if row doesn't exist + - REPLACE: Delete then insert + - INSERT_OR_UPDATE: Insert or update if exists (default) + max_batch_size_bytes: Maximum bytes per mutation batch (default: 1MB) + max_number_mutations: Maximum cell mutations per batch (default: 5000) + max_number_rows: Maximum rows per batch (default: 500) + grouping_factor: Multiple of max mutation for sorting (default: 1000) + host: Spanner host URL (usually not needed) + emulator_host: Spanner emulator host (e.g., "http://localhost:9010") + expansion_service: Java expansion service address (host:port) + commit_deadline: Commit API deadline in seconds (default: 15) + max_cumulative_backoff: Max retry backoff seconds (default: 900) + failure_mode: Error handling strategy: + - FAIL_FAST: Throw exception for any failure + - REPORT_FAILURES: Continue processing (default) + high_priority: Use high priority for operations (default: False) + **spanner_kwargs: Additional keyword arguments to pass to the + underlying Spanner write transform. Use this to pass any + Spanner-specific parameters not explicitly exposed by this config. + """ + self.project_id = project_id + self.instance_id = instance_id + self.database_id = database_id + self.table_name = table_name + self.write_mode = write_mode + self.max_batch_size_bytes = max_batch_size_bytes + self.max_number_mutations = max_number_mutations + self.max_number_rows = max_number_rows + self.grouping_factor = grouping_factor + self.host = host + self.emulator_host = emulator_host + self.expansion_service = expansion_service + self.commit_deadline = commit_deadline + self.max_cumulative_backoff = max_cumulative_backoff + self.failure_mode = failure_mode + self.high_priority = high_priority + self.spanner_kwargs = spanner_kwargs + + # Use defaults if not provided + specs = column_specs or SpannerColumnSpecsBuilder.with_defaults().build() + + # Create schema builder (NamedTuple + RowCoder registration) + self.schema_builder = _SpannerSchemaBuilder(table_name, specs) + + def create_write_transform(self) -> beam.PTransform: + """Create the Spanner write PTransform. + + Returns: + PTransform for writing to Spanner + """ + return _WriteToSpannerVectorDatabase(self) + + +class _WriteToSpannerVectorDatabase(beam.PTransform): + """Internal: PTransform for writing to Spanner vector database.""" + def __init__(self, config: SpannerVectorWriterConfig): + """Initialize write transform. + + Args: + config: Spanner writer configuration + """ + self.config = config + self.schema_builder = config.schema_builder + + def expand(self, pcoll: beam.PCollection[Chunk]): + """Expand the transform. + + Args: + pcoll: PCollection of Chunks to write + """ + # Select appropriate Spanner write transform based on write_mode + write_transform_class = { + "INSERT": spanner.SpannerInsert, + "UPDATE": spanner.SpannerUpdate, + "REPLACE": spanner.SpannerReplace, + "INSERT_OR_UPDATE": spanner.SpannerInsertOrUpdate, + }[self.config.write_mode] + + return ( + pcoll + | "Convert to Records" >> beam.Map( + self.schema_builder.create_converter()).with_output_types( + self.schema_builder.record_type) + | "Write to Spanner" >> write_transform_class( + project_id=self.config.project_id, + instance_id=self.config.instance_id, + database_id=self.config.database_id, + table=self.config.table_name, + max_batch_size_bytes=self.config.max_batch_size_bytes, + max_number_mutations=self.config.max_number_mutations, + max_number_rows=self.config.max_number_rows, + grouping_factor=self.config.grouping_factor, + host=self.config.host, + emulator_host=self.config.emulator_host, + commit_deadline=self.config.commit_deadline, + max_cumulative_backoff=self.config.max_cumulative_backoff, + failure_mode=self.config.failure_mode, + expansion_service=self.config.expansion_service, + high_priority=self.config.high_priority, + **self.config.spanner_kwargs)) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/spanner_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/spanner_it_test.py new file mode 100644 index 000000000000..ab9a982a81f7 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/spanner_it_test.py @@ -0,0 +1,601 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Integration tests for Spanner vector store writer.""" + +import logging +import os +import time +import unittest +import uuid + +import pytest + +import apache_beam as beam +from apache_beam.ml.rag.ingestion.spanner import SpannerVectorWriterConfig +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding +from apache_beam.testing.test_pipeline import TestPipeline + +# pylint: disable=wrong-import-order, wrong-import-position +try: + from google.cloud import spanner +except ImportError: + spanner = None + +try: + from testcontainers.core.container import DockerContainer +except ImportError: + DockerContainer = None +# pylint: enable=wrong-import-order, wrong-import-position + + +def retry(fn, retries, err_msg, *args, **kwargs): + """Retry a function with exponential backoff.""" + for _ in range(retries): + try: + return fn(*args, **kwargs) + except: # pylint: disable=bare-except + time.sleep(1) + logging.error(err_msg) + raise RuntimeError(err_msg) + + +class SpannerEmulatorHelper: + """Helper for managing Spanner emulator lifecycle.""" + def __init__(self, project_id: str, instance_id: str, table_name: str): + self.project_id = project_id + self.instance_id = instance_id + self.table_name = table_name + self.host = None + + # Start emulator + self.emulator = DockerContainer( + 'gcr.io/cloud-spanner-emulator/emulator:latest').with_exposed_ports( + 9010, 9020) + retry(self.emulator.start, 3, 'Could not start spanner emulator.') + time.sleep(3) + + self.host = f'{self.emulator.get_container_host_ip()}:' \ + f'{self.emulator.get_exposed_port(9010)}' + os.environ['SPANNER_EMULATOR_HOST'] = self.host + + # Create client and instance + self.client = spanner.Client(project_id) + self.instance = self.client.instance(instance_id) + self.create_instance() + + def create_instance(self): + """Create Spanner instance in emulator.""" + self.instance.create().result(120) + + def create_database(self, database_id: str): + """Create database with default vector table schema.""" + database = self.instance.database( + database_id, + ddl_statements=[ + f''' + CREATE TABLE {self.table_name} ( + id STRING(1024) NOT NULL, + embedding ARRAY(vector_length=>3), + content STRING(MAX), + metadata JSON + ) PRIMARY KEY (id)''' + ]) + database.create().result(120) + + def read_data(self, database_id: str): + """Read all data from the table.""" + database = self.instance.database(database_id) + with database.snapshot() as snapshot: + results = snapshot.execute_sql( + f'SELECT * FROM {self.table_name} ORDER BY id') + return list(results) if results else [] + + def drop_database(self, database_id: str): + """Drop the database.""" + database = self.instance.database(database_id) + database.drop() + + def shutdown(self): + """Stop the emulator.""" + if self.emulator: + try: + self.emulator.stop() + except: # pylint: disable=bare-except + logging.error('Could not stop Spanner emulator.') + + def get_emulator_host(self) -> str: + """Get the emulator host URL.""" + return f'http://{self.host}' + + +@pytest.mark.uses_gcp_java_expansion_service +@unittest.skipUnless( + os.environ.get('EXPANSION_JARS'), + "EXPANSION_JARS environment var is not provided, " + "indicating that jars have not been built") +@unittest.skipIf(spanner is None, 'GCP dependencies are not installed.') +@unittest.skipIf( + DockerContainer is None, 'testcontainers package is not installed.') +class SpannerVectorWriterTest(unittest.TestCase): + """Integration tests for Spanner vector writer.""" + @classmethod + def setUpClass(cls): + """Set up Spanner emulator for all tests.""" + cls.project_id = 'test-project' + cls.instance_id = 'test-instance' + cls.table_name = 'embeddings' + + cls.spanner_helper = SpannerEmulatorHelper( + cls.project_id, cls.instance_id, cls.table_name) + + @classmethod + def tearDownClass(cls): + """Tear down Spanner emulator.""" + cls.spanner_helper.shutdown() + + def setUp(self): + """Create a unique database for each test.""" + self.database_id = f'test_db_{uuid.uuid4().hex}'[:30] + self.spanner_helper.create_database(self.database_id) + + def tearDown(self): + """Drop the test database.""" + self.spanner_helper.drop_database(self.database_id) + + def test_write_default_schema(self): + """Test writing with default schema (id, embedding, content, metadata).""" + # Create test chunks + chunks = [ + Chunk( + id='doc1', + embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]), + content=Content(text='First document'), + metadata={ + 'source': 'test', 'page': 1 + }), + Chunk( + id='doc2', + embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]), + content=Content(text='Second document'), + metadata={ + 'source': 'test', 'page': 2 + }), + ] + + # Create config with default schema + config = SpannerVectorWriterConfig( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + table_name=self.table_name, + emulator_host=self.spanner_helper.get_emulator_host(), + ) + + # Write chunks + with TestPipeline() as p: + p.not_use_test_runner_api = True + _ = (p | beam.Create(chunks) | config.create_write_transform()) + + # Verify data was written + results = self.spanner_helper.read_data(self.database_id) + self.assertEqual(len(results), 2) + + # Check first row + row1 = results[0] + self.assertEqual(row1[0], 'doc1') # id + self.assertEqual(list(row1[1]), [1.0, 2.0, 3.0]) # embedding + self.assertEqual(row1[2], 'First document') # content + # metadata is JSON + metadata1 = row1[3] + self.assertEqual(metadata1['source'], 'test') + self.assertEqual(metadata1['page'], 1) + + # Check second row + row2 = results[1] + self.assertEqual(row2[0], 'doc2') + self.assertEqual(list(row2[1]), [4.0, 5.0, 6.0]) + self.assertEqual(row2[2], 'Second document') + + def test_write_flattened_metadata(self): + """Test writing with flattened metadata fields.""" + from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder + + # Create custom database with flattened columns + self.spanner_helper.drop_database(self.database_id) + database = self.spanner_helper.instance.database( + self.database_id, + ddl_statements=[ + f''' + CREATE TABLE {self.table_name} ( + id STRING(1024) NOT NULL, + embedding ARRAY(vector_length=>3), + content STRING(MAX), + source STRING(MAX), + page_number INT64, + metadata JSON + ) PRIMARY KEY (id)''' + ]) + database.create().result(120) + + # Create test chunks + chunks = [ + Chunk( + id='doc1', + embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]), + content=Content(text='First document'), + metadata={ + 'source': 'book.pdf', 'page': 10, 'author': 'John' + }), + Chunk( + id='doc2', + embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]), + content=Content(text='Second document'), + metadata={ + 'source': 'article.txt', 'page': 5, 'author': 'Jane' + }), + ] + + # Create config with flattened metadata + specs = ( + SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec(). + with_content_spec().add_metadata_field( + 'source', str, column_name='source').add_metadata_field( + 'page', int, + column_name='page_number').with_metadata_spec().build()) + + config = SpannerVectorWriterConfig( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + table_name=self.table_name, + column_specs=specs, + emulator_host=self.spanner_helper.get_emulator_host(), + ) + + # Write chunks + with TestPipeline() as p: + p.not_use_test_runner_api = True + _ = (p | beam.Create(chunks) | config.create_write_transform()) + + # Verify data + database = self.spanner_helper.instance.database(self.database_id) + with database.snapshot() as snapshot: + results = snapshot.execute_sql( + f'SELECT id, embedding, content, source, page_number, metadata ' + f'FROM {self.table_name} ORDER BY id') + rows = list(results) + + self.assertEqual(len(rows), 2) + + # Check first row + self.assertEqual(rows[0][0], 'doc1') + self.assertEqual(list(rows[0][1]), [1.0, 2.0, 3.0]) + self.assertEqual(rows[0][2], 'First document') + self.assertEqual(rows[0][3], 'book.pdf') # flattened source + self.assertEqual(rows[0][4], 10) # flattened page_number + + metadata1 = rows[0][5] + self.assertEqual(metadata1['author'], 'John') + + def test_write_minimal_schema(self): + """Test writing with minimal schema (only id and embedding).""" + from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder + + # Create custom database with minimal schema + self.spanner_helper.drop_database(self.database_id) + database = self.spanner_helper.instance.database( + self.database_id, + ddl_statements=[ + f''' + CREATE TABLE {self.table_name} ( + id STRING(1024) NOT NULL, + embedding ARRAY(vector_length=>3) + ) PRIMARY KEY (id)''' + ]) + database.create().result(120) + + # Create test chunks + chunks = [ + Chunk( + id='doc1', + embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]), + content=Content(text='First document'), + metadata={'source': 'test'}), + Chunk( + id='doc2', + embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]), + content=Content(text='Second document'), + metadata={'source': 'test'}), + ] + + # Create config with minimal schema + specs = ( + SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec().build( + )) + + config = SpannerVectorWriterConfig( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + table_name=self.table_name, + column_specs=specs, + emulator_host=self.spanner_helper.get_emulator_host(), + ) + + # Write chunks + with TestPipeline() as p: + p.not_use_test_runner_api = True + _ = (p | beam.Create(chunks) | config.create_write_transform()) + + # Verify data + results = self.spanner_helper.read_data(self.database_id) + self.assertEqual(len(results), 2) + self.assertEqual(results[0][0], 'doc1') + self.assertEqual(list(results[0][1]), [1.0, 2.0, 3.0]) + + def test_write_with_converter(self): + """Test writing with custom converter function.""" + from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder + + # Create test chunks with embeddings that need normalization + chunks = [ + Chunk( + id='doc1', + embedding=Embedding(dense_embedding=[3.0, 4.0, 0.0]), + content=Content(text='First document'), + metadata={'source': 'test'}), + ] + + # Define normalizer + def normalize(vec): + norm = (sum(x**2 for x in vec)**0.5) or 1.0 + return [x / norm for x in vec] + + # Create config with normalized embeddings + specs = ( + SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec( + convert_fn=normalize).with_content_spec().with_metadata_spec(). + build()) + + config = SpannerVectorWriterConfig( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + table_name=self.table_name, + column_specs=specs, + emulator_host=self.spanner_helper.get_emulator_host(), + ) + + # Write chunks + with TestPipeline() as p: + p.not_use_test_runner_api = True + _ = (p | beam.Create(chunks) | config.create_write_transform()) + + # Verify data - embedding should be normalized + results = self.spanner_helper.read_data(self.database_id) + self.assertEqual(len(results), 1) + + embedding = list(results[0][1]) + # Original was [3.0, 4.0, 0.0], normalized should be [0.6, 0.8, 0.0] + self.assertAlmostEqual(embedding[0], 0.6, places=5) + self.assertAlmostEqual(embedding[1], 0.8, places=5) + self.assertAlmostEqual(embedding[2], 0.0, places=5) + + # Check norm is 1.0 + norm = sum(x**2 for x in embedding)**0.5 + self.assertAlmostEqual(norm, 1.0, places=5) + + def test_write_update_mode(self): + """Test writing with UPDATE mode.""" + # First insert data + chunks_insert = [ + Chunk( + id='doc1', + embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]), + content=Content(text='Original content'), + metadata={'version': 1}), + ] + + config_insert = SpannerVectorWriterConfig( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + table_name=self.table_name, + write_mode='INSERT', + emulator_host=self.spanner_helper.get_emulator_host(), + ) + + with TestPipeline() as p: + p.not_use_test_runner_api = True + _ = ( + p + | beam.Create(chunks_insert) + | config_insert.create_write_transform()) + + # Update existing row + chunks_update = [ + Chunk( + id='doc1', + embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]), + content=Content(text='Updated content'), + metadata={'version': 2}), + ] + + config_update = SpannerVectorWriterConfig( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + table_name=self.table_name, + write_mode='UPDATE', + emulator_host=self.spanner_helper.get_emulator_host(), + ) + + with TestPipeline() as p: + p.not_use_test_runner_api = True + _ = ( + p + | beam.Create(chunks_update) + | config_update.create_write_transform()) + + # Verify update succeeded + results = self.spanner_helper.read_data(self.database_id) + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], 'doc1') + self.assertEqual(list(results[0][1]), [4.0, 5.0, 6.0]) + self.assertEqual(results[0][2], 'Updated content') + + metadata = results[0][3] + self.assertEqual(metadata['version'], 2) + + def test_write_custom_column(self): + """Test writing with custom computed column.""" + from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder + + # Create custom database with computed column + self.spanner_helper.drop_database(self.database_id) + database = self.spanner_helper.instance.database( + self.database_id, + ddl_statements=[ + f''' + CREATE TABLE {self.table_name} ( + id STRING(1024) NOT NULL, + embedding ARRAY(vector_length=>3), + content STRING(MAX), + word_count INT64, + metadata JSON + ) PRIMARY KEY (id)''' + ]) + database.create().result(120) + + # Create test chunks + chunks = [ + Chunk( + id='doc1', + embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]), + content=Content(text='Hello world test'), + metadata={}), + Chunk( + id='doc2', + embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]), + content=Content(text='This is a longer test document'), + metadata={}), + ] + + # Create config with custom word_count column + specs = ( + SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec( + ).with_content_spec().add_column( + column_name='word_count', + python_type=int, + value_fn=lambda chunk: len(chunk.content.text.split())). + with_metadata_spec().build()) + + config = SpannerVectorWriterConfig( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + table_name=self.table_name, + column_specs=specs, + emulator_host=self.spanner_helper.get_emulator_host(), + ) + + # Write chunks + with TestPipeline() as p: + p.not_use_test_runner_api = True + _ = (p | beam.Create(chunks) | config.create_write_transform()) + + # Verify data + database = self.spanner_helper.instance.database(self.database_id) + with database.snapshot() as snapshot: + results = snapshot.execute_sql( + f'SELECT id, word_count FROM {self.table_name} ORDER BY id') + rows = list(results) + + self.assertEqual(len(rows), 2) + self.assertEqual(rows[0][1], 3) # "Hello world test" = 3 words + self.assertEqual(rows[1][1], 6) # 6 words + + def test_write_with_timestamp(self): + """Test writing with timestamp columns.""" + from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder + + # Create database with timestamp column + self.spanner_helper.drop_database(self.database_id) + database = self.spanner_helper.instance.database( + self.database_id, + ddl_statements=[ + f''' + CREATE TABLE {self.table_name} ( + id STRING(1024) NOT NULL, + embedding ARRAY(vector_length=>3), + content STRING(MAX), + created_at TIMESTAMP, + metadata JSON + ) PRIMARY KEY (id)''' + ]) + database.create().result(120) + + # Create chunks with timestamp + timestamp_str = "2025-10-28T09:45:00.123456Z" + chunks = [ + Chunk( + id='doc1', + embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]), + content=Content(text='Document with timestamp'), + metadata={'created_at': timestamp_str}), + ] + + # Create config with timestamp field + specs = ( + SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec(). + with_content_spec().add_metadata_field( + 'created_at', str, + column_name='created_at').with_metadata_spec().build()) + + config = SpannerVectorWriterConfig( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + table_name=self.table_name, + column_specs=specs, + emulator_host=self.spanner_helper.get_emulator_host(), + ) + + # Write chunks + with TestPipeline() as p: + p.not_use_test_runner_api = True + _ = (p | beam.Create(chunks) | config.create_write_transform()) + + # Verify timestamp was written + database = self.spanner_helper.instance.database(self.database_id) + with database.snapshot() as snapshot: + results = snapshot.execute_sql( + f'SELECT id, created_at FROM {self.table_name}') + rows = list(results) + + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0][0], 'doc1') + # Timestamp is returned as datetime object by Spanner client + self.assertIsNotNone(rows[0][1]) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main()