Skip to content

Commit 6ab9d30

Browse files
claudevdmClaude
authored andcommitted
Alloy language connector (#34156)
* Add AlloyDB language connector support. * Add test. * Trigger test. * Add link to WriteToJdbc. --------- Co-authored-by: Claude <cvandermerwe@google.com>
1 parent 3594d2a commit 6ab9d30

File tree

3 files changed

+217
-2
lines changed

3 files changed

+217
-2
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
22
"comment": "Modify this file in a trivial way to cause this test suite to run",
3-
"modification": 4
3+
"modification": 5
44
}

sdks/python/apache_beam/ml/rag/ingestion/alloydb.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import json
1818
import logging
1919
from dataclasses import dataclass
20+
from dataclasses import field
2021
from typing import Any
2122
from typing import Callable
2223
from typing import Dict
@@ -37,6 +38,73 @@
3738
_LOGGER = logging.getLogger(__name__)
3839

3940

41+
@dataclass
42+
class AlloyDBLanguageConnectorConfig:
43+
"""Configuration options for AlloyDB Java language connector.
44+
45+
Contains all parameters needed to configure a connection using the AlloyDB
46+
Java connector via JDBC. For details see
47+
https://github.com/GoogleCloudPlatform/alloydb-java-connector/blob/main/docs/jdbc.md
48+
49+
Attributes:
50+
database_name: Name of the database to connect to.
51+
instance_name: Fullly qualified instance. Format:
52+
'projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances
53+
/<INSTANCE>'
54+
ip_type: IP type to use for connection. Either 'PRIVATE' (default),
55+
'PUBLIC' 'PSC.
56+
enable_iam_auth: Whether to enable IAM authentication. Default is False
57+
target_principal: Optional service account to impersonate for
58+
connection.
59+
delegates: Optional comma-separated list of service accounts for
60+
delegated impersonation.
61+
admin_service_endpoint: Optional custom API service endpoint.
62+
quota_project: Optional project ID for quota and billing.
63+
"""
64+
database_name: str
65+
instance_name: str
66+
ip_type: str = "PRIVATE"
67+
enable_iam_auth: bool = False
68+
target_principal: Optional[str] = None
69+
delegates: Optional[List[str]] = None
70+
admin_service_endpoint: Optional[str] = None
71+
quota_project: Optional[str] = None
72+
73+
def to_jdbc_url(self) -> str:
74+
"""Convert options to a properly formatted JDBC URL.
75+
76+
Returns:
77+
JDBC URL string configured with all options.
78+
"""
79+
# Base URL with database name
80+
url = f"jdbc:postgresql:///{self.database_name}?"
81+
82+
# Add required properties
83+
properties = {
84+
"socketFactory": "com.google.cloud.alloydb.SocketFactory",
85+
"alloydbInstanceName": self.instance_name,
86+
"alloydbIpType": self.ip_type
87+
}
88+
89+
if self.enable_iam_auth:
90+
properties["alloydbEnableIAMAuth"] = "true"
91+
92+
if self.target_principal:
93+
properties["alloydbTargetPrincipal"] = self.target_principal
94+
95+
if self.delegates:
96+
properties["alloydbDelegates"] = ",".join(self.delegates)
97+
98+
if self.admin_service_endpoint:
99+
properties["alloydbAdminServiceEndpoint"] = self.admin_service_endpoint
100+
101+
if self.quota_project:
102+
properties["alloydbQuotaProject"] = self.quota_project
103+
104+
property_string = "&".join(f"{k}={v}" for k, v in properties.items())
105+
return url + property_string
106+
107+
40108
@dataclass
41109
class AlloyDBConnectionConfig:
42110
"""Configuration for AlloyDB database connection.
@@ -58,6 +126,10 @@ class AlloyDBConnectionConfig:
58126
max_connections: Optional number of connections in the pool.
59127
Use negative for no limit.
60128
write_batch_size: Optional write batch size for bulk operations.
129+
additional_jdbc_args: Additional arguments that will be passed to
130+
WriteToJdbc. These may include 'driver_jars', 'expansion_service',
131+
'classpath', etc. See full set of args at
132+
:class:`~apache_beam.io.jdbc.WriteToJdbc`
61133
62134
Example:
63135
>>> config = AlloyDBConnectionConfig(
@@ -76,6 +148,60 @@ class AlloyDBConnectionConfig:
76148
autosharding: Optional[bool] = None
77149
max_connections: Optional[int] = None
78150
write_batch_size: Optional[int] = None
151+
additional_jdbc_args: Dict[str, Any] = field(default_factory=dict)
152+
153+
@classmethod
154+
def with_language_connector(
155+
cls,
156+
connector_options: AlloyDBLanguageConnectorConfig,
157+
username: str,
158+
password: str,
159+
connection_properties: Optional[Dict[str, str]] = None,
160+
connection_init_sqls: Optional[List[str]] = None,
161+
autosharding: Optional[bool] = None,
162+
max_connections: Optional[int] = None,
163+
write_batch_size: Optional[int] = None) -> 'AlloyDBConnectionConfig':
164+
"""Create AlloyDBConnectionConfig using the AlloyDB language connector.
165+
166+
Args:
167+
connector_options: AlloyDB language connector configuration options.
168+
username: Database username. For IAM auth, this should be the IAM
169+
user email.
170+
password: Database password. Can be empty string when using IAM
171+
auth.
172+
connection_properties: Additional JDBC connection properties.
173+
connection_init_sqls: SQL statements to execute on connection.
174+
autosharding: Enable autosharding.
175+
max_connections: Max connections in pool.
176+
write_batch_size: Write batch size.
177+
178+
Returns:
179+
Configured AlloyDBConnectionConfig instance.
180+
181+
Example:
182+
>>> options = AlloyDBLanguageConnectorConfig(
183+
... database_name="mydb",
184+
... instance_name="projects/my-project/locations/us-central1\
185+
.... /clusters/my-cluster/instances/my-instance",
186+
... ip_type="PUBLIC",
187+
... enable_iam_auth=True
188+
... )
189+
"""
190+
return cls(
191+
jdbc_url=connector_options.to_jdbc_url(),
192+
username=username,
193+
password=password,
194+
connection_properties=connection_properties,
195+
connection_init_sqls=connection_init_sqls,
196+
autosharding=autosharding,
197+
max_connections=max_connections,
198+
write_batch_size=write_batch_size,
199+
additional_jdbc_args={
200+
'classpath': [
201+
"org.postgresql:postgresql:42.2.16",
202+
"com.google.cloud:alloydb-jdbc-connector:1.2.0"
203+
]
204+
})
79205

80206

81207
@dataclass
@@ -713,4 +839,5 @@ def expand(self, pcoll: beam.PCollection[Chunk]):
713839
connection_init_sqls,
714840
autosharding=self.config.connection_config.autosharding,
715841
max_connections=self.config.connection_config.max_connections,
716-
write_batch_size=self.config.connection_config.write_batch_size))
842+
write_batch_size=self.config.connection_config.write_batch_size,
843+
**self.config.connection_config.additional_jdbc_args))

sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from apache_beam.coders.row_coder import RowCoder
3434
from apache_beam.io.jdbc import ReadFromJdbc
3535
from apache_beam.ml.rag.ingestion.alloydb import AlloyDBConnectionConfig
36+
from apache_beam.ml.rag.ingestion.alloydb import AlloyDBLanguageConnectorConfig
3637
from apache_beam.ml.rag.ingestion.alloydb import AlloyDBVectorWriterConfig
3738
from apache_beam.ml.rag.ingestion.alloydb import ColumnSpec
3839
from apache_beam.ml.rag.ingestion.alloydb import ColumnSpecsBuilder
@@ -328,6 +329,93 @@ def test_default_schema(self):
328329
equal_to([expected_last_n]),
329330
label=f"last_{sample_size}_check")
330331

332+
def test_language_connector(self):
333+
"""Test language connector."""
334+
self.skip_if_dataflow_runner()
335+
336+
connector_options = AlloyDBLanguageConnectorConfig(
337+
database_name=self.database,
338+
instance_name="projects/apache-beam-testing/locations/us-central1/\
339+
clusters/testing-psc/instances/testing-psc-1",
340+
ip_type="PSC")
341+
connection_config = AlloyDBConnectionConfig.with_language_connector(
342+
connector_options=connector_options,
343+
username=self.username,
344+
password=self.password)
345+
config = AlloyDBVectorWriterConfig(
346+
connection_config=connection_config, table_name=self.default_table_name)
347+
348+
# Create test chunks
349+
num_records = 150
350+
sample_size = min(500, num_records // 2)
351+
chunks = ChunkTestUtils.get_expected_values(0, num_records)
352+
353+
self.write_test_pipeline.not_use_test_runner_api = True
354+
355+
with self.write_test_pipeline as p:
356+
_ = (p | beam.Create(chunks) | config.create_write_transform())
357+
358+
self.read_test_pipeline.not_use_test_runner_api = True
359+
read_query = f"""
360+
SELECT
361+
CAST(id AS VARCHAR(255)),
362+
CAST(content AS VARCHAR(255)),
363+
CAST(embedding AS text),
364+
CAST(metadata AS text)
365+
FROM {self.default_table_name}
366+
"""
367+
368+
with self.read_test_pipeline as p:
369+
rows = (
370+
p
371+
| ReadFromJdbc(
372+
table_name=self.default_table_name,
373+
driver_class_name="org.postgresql.Driver",
374+
jdbc_url=connector_options.to_jdbc_url(),
375+
username=self.username,
376+
password=self.password,
377+
query=read_query,
378+
classpath=[
379+
"org.postgresql:postgresql:42.2.16",
380+
"com.google.cloud:alloydb-jdbc-connector:1.2.0"
381+
]))
382+
383+
count_result = rows | "Count All" >> beam.combiners.Count.Globally()
384+
assert_that(count_result, equal_to([num_records]), label='count_check')
385+
386+
chunks = (rows | "To Chunks" >> beam.Map(row_to_chunk))
387+
chunk_hashes = chunks | "Hash Chunks" >> beam.CombineGlobally(HashingFn())
388+
assert_that(
389+
chunk_hashes,
390+
equal_to([generate_expected_hash(num_records)]),
391+
label='hash_check')
392+
393+
# Sample validation
394+
first_n = (
395+
chunks
396+
| "Key on Index" >> beam.Map(key_on_id)
397+
| f"Get First {sample_size}" >> beam.transforms.combiners.Top.Of(
398+
sample_size, key=lambda x: x[0], reverse=True)
399+
| "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs]))
400+
expected_first_n = ChunkTestUtils.get_expected_values(0, sample_size)
401+
assert_that(
402+
first_n,
403+
equal_to([expected_first_n]),
404+
label=f"first_{sample_size}_check")
405+
406+
last_n = (
407+
chunks
408+
| "Key on Index 2" >> beam.Map(key_on_id)
409+
| f"Get Last {sample_size}" >> beam.transforms.combiners.Top.Of(
410+
sample_size, key=lambda x: x[0])
411+
| "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs]))
412+
expected_last_n = ChunkTestUtils.get_expected_values(
413+
num_records - sample_size, num_records)[::-1]
414+
assert_that(
415+
last_n,
416+
equal_to([expected_last_n]),
417+
label=f"last_{sample_size}_check")
418+
331419
def test_custom_specs(self):
332420
"""Test custom specifications for ID, embedding, and content."""
333421
self.skip_if_dataflow_runner()

0 commit comments

Comments
 (0)