diff --git a/.asf.yaml b/.asf.yaml index 547296d1c67d..8f1d49fbb6f7 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -51,6 +51,7 @@ github: protected_branches: master: {} + release-2.66.0-postrelease: {} release-2.66: {} release-2.65.0-postrelease: {} release-2.65: {} diff --git a/.github/workflows/beam_PostCommit_Python_Xlang_IO_Direct.yml b/.github/workflows/beam_PostCommit_Python_Xlang_IO_Direct.yml index 35c250e8627b..c5781ee6a66d 100644 --- a/.github/workflows/beam_PostCommit_Python_Xlang_IO_Direct.yml +++ b/.github/workflows/beam_PostCommit_Python_Xlang_IO_Direct.yml @@ -17,7 +17,7 @@ name: PostCommit Python Xlang IO Direct on: schedule: - - cron: '30 5/6 * * *' + - cron: '30 4/6 * * *' pull_request_target: paths: ['release/trigger_all_tests.json', '.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Direct.json'] workflow_dispatch: diff --git a/.github/workflows/load-tests-pipeline-options/go_GBK_Flink_Batch_10b.txt b/.github/workflows/load-tests-pipeline-options/go_GBK_Flink_Batch_10b.txt index 0bb35b26436d..4691e2d92ced 100644 --- a/.github/workflows/load-tests-pipeline-options/go_GBK_Flink_Batch_10b.txt +++ b/.github/workflows/load-tests-pipeline-options/go_GBK_Flink_Batch_10b.txt @@ -16,7 +16,7 @@ --influx_namespace=flink --influx_measurement=go_batch_gbk_1 ---input_options=''{\"num_records\":5000000,\"key_size\":1,\"value_size\":9}'' +--input_options=''{\"num_records\":500000,\"key_size\":1,\"value_size\":9}'' --iterations=1 --fanout=1 --parallelism=5 diff --git a/CHANGES.md b/CHANGES.md index 5cde77cfebf8..dda05276be48 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -75,6 +75,9 @@ * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * Add pip-based install support for JupyterLab Sidepanel extension ([#35397](https://github.com/apache/beam/issues/#35397)). +* Milvus enrichment handler added (Python) ([#35216](https://github.com/apache/beam/pull/35216)). + Beam now supports Milvus enrichment handler capabilities for vector, keyword, + and hybrid search operations. ## Breaking Changes diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 9c44b526a1b5..52fcd2a63a88 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -606,16 +606,16 @@ class BeamModulePlugin implements Plugin { def dbcp2_version = "2.9.0" def errorprone_version = "2.10.0" // [bomupgrader] determined by: com.google.api:gax, consistent with: google_cloud_platform_libraries_bom - def gax_version = "2.65.0" + def gax_version = "2.67.0" def google_ads_version = "33.0.0" def google_clients_version = "2.0.0" def google_cloud_bigdataoss_version = "2.2.26" // [bomupgrader] determined by: com.google.cloud:google-cloud-spanner, consistent with: google_cloud_platform_libraries_bom - def google_cloud_spanner_version = "6.93.0" + def google_cloud_spanner_version = "6.95.1" def google_code_gson_version = "2.10.1" def google_oauth_clients_version = "1.34.1" // [bomupgrader] determined by: io.grpc:grpc-netty, consistent with: google_cloud_platform_libraries_bom - def grpc_version = "1.70.0" + def grpc_version = "1.71.0" def guava_version = "33.1.0-jre" def hadoop_version = "3.4.1" def hamcrest_version = "2.1" @@ -631,7 +631,7 @@ class BeamModulePlugin implements Plugin { def log4j2_version = "2.20.0" def nemo_version = "0.1" // [bomupgrader] determined by: io.grpc:grpc-netty, consistent with: google_cloud_platform_libraries_bom - def netty_version = "4.1.118.Final" + def netty_version = "4.1.110.Final" def postgres_version = "42.2.16" // [bomupgrader] determined by: com.google.protobuf:protobuf-java, consistent with: google_cloud_platform_libraries_bom def protobuf_version = "4.29.4" @@ -733,12 +733,12 @@ class BeamModulePlugin implements Plugin { google_api_client_gson : "com.google.api-client:google-api-client-gson:$google_clients_version", google_api_client_java6 : "com.google.api-client:google-api-client-java6:$google_clients_version", google_api_common : "com.google.api:api-common", // google_cloud_platform_libraries_bom sets version - google_api_services_bigquery : "com.google.apis:google-api-services-bigquery:v2-rev20250427-2.0.0", // [bomupgrader] sets version + google_api_services_bigquery : "com.google.apis:google-api-services-bigquery:v2-rev20250511-2.0.0", // [bomupgrader] sets version google_api_services_cloudresourcemanager : "com.google.apis:google-api-services-cloudresourcemanager:v1-rev20240310-2.0.0", // [bomupgrader] sets version google_api_services_dataflow : "com.google.apis:google-api-services-dataflow:v1b3-rev20250519-$google_clients_version", google_api_services_healthcare : "com.google.apis:google-api-services-healthcare:v1-rev20240130-$google_clients_version", google_api_services_pubsub : "com.google.apis:google-api-services-pubsub:v1-rev20220904-$google_clients_version", - google_api_services_storage : "com.google.apis:google-api-services-storage:v1-rev20250424-2.0.0", // [bomupgrader] sets version + google_api_services_storage : "com.google.apis:google-api-services-storage:v1-rev20250524-2.0.0", // [bomupgrader] sets version google_auth_library_credentials : "com.google.auth:google-auth-library-credentials", // google_cloud_platform_libraries_bom sets version google_auth_library_oauth2_http : "com.google.auth:google-auth-library-oauth2-http", // google_cloud_platform_libraries_bom sets version google_cloud_bigquery : "com.google.cloud:google-cloud-bigquery", // google_cloud_platform_libraries_bom sets version @@ -750,13 +750,13 @@ class BeamModulePlugin implements Plugin { google_cloud_core_grpc : "com.google.cloud:google-cloud-core-grpc", // google_cloud_platform_libraries_bom sets version google_cloud_datacatalog_v1beta1 : "com.google.cloud:google-cloud-datacatalog", // google_cloud_platform_libraries_bom sets version google_cloud_dataflow_java_proto_library_all: "com.google.cloud.dataflow:google-cloud-dataflow-java-proto-library-all:0.5.160304", - google_cloud_datastore_v1_proto_client : "com.google.cloud.datastore:datastore-v1-proto-client:2.28.1", // [bomupgrader] sets version + google_cloud_datastore_v1_proto_client : "com.google.cloud.datastore:datastore-v1-proto-client:2.29.1", // [bomupgrader] sets version google_cloud_firestore : "com.google.cloud:google-cloud-firestore", // google_cloud_platform_libraries_bom sets version google_cloud_pubsub : "com.google.cloud:google-cloud-pubsub", // google_cloud_platform_libraries_bom sets version google_cloud_pubsublite : "com.google.cloud:google-cloud-pubsublite", // google_cloud_platform_libraries_bom sets version // [bomupgrader] the BOM version is set by scripts/tools/bomupgrader.py. If update manually, also update // libraries-bom version on sdks/java/container/license_scripts/dep_urls_java.yaml - google_cloud_platform_libraries_bom : "com.google.cloud:libraries-bom:26.60.0", + google_cloud_platform_libraries_bom : "com.google.cloud:libraries-bom:26.62.0", google_cloud_secret_manager : "com.google.cloud:google-cloud-secretmanager", // google_cloud_platform_libraries_bom sets version google_cloud_spanner : "com.google.cloud:google-cloud-spanner", // google_cloud_platform_libraries_bom sets version google_cloud_spanner_test : "com.google.cloud:google-cloud-spanner:$google_cloud_spanner_version:tests", diff --git a/sdks/go.mod b/sdks/go.mod index 7616403b7a9b..ab5b5398de1c 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -46,7 +46,7 @@ require ( github.com/johannesboyne/gofakes3 v0.0.0-20250106100439-5c39aecd6999 github.com/lib/pq v1.10.9 github.com/linkedin/goavro/v2 v2.14.0 - github.com/nats-io/nats-server/v2 v2.11.4 + github.com/nats-io/nats-server/v2 v2.11.5 github.com/nats-io/nats.go v1.43.0 github.com/proullon/ramsql v0.1.4 github.com/spf13/cobra v1.9.1 @@ -162,7 +162,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20250326154945-ae57f3c0d45f // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect - github.com/docker/docker v28.2.2+incompatible // but required to resolve issue docker has with go1.20 + github.com/docker/docker v28.3.0+incompatible // but required to resolve issue docker has with go1.20 github.com/docker/go-units v0.5.0 // indirect github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect diff --git a/sdks/go.sum b/sdks/go.sum index 9374cc624c2d..0d57eaf618ae 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -890,8 +890,8 @@ github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5Qvfr github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko= github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= -github.com/docker/docker v28.2.2+incompatible h1:CjwRSksz8Yo4+RmQ339Dp/D2tGO5JxwYeqtMOEe0LDw= -github.com/docker/docker v28.2.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v28.3.0+incompatible h1:ffS62aKWupCWdvcee7nBU9fhnmknOqDPaJAMtfK0ImQ= +github.com/docker/docker v28.3.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= @@ -1323,8 +1323,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/nats-io/jwt/v2 v2.7.4 h1:jXFuDDxs/GQjGDZGhNgH4tXzSUK6WQi2rsj4xmsNOtI= github.com/nats-io/jwt/v2 v2.7.4/go.mod h1:me11pOkwObtcBNR8AiMrUbtVOUGkqYjMQZ6jnSdVUIA= -github.com/nats-io/nats-server/v2 v2.11.4 h1:oQhvy6He6ER926sGqIKBKuYHH4BGnUQCNb0Y5Qa+M54= -github.com/nats-io/nats-server/v2 v2.11.4/go.mod h1:jFnKKwbNeq6IfLHq+OMnl7vrFRihQ/MkhRbiWfjLdjU= +github.com/nats-io/nats-server/v2 v2.11.5 h1:yxwFASM5VrbHky6bCCame6g6fXZaayLoh7WFPWU9EEg= +github.com/nats-io/nats-server/v2 v2.11.5/go.mod h1:2xoztlcb4lDL5Blh1/BiukkKELXvKQ5Vy29FPVRBUYs= github.com/nats-io/nats.go v1.43.0 h1:uRFZ2FEoRvP64+UUhaTokyS18XBCR/xM2vQZKO4i8ug= github.com/nats-io/nats.go v1.43.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g= github.com/nats-io/nkeys v0.4.11 h1:q44qGV008kYd9W1b1nEBkNzvnWxtRSQ7A8BoqRrcfa0= diff --git a/sdks/java/container/build.gradle b/sdks/java/container/build.gradle index bc9bc45ec371..711b34b38b82 100644 --- a/sdks/java/container/build.gradle +++ b/sdks/java/container/build.gradle @@ -83,4 +83,5 @@ task pushAll { dependsOn ":sdks:java:container:java11:docker" dependsOn ":sdks:java:container:java17:docker" dependsOn ":sdks:java:container:java21:docker" + dependsOn ":sdks:java:container:distroless:pushAll" } diff --git a/sdks/java/container/license_scripts/dep_urls_java.yaml b/sdks/java/container/license_scripts/dep_urls_java.yaml index 89a113718e53..4f9f50725def 100644 --- a/sdks/java/container/license_scripts/dep_urls_java.yaml +++ b/sdks/java/container/license_scripts/dep_urls_java.yaml @@ -46,7 +46,7 @@ jaxen: '1.1.6': type: "3-Clause BSD" libraries-bom: - '26.60.0': + '26.62.0': license: "https://raw.githubusercontent.com/GoogleCloudPlatform/cloud-opensource-java/master/LICENSE" type: "Apache License 2.0" paranamer: diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/ChangeStreamResultSet.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/ChangeStreamResultSet.java index f4ffba598a4b..1268c739164f 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/ChangeStreamResultSet.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/ChangeStreamResultSet.java @@ -108,6 +108,26 @@ public Struct getCurrentRowAsStruct() { return resultSet.getCurrentRowAsStruct(); } + /** + * Returns the only change stream record proto at the current pointer of the result set. It also + * updates the timestamp at which the record was read. This function enhances the getProtoMessage + * function but only focus on the ChangeStreamRecord type. + * + * @return a change stream record as a proto or null + */ + public com.google.spanner.v1.ChangeStreamRecord getProtoChangeStreamRecord() { + recordReadAt = Timestamp.now(); + return resultSet.getProtoMessage( + 0, com.google.spanner.v1.ChangeStreamRecord.getDefaultInstance()); + } + + /** Returns true if the result set at the current pointer contain only one proto change record. */ + public boolean isProtoChangeRecord() { + return resultSet.getColumnCount() == 1 + && !resultSet.isNull(0) + && resultSet.getColumnType(0).getCode() == com.google.cloud.spanner.Type.Code.PROTO; + } + /** * Returns the record at the current pointer as {@link JsonB}. It also updates the timestamp at * which the record was read. diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/mapper/ChangeStreamRecordMapper.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/mapper/ChangeStreamRecordMapper.java index 20314566dcc7..15032ba98a21 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/mapper/ChangeStreamRecordMapper.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/mapper/ChangeStreamRecordMapper.java @@ -23,6 +23,8 @@ import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; @@ -42,7 +44,10 @@ import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.InitialPartition; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.Mod; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ModType; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionEndRecord; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionEventRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionMetadata; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionStartRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.TypeCode; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ValueCaptureType; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; @@ -223,12 +228,218 @@ public List toChangeStreamRecords( return Collections.singletonList( toChangeStreamRecordJson(partition, resultSet.getPgJsonb(0), resultSetMetadata)); } - // In GoogleSQL, change stream records are returned as an array of structs. + + // In GoogleSQL, for `MUTABLE_KEY_RANGE` option, change stream records are returned as Protos. + if (resultSet.isProtoChangeRecord()) { + return Arrays.asList( + toChangeStreamRecord( + partition, resultSet.getProtoChangeStreamRecord(), resultSetMetadata)); + } + + // In GoogleSQL, for `IMMUTABLE_KEY_RANGE` option, change stream records are returned as an array + // of structs. return resultSet.getCurrentRowAsStruct().getStructList(0).stream() .flatMap(struct -> toChangeStreamRecord(partition, struct, resultSetMetadata)) .collect(Collectors.toList()); } + ChangeStreamRecord toChangeStreamRecord( + PartitionMetadata partition, + com.google.spanner.v1.ChangeStreamRecord changeStreamRecordProto, + ChangeStreamResultSetMetadata resultSetMetadata) { + if (changeStreamRecordProto.hasPartitionStartRecord()) { + return parseProtoPartitionStartRecord( + partition, resultSetMetadata, changeStreamRecordProto.getPartitionStartRecord()); + } else if (changeStreamRecordProto.hasPartitionEndRecord()) { + return parseProtoPartitionEndRecord( + partition, resultSetMetadata, changeStreamRecordProto.getPartitionEndRecord()); + } else if (changeStreamRecordProto.hasPartitionEventRecord()) { + return parseProtoPartitionEventRecord( + partition, resultSetMetadata, changeStreamRecordProto.getPartitionEventRecord()); + } else if (changeStreamRecordProto.hasHeartbeatRecord()) { + return parseProtoHeartbeatRecord( + partition, resultSetMetadata, changeStreamRecordProto.getHeartbeatRecord()); + } else if (changeStreamRecordProto.hasDataChangeRecord()) { + return parseProtoDataChangeRecord( + partition, resultSetMetadata, changeStreamRecordProto.getDataChangeRecord()); + } else { + throw new IllegalArgumentException( + "Unknown change stream record type " + changeStreamRecordProto.toString()); + } + } + + ChangeStreamRecord parseProtoPartitionStartRecord( + PartitionMetadata partition, + ChangeStreamResultSetMetadata resultSetMetadata, + com.google.spanner.v1.ChangeStreamRecord.PartitionStartRecord partitionStartRecordProto) { + final Timestamp startTimestamp = + Timestamp.fromProto(partitionStartRecordProto.getStartTimestamp()); + return new PartitionStartRecord( + startTimestamp, + partitionStartRecordProto.getRecordSequence(), + partitionStartRecordProto.getPartitionTokensList(), + changeStreamRecordMetadataFrom(partition, startTimestamp, resultSetMetadata)); + } + + ChangeStreamRecord parseProtoPartitionEndRecord( + PartitionMetadata partition, + ChangeStreamResultSetMetadata resultSetMetadata, + com.google.spanner.v1.ChangeStreamRecord.PartitionEndRecord partitionEndRecordProto) { + final Timestamp endTimestamp = Timestamp.fromProto(partitionEndRecordProto.getEndTimestamp()); + return new PartitionEndRecord( + endTimestamp, + partitionEndRecordProto.getRecordSequence(), + changeStreamRecordMetadataFrom(partition, endTimestamp, resultSetMetadata)); + } + + ChangeStreamRecord parseProtoPartitionEventRecord( + PartitionMetadata partition, + ChangeStreamResultSetMetadata resultSetMetadata, + com.google.spanner.v1.ChangeStreamRecord.PartitionEventRecord partitionEventRecordProto) { + final Timestamp commitTimestamp = + Timestamp.fromProto(partitionEventRecordProto.getCommitTimestamp()); + return new PartitionEventRecord( + commitTimestamp, + partitionEventRecordProto.getRecordSequence(), + changeStreamRecordMetadataFrom(partition, commitTimestamp, resultSetMetadata)); + } + + ChangeStreamRecord parseProtoHeartbeatRecord( + PartitionMetadata partition, + ChangeStreamResultSetMetadata resultSetMetadata, + com.google.spanner.v1.ChangeStreamRecord.HeartbeatRecord heartbeatRecordProto) { + final Timestamp heartbeatTimestamp = Timestamp.fromProto(heartbeatRecordProto.getTimestamp()); + return new HeartbeatRecord( + heartbeatTimestamp, + changeStreamRecordMetadataFrom(partition, heartbeatTimestamp, resultSetMetadata)); + } + + ChangeStreamRecord parseProtoDataChangeRecord( + PartitionMetadata partition, + ChangeStreamResultSetMetadata resultSetMetadata, + com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord dataChangeRecordProto) { + final Timestamp commitTimestamp = + Timestamp.fromProto(dataChangeRecordProto.getCommitTimestamp()); + return new DataChangeRecord( + partition.getPartitionToken(), + commitTimestamp, + dataChangeRecordProto.getServerTransactionId(), + dataChangeRecordProto.getIsLastRecordInTransactionInPartition(), + dataChangeRecordProto.getRecordSequence(), + dataChangeRecordProto.getTable(), + parseProtoColumnMetadata(dataChangeRecordProto.getColumnMetadataList()), + parseProtoMod( + dataChangeRecordProto.getModsList(), dataChangeRecordProto.getColumnMetadataList()), + parseProtoModType(dataChangeRecordProto.getModType()), + parseProtoValueCaptureType(dataChangeRecordProto.getValueCaptureType()), + dataChangeRecordProto.getNumberOfRecordsInTransaction(), + dataChangeRecordProto.getNumberOfPartitionsInTransaction(), + dataChangeRecordProto.getTransactionTag(), + dataChangeRecordProto.getIsSystemTransaction(), + changeStreamRecordMetadataFrom(partition, commitTimestamp, resultSetMetadata)); + } + + List parseProtoColumnMetadata( + List + columnMetadataProtos) { + List columnTypes = new ArrayList<>(); + for (com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ColumnMetadata + columnMetadataProto : columnMetadataProtos) { + // TypeCode class takes json format argument in its constructor, e.g. `{\"code\":\"INT64\"}`. + String typeCodeJson; + try { + typeCodeJson = this.printer.print(columnMetadataProto.getType()); + } catch (InvalidProtocolBufferException exc) { + throw new IllegalArgumentException( + "Failed to print type: " + columnMetadataProto.getType().toString()); + } + ColumnType columnType = + new ColumnType( + columnMetadataProto.getName(), + new TypeCode(typeCodeJson), + columnMetadataProto.getIsPrimaryKey(), + columnMetadataProto.getOrdinalPosition()); + columnTypes.add(columnType); + } + return columnTypes; + } + + String convertModValueProtosToJson( + List modValueProtos, + List + columnMetadataProtos) { + com.google.protobuf.Struct.Builder modStructValueBuilder = + com.google.protobuf.Struct.newBuilder(); + for (com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ModValue modValueProto : + modValueProtos) { + final String columnName = + columnMetadataProtos.get(modValueProto.getColumnMetadataIndex()).getName(); + final Value columnValue = modValueProto.getValue(); + modStructValueBuilder.putFields(columnName, columnValue); + } + Value modStructValue = Value.newBuilder().setStructValue(modStructValueBuilder.build()).build(); + String modValueJson; + try { + modValueJson = this.printer.print(modStructValue); + } catch (InvalidProtocolBufferException exc) { + throw new IllegalArgumentException("Failed to print type: " + modStructValue); + } + return modValueJson; + } + + List parseProtoMod( + List modProtos, + List + columnMetadataProtos) { + List mods = new ArrayList<>(); + for (com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.Mod modProto : modProtos) { + final String keysJson = + convertModValueProtosToJson(modProto.getKeysList(), columnMetadataProtos); + final String oldValuesJson = + convertModValueProtosToJson(modProto.getOldValuesList(), columnMetadataProtos); + final String newValuesJson = + convertModValueProtosToJson(modProto.getNewValuesList(), columnMetadataProtos); + Mod mod = new Mod(keysJson, oldValuesJson, newValuesJson); + mods.add(mod); + } + return mods; + } + + ModType parseProtoModType( + com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ModType modTypeProto) { + if (modTypeProto == com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ModType.INSERT) { + return ModType.INSERT; + } else if (modTypeProto + == com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ModType.UPDATE) { + return ModType.UPDATE; + } else if (modTypeProto + == com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ModType.DELETE) { + return ModType.DELETE; + } + return ModType.UNKNOWN; + } + + ValueCaptureType parseProtoValueCaptureType( + com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ValueCaptureType + valueCaptureTypeProto) { + if (valueCaptureTypeProto + == com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ValueCaptureType.NEW_ROW) { + return ValueCaptureType.NEW_ROW; + } else if (valueCaptureTypeProto + == com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ValueCaptureType.NEW_VALUES) { + return ValueCaptureType.NEW_VALUES; + } else if (valueCaptureTypeProto + == com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ValueCaptureType + .OLD_AND_NEW_VALUES) { + return ValueCaptureType.OLD_AND_NEW_VALUES; + } else if (valueCaptureTypeProto + == com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ValueCaptureType + .NEW_ROW_AND_OLD_VALUES) { + return ValueCaptureType.NEW_ROW_AND_OLD_VALUES; + } + return ValueCaptureType.UNKNOWN; + } + Stream toChangeStreamRecord( PartitionMetadata partition, Struct row, ChangeStreamResultSetMetadata resultSetMetadata) { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/model/DataChangeRecord.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/model/DataChangeRecord.java index e00ef9c08ca1..ffba4d10e3a2 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/model/DataChangeRecord.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/model/DataChangeRecord.java @@ -288,7 +288,7 @@ public String toString() { + '\'' + ", isSystemTransaction=" + isSystemTransaction - + ", metadata" + + ", metadata=" + metadata + '}'; } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/GcpApiSurfaceTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/GcpApiSurfaceTest.java index 03226c2c77e3..df47964b9c5e 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/GcpApiSurfaceTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/GcpApiSurfaceTest.java @@ -88,6 +88,7 @@ public void testGcpApiSurface() throws Exception { classesInPackage("com.google.pubsub.v1"), classesInPackage("com.google.cloud.pubsublite"), Matchers.equalTo(com.google.api.gax.rpc.ApiException.class), + Matchers.equalTo(com.google.errorprone.annotations.CheckReturnValue.class), Matchers.>equalTo(com.google.api.gax.rpc.StatusCode.class), Matchers.>equalTo(com.google.api.resourcenames.ResourceName.class), Matchers.>equalTo(com.google.common.base.Function.class), diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/mapper/ChangeStreamRecordMapperTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/mapper/ChangeStreamRecordMapperTest.java index a06fb074e637..b3dd1bef049f 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/mapper/ChangeStreamRecordMapperTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/mapper/ChangeStreamRecordMapperTest.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.gcp.spanner.changestreams.mapper; import static org.apache.beam.sdk.io.gcp.spanner.changestreams.util.TestJsonMapper.recordToJson; +import static org.apache.beam.sdk.io.gcp.spanner.changestreams.util.TestProtoMapper.recordToProto; import static org.apache.beam.sdk.io.gcp.spanner.changestreams.util.TestStructMapper.recordsToStructWithJson; import static org.apache.beam.sdk.io.gcp.spanner.changestreams.util.TestStructMapper.recordsToStructWithStrings; import static org.apache.beam.sdk.io.gcp.spanner.changestreams.util.TestStructMapper.recordsWithUnknownModTypeAndValueCaptureType; @@ -41,8 +42,11 @@ import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.InitialPartition; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.Mod; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ModType; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionEndRecord; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionEventRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionMetadata; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionMetadata.State; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionStartRecord; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.TypeCode; import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ValueCaptureType; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; @@ -930,4 +934,109 @@ public void testMappingJsonRowToChildPartitionRecord() { Collections.singletonList(childPartitionsRecord), mapperPostgres.toChangeStreamRecords(partition, resultSet, resultSetMetadata)); } + + @Test + public void testMappingProtoRowToPartitionStartRecord() { + final PartitionStartRecord partitionStartRecord = + new PartitionStartRecord( + Timestamp.MIN_VALUE, + "fakeRecordSequence", + Arrays.asList("partitionToken1", "partitionToken2"), + null); + com.google.spanner.v1.ChangeStreamRecord changeStreamRecordProto = + recordToProto(partitionStartRecord); + assertNotNull(changeStreamRecordProto); + ChangeStreamResultSet resultSet = mock(ChangeStreamResultSet.class); + + when(resultSet.isProtoChangeRecord()).thenReturn(true); + when(resultSet.getProtoChangeStreamRecord()).thenReturn(changeStreamRecordProto); + assertEquals( + Collections.singletonList(partitionStartRecord), + mapper.toChangeStreamRecords(partition, resultSet, resultSetMetadata)); + } + + @Test + public void testMappingProtoRowToPartitionEndRecord() { + final PartitionEndRecord partitionEndChange = + new PartitionEndRecord(Timestamp.MIN_VALUE, "fakeRecordSequence", null); + com.google.spanner.v1.ChangeStreamRecord changeStreamRecordProto = + recordToProto(partitionEndChange); + assertNotNull(changeStreamRecordProto); + ChangeStreamResultSet resultSet = mock(ChangeStreamResultSet.class); + + when(resultSet.isProtoChangeRecord()).thenReturn(true); + when(resultSet.getProtoChangeStreamRecord()).thenReturn(changeStreamRecordProto); + assertEquals( + Collections.singletonList(partitionEndChange), + mapper.toChangeStreamRecords(partition, resultSet, resultSetMetadata)); + } + + @Test + public void testMappingProtoRowToPartitionEventRecord() { + final PartitionEventRecord partitionEventRecord = + new PartitionEventRecord(Timestamp.MIN_VALUE, "fakeRecordSequence", null); + com.google.spanner.v1.ChangeStreamRecord changeStreamRecordProto = + recordToProto(partitionEventRecord); + assertNotNull(changeStreamRecordProto); + ChangeStreamResultSet resultSet = mock(ChangeStreamResultSet.class); + + when(resultSet.isProtoChangeRecord()).thenReturn(true); + when(resultSet.getProtoChangeStreamRecord()).thenReturn(changeStreamRecordProto); + assertEquals( + Collections.singletonList(partitionEventRecord), + mapper.toChangeStreamRecords(partition, resultSet, resultSetMetadata)); + } + + @Test + public void testMappingProtoRowToHeartbeatRecord() { + final HeartbeatRecord heartbeatRecord = + new HeartbeatRecord(Timestamp.ofTimeSecondsAndNanos(10L, 20), null); + com.google.spanner.v1.ChangeStreamRecord changeStreamRecordProto = + recordToProto(heartbeatRecord); + assertNotNull(changeStreamRecordProto); + ChangeStreamResultSet resultSet = mock(ChangeStreamResultSet.class); + + when(resultSet.isProtoChangeRecord()).thenReturn(true); + when(resultSet.getProtoChangeStreamRecord()).thenReturn(changeStreamRecordProto); + assertEquals( + Collections.singletonList(heartbeatRecord), + mapper.toChangeStreamRecords(partition, resultSet, resultSetMetadata)); + } + + @Test + public void testMappingProtoRowToDataChangeRecord() { + final DataChangeRecord dataChangeRecord = + new DataChangeRecord( + "partitionToken", + Timestamp.ofTimeSecondsAndNanos(10L, 20), + "serverTransactionId", + true, + "1", + "tableName", + Arrays.asList( + new ColumnType("column1", new TypeCode("{\"code\":\"INT64\"}"), true, 1L), + new ColumnType("column2", new TypeCode("{\"code\":\"BYTES\"}"), false, 2L)), + Collections.singletonList( + new Mod( + "{\"column1\":\"value1\"}", + "{\"column2\":\"oldValue2\"}", + "{\"column2\":\"newValue2\"}")), + ModType.UPDATE, + ValueCaptureType.OLD_AND_NEW_VALUES, + 10L, + 2L, + "transactionTag", + true, + null); + com.google.spanner.v1.ChangeStreamRecord changeStreamRecordProto = + recordToProto(dataChangeRecord); + assertNotNull(changeStreamRecordProto); + ChangeStreamResultSet resultSet = mock(ChangeStreamResultSet.class); + + when(resultSet.isProtoChangeRecord()).thenReturn(true); + when(resultSet.getProtoChangeStreamRecord()).thenReturn(changeStreamRecordProto); + assertEquals( + Collections.singletonList(dataChangeRecord), + mapper.toChangeStreamRecords(partition, resultSet, resultSetMetadata)); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/util/TestProtoMapper.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/util/TestProtoMapper.java new file mode 100644 index 000000000000..6cf5958f03e9 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/util/TestProtoMapper.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner.changestreams.util; + +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ChangeStreamRecord; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ColumnType; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.DataChangeRecord; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.HeartbeatRecord; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.Mod; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ModType; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionEndRecord; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionEventRecord; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.PartitionStartRecord; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.ValueCaptureType; + +// Test util class to convert ChangeStreamRecord class to proto represenatation. Similar to +// TestJsonMapper and TestStructMapper. +public class TestProtoMapper { + + public static com.google.spanner.v1.ChangeStreamRecord recordToProto(ChangeStreamRecord record) { + if (record instanceof PartitionStartRecord) { + return convertPartitionStartRecordToProto((PartitionStartRecord) record); + } else if (record instanceof PartitionEndRecord) { + return convertPartitionEndRecordToProto((PartitionEndRecord) record); + } else if (record instanceof PartitionEventRecord) { + return convertPartitionEventRecordToProto((PartitionEventRecord) record); + } else if (record instanceof HeartbeatRecord) { + return convertHeartbeatRecordToProto((HeartbeatRecord) record); + } else if (record instanceof DataChangeRecord) { + return convertDataChangeRecordToProto((DataChangeRecord) record); + } else { + throw new UnsupportedOperationException("Unimplemented mapping for " + record.getClass()); + } + } + + private static com.google.spanner.v1.ChangeStreamRecord convertPartitionStartRecordToProto( + PartitionStartRecord partitionStartRecord) { + com.google.spanner.v1.ChangeStreamRecord.PartitionStartRecord partitionStartRecordProto = + com.google.spanner.v1.ChangeStreamRecord.PartitionStartRecord.newBuilder() + .setStartTimestamp(partitionStartRecord.getStartTimestamp().toProto()) + .setRecordSequence(partitionStartRecord.getRecordSequence()) + .addAllPartitionTokens(partitionStartRecord.getPartitionTokens()) + .build(); + com.google.spanner.v1.ChangeStreamRecord changeStreamRecordProto = + com.google.spanner.v1.ChangeStreamRecord.newBuilder() + .setPartitionStartRecord(partitionStartRecordProto) + .build(); + return changeStreamRecordProto; + } + + private static com.google.spanner.v1.ChangeStreamRecord convertPartitionEndRecordToProto( + PartitionEndRecord partitionEndRecord) { + com.google.spanner.v1.ChangeStreamRecord.PartitionEndRecord partitionEndRecordProto = + com.google.spanner.v1.ChangeStreamRecord.PartitionEndRecord.newBuilder() + .setEndTimestamp(partitionEndRecord.getEndTimestamp().toProto()) + .setRecordSequence(partitionEndRecord.getRecordSequence()) + .build(); + com.google.spanner.v1.ChangeStreamRecord changeStreamRecordProto = + com.google.spanner.v1.ChangeStreamRecord.newBuilder() + .setPartitionEndRecord(partitionEndRecordProto) + .build(); + return changeStreamRecordProto; + } + + private static com.google.spanner.v1.ChangeStreamRecord convertPartitionEventRecordToProto( + PartitionEventRecord partitionEventRecord) { + com.google.spanner.v1.ChangeStreamRecord.PartitionEventRecord partitionEventRecordProto = + com.google.spanner.v1.ChangeStreamRecord.PartitionEventRecord.newBuilder() + .setCommitTimestamp(partitionEventRecord.getCommitTimestamp().toProto()) + .setRecordSequence(partitionEventRecord.getRecordSequence()) + .build(); + com.google.spanner.v1.ChangeStreamRecord changeStreamRecordProto = + com.google.spanner.v1.ChangeStreamRecord.newBuilder() + .setPartitionEventRecord(partitionEventRecordProto) + .build(); + return changeStreamRecordProto; + } + + private static com.google.spanner.v1.ChangeStreamRecord convertHeartbeatRecordToProto( + HeartbeatRecord heartbeatRecord) { + com.google.spanner.v1.ChangeStreamRecord.HeartbeatRecord heartbeatRecordProto = + com.google.spanner.v1.ChangeStreamRecord.HeartbeatRecord.newBuilder() + .setTimestamp(heartbeatRecord.getTimestamp().toProto()) + .build(); + com.google.spanner.v1.ChangeStreamRecord changeStreamRecordProto = + com.google.spanner.v1.ChangeStreamRecord.newBuilder() + .setHeartbeatRecord(heartbeatRecordProto) + .build(); + return changeStreamRecordProto; + } + + private static com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ModType getProtoModType( + ModType modType) { + if (modType == ModType.INSERT) { + return com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ModType.INSERT; + } else if (modType == ModType.UPDATE) { + return com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ModType.UPDATE; + } else if (modType == ModType.DELETE) { + return com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ModType.DELETE; + } + return com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ModType.MOD_TYPE_UNSPECIFIED; + } + + private static com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ValueCaptureType + getProtoValueCaptureTypeProto(ValueCaptureType valueCaptureType) { + if (valueCaptureType == ValueCaptureType.NEW_ROW) { + return com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ValueCaptureType.NEW_ROW; + } + if (valueCaptureType == ValueCaptureType.NEW_VALUES) { + return com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ValueCaptureType.NEW_VALUES; + } + if (valueCaptureType == ValueCaptureType.OLD_AND_NEW_VALUES) { + return com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ValueCaptureType + .OLD_AND_NEW_VALUES; + } + if (valueCaptureType == ValueCaptureType.NEW_ROW_AND_OLD_VALUES) { + return com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ValueCaptureType + .NEW_ROW_AND_OLD_VALUES; + } + return com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ValueCaptureType + .VALUE_CAPTURE_TYPE_UNSPECIFIED; + } + + private static List + getProtoColumnMetadata(List columnTypes) { + JsonFormat.Parser jsonParser = JsonFormat.parser().ignoringUnknownFields(); + + List + columnMetaDataProtos = new ArrayList<>(); + for (ColumnType columnType : columnTypes) { + // TypeCode class contains json format type code, e.g. {\"code\":\"INT64\"}. We need to + // extract "INT64" type code. + Value.Builder typeCodeJson = Value.newBuilder(); + try { + jsonParser.merge(columnType.getType().getCode(), typeCodeJson); + } catch (InvalidProtocolBufferException exc) { + throw new IllegalArgumentException( + "Failed to parse json type code into proto: " + columnType.getType().getCode()); + } + Value typeCode = + Optional.ofNullable(typeCodeJson.build().getStructValue().getFieldsMap().get("code")) + .orElseThrow(IllegalArgumentException::new); + + com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ColumnMetadata columnMetadataProto = + com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ColumnMetadata.newBuilder() + .setName(columnType.getName()) + .setType( + com.google.spanner.v1.Type.newBuilder() + .setCode(com.google.spanner.v1.TypeCode.valueOf(typeCode.getStringValue()))) + .setIsPrimaryKey(columnType.isPrimaryKey()) + .setOrdinalPosition(columnType.getOrdinalPosition()) + .build(); + columnMetaDataProtos.add(columnMetadataProto); + } + return columnMetaDataProtos; + } + + private static List + columnsJsonToProtos(String columnsJson, Map columnNameToIndex) { + List modValueProtos = + new ArrayList<>(); + JsonFormat.Parser jsonParser = JsonFormat.parser().ignoringUnknownFields(); + Value.Builder columnsJsonValue = Value.newBuilder(); + try { + jsonParser.merge(columnsJson, columnsJsonValue); + } catch (InvalidProtocolBufferException exc) { + throw new IllegalArgumentException( + "Failed to parse json type columns into proto: " + columnsJson); + } + Map columns = columnsJsonValue.build().getStructValue().getFieldsMap(); + for (Map.Entry entry : columns.entrySet()) { + final String columnName = entry.getKey(); + final String columnValue = entry.getValue().getStringValue(); + final Integer columnIndex = columnNameToIndex.get(columnName); + + com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ModValue modValueProto = + com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.ModValue.newBuilder() + .setColumnMetadataIndex(columnIndex) + .setValue(Value.newBuilder().setStringValue(columnValue).build()) + .build(); + modValueProtos.add(modValueProto); + } + return modValueProtos; + } + + private static List getProtoMods( + List mods, List columnTypes) { + Map columnNameToIndex = new HashMap<>(); + for (int i = 0; i < columnTypes.size(); ++i) { + columnNameToIndex.put(columnTypes.get(i).getName(), i); + } + List modProtos = + new ArrayList<>(); + for (Mod mod : mods) { + com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.Mod modProto = + com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.Mod.newBuilder() + .addAllKeys(columnsJsonToProtos(mod.getKeysJson(), columnNameToIndex)) + .addAllOldValues(columnsJsonToProtos(mod.getOldValuesJson(), columnNameToIndex)) + .addAllNewValues(columnsJsonToProtos(mod.getNewValuesJson(), columnNameToIndex)) + .build(); + modProtos.add(modProto); + } + return modProtos; + } + + private static com.google.spanner.v1.ChangeStreamRecord convertDataChangeRecordToProto( + DataChangeRecord dataChangeRecord) { + com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord dataChangeRecordProto = + com.google.spanner.v1.ChangeStreamRecord.DataChangeRecord.newBuilder() + .setCommitTimestamp(dataChangeRecord.getCommitTimestamp().toProto()) + .setRecordSequence(dataChangeRecord.getRecordSequence()) + .setServerTransactionId(dataChangeRecord.getServerTransactionId()) + .setIsLastRecordInTransactionInPartition( + dataChangeRecord.isLastRecordInTransactionInPartition()) + .setTable(dataChangeRecord.getTableName()) + .addAllColumnMetadata(getProtoColumnMetadata(dataChangeRecord.getRowType())) + .addAllMods(getProtoMods(dataChangeRecord.getMods(), dataChangeRecord.getRowType())) + .setModType(getProtoModType(dataChangeRecord.getModType())) + .setValueCaptureType( + getProtoValueCaptureTypeProto(dataChangeRecord.getValueCaptureType())) + .setNumberOfRecordsInTransaction( + (int) dataChangeRecord.getNumberOfRecordsInTransaction()) + .setNumberOfPartitionsInTransaction( + (int) dataChangeRecord.getNumberOfPartitionsInTransaction()) + .setTransactionTag(dataChangeRecord.getTransactionTag()) + .setIsSystemTransaction(dataChangeRecord.isSystemTransaction()) + .build(); + com.google.spanner.v1.ChangeStreamRecord changeStreamRecordProto = + com.google.spanner.v1.ChangeStreamRecord.newBuilder() + .setDataChangeRecord(dataChangeRecordProto) + .build(); + return changeStreamRecordProto; + } +} diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index 9f5d49f1fd2a..dbd0a301bb0d 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -24,6 +24,8 @@ import logging import math import pickle +import subprocess +import sys import textwrap import unittest from decimal import Decimal @@ -610,6 +612,10 @@ def test_param_windowed_value_coder(self): def test_cross_process_encoding_of_special_types_is_deterministic(self): """Test cross-process determinism for all special deterministic types""" + + if sys.executable is None: + self.skipTest('No Python interpreter found') + # pylint: disable=line-too-long script = textwrap.dedent( '''\ @@ -711,9 +717,6 @@ def __eq__(self, other): ''') def run_subprocess(): - import subprocess - import sys - result = subprocess.run([sys.executable, '-c', script], capture_output=True, timeout=30, diff --git a/sdks/python/apache_beam/io/gcp/gcsio_test.py b/sdks/python/apache_beam/io/gcp/gcsio_test.py index 0f28425c1c01..4c18647729e3 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsio_test.py @@ -174,7 +174,8 @@ def lookup_blob(self, name): if name in bucket.blobs: return bucket.blobs[name] else: - return bucket.create_blob(name) + new_blob = bucket._create_blob(name) + return bucket.add_blob(new_blob) def set_default_kms_key_name(self, name): self.default_kms_key_name = name diff --git a/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py b/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py index 77eb27ed37ba..f6117a260a34 100644 --- a/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py +++ b/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py @@ -220,11 +220,11 @@ def format_query(self, chunks: List[Chunk]) -> str: # Create embeddings subquery for this group embedding_unions = [] for chunk in group_chunks: - if chunk.embedding is None or chunk.embedding.dense_embedding is None: + if not chunk.dense_embedding: raise ValueError(f"Chunk {chunk.id} missing embedding") embedding_str = ( f"SELECT '{chunk.id}' as id, " - f"{[float(x) for x in chunk.embedding.dense_embedding]} " + f"{[float(x) for x in chunk.dense_embedding]} " f"as embedding") embedding_unions.append(embedding_str) group_embeddings = " UNION ALL ".join(embedding_unions) diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py new file mode 100644 index 000000000000..a0f597f5366f --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py @@ -0,0 +1,599 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections.abc import Sequence +from dataclasses import dataclass +from dataclasses import field +from enum import Enum +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +from google.protobuf.json_format import MessageToDict +from pymilvus import AnnSearchRequest +from pymilvus import Hit +from pymilvus import Hits +from pymilvus import MilvusClient +from pymilvus import SearchResult + +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Embedding +from apache_beam.transforms.enrichment import EnrichmentSourceHandler + + +class SearchStrategy(Enum): + """Search strategies for information retrieval. + + Args: + HYBRID: Combines vector and keyword search approaches. Leverages both + semantic understanding and exact matching. Typically provides the most + comprehensive results. Useful for queries with both conceptual and + specific keyword components. + VECTOR: Vector similarity search only. Based on semantic similarity between + query and documents. Effective for conceptual searches and finding related + content. Less sensitive to exact terminology than keyword search. + KEYWORD: Keyword/text search only. Based on exact or fuzzy matching of + specific terms. Effective for precise queries where exact wording matters. + Less effective for conceptual or semantic searches. + """ + HYBRID = "hybrid" + VECTOR = "vector" + KEYWORD = "keyword" + + +class KeywordSearchMetrics(Enum): + """Metrics for keyword search. + + Args: + BM25: Range [0 to ∞), Best Match 25 ranking algorithm for text relevance. + Combines term frequency, inverse document frequency, and document length. + Higher scores indicate greater relevance. Higher scores indicate greater + relevance. Takes into account diminishing returns of term frequency. + Balances between exact matching and semantic relevance. + """ + BM25 = "BM25" + + +class VectorSearchMetrics(Enum): + """Metrics for vector search. + + Args: + COSINE: Range [-1 to 1], higher values indicate greater similarity. Value 1 + means vectors point in identical direction. Value 0 means vectors are + perpendicular to each other (no relationship). Value -1 means vectors + point in exactly opposite directions. + EUCLIDEAN_DISTANCE (L2): Range [0 to ∞), lower values indicate greater + similarity. Value 0 means vectors are identical. Larger values mean more + dissimilarity between vectors. + INNER_PRODUCT (IP): Range varies based on vector magnitudes, higher values + indicate greater similarity. Value 0 means vectors are perpendicular to + each other. Positive values mean vectors share some directional component. + Negative values mean vectors point in opposing directions. + """ + COSINE = "COSINE" + EUCLIDEAN_DISTANCE = "L2" + INNER_PRODUCT = "IP" + + +class MilvusBaseRanker: + """Base class for ranking algorithms in Milvus hybrid search strategy.""" + def __int__(self): + return + + def dict(self): + return {} + + def __str__(self): + return self.dict().__str__() + + +@dataclass +class MilvusConnectionParameters: + """Parameters for establishing connections to Milvus servers. + + Args: + uri: URI endpoint for connecting to Milvus server in the format + "http(s)://hostname:port". + user: Username for authentication. Required if authentication is enabled and + not using token authentication. + password: Password for authentication. Required if authentication is enabled + and not using token authentication. + db_id: Database ID to connect to. Specifies which Milvus database to use. + Defaults to 'default'. + token: Authentication token as an alternative to username/password. + timeout: Connection timeout in seconds. Uses client default if None. + kwargs: Optional keyword arguments for additional connection parameters. + Enables forward compatibility. + """ + uri: str + user: str = field(default_factory=str) + password: str = field(default_factory=str) + db_id: str = "default" + token: str = field(default_factory=str) + timeout: Optional[float] = None + kwargs: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.uri: + raise ValueError("URI must be provided for Milvus connection") + + +@dataclass +class BaseSearchParameters: + """Base parameters for both vector and keyword search operations. + + Args: + anns_field: Approximate nearest neighbor search field indicates field name + containing the embedding to search. Required for both vector and keyword + search. + limit: Maximum number of results to return per query. Must be positive. + Defaults to 3 search results. + filter: Boolean expression string for filtering search results. + Example: 'price <= 1000 AND category == "electronics"'. + search_params: Additional search parameters specific to the search type. + Example: {"metric_type": VectorSearchMetrics.EUCLIDEAN_DISTANCE}. + consistency_level: Consistency level for read operations. + Options: "Strong", "Session", "Bounded", "Eventually". Defaults to + "Bounded" if not specified when creating the collection. + """ + anns_field: str + limit: int = 3 + filter: str = field(default_factory=str) + search_params: Dict[str, Any] = field(default_factory=dict) + consistency_level: Optional[str] = None + + def __post_init__(self): + if not self.anns_field: + raise ValueError( + "Approximate Nearest Neighbor Search (ANNS) field must be provided") + + if self.limit <= 0: + raise ValueError(f"Search limit must be positive, got {self.limit}") + + +@dataclass +class VectorSearchParameters(BaseSearchParameters): + """Parameters for vector similarity search operations. + + Inherits all parameters from BaseSearchParameters with the same semantics. + The anns_field should contain dense vector embeddings for this search type. + + Args: + kwargs: Optional keyword arguments for additional vector search parameters. + Enables forward compatibility. + + Note: + For inherited parameters documentation, see BaseSearchParameters. + """ + kwargs: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class KeywordSearchParameters(BaseSearchParameters): + """Parameters for keyword/text search operations. + + This class inherits all parameters from BaseSearchParameters with the same + semantics. The anns_field should contain sparse vector embeddings content for + this search type. + + Args: + kwargs: Optional keyword arguments for additional keyword search parameters. + Enables forward compatibility. + + Note: + For inherited parameters documentation, see BaseSearchParameters. + """ + kwargs: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class HybridSearchParameters: + """Parameters for hybrid (vector + keyword) search operations. + + Args: + vector: Parameters for the vector search component. + keyword: Parameters for the keyword search component. + ranker: Ranker for combining vector and keyword search results. + Example: RRFRanker(k=100). + limit: Maximum number of results to return per query. Defaults to 3 search + results. + kwargs: Optional keyword arguments for additional hybrid search parameters. + Enables forward compatibility. + """ + vector: VectorSearchParameters + keyword: KeywordSearchParameters + ranker: MilvusBaseRanker + limit: int = 3 + kwargs: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.vector or not self.keyword: + raise ValueError( + "Vector and keyword search parameters must be provided for " + "hybrid search") + + if not self.ranker: + raise ValueError("Ranker must be provided for hybrid search") + + if self.limit <= 0: + raise ValueError(f"Search limit must be positive, got {self.limit}") + + +SearchStrategyType = Union[VectorSearchParameters, + KeywordSearchParameters, + HybridSearchParameters] + + +@dataclass +class MilvusSearchParameters: + """Parameters configuring Milvus search operations. + + This class encapsulates all parameters needed to execute searches against + Milvus collections, supporting vector, keyword, and hybrid search strategies. + + Args: + collection_name: Name of the collection to search in. + search_strategy: Type of search to perform (VECTOR, KEYWORD, or HYBRID). + partition_names: List of partition names to restrict the search to. If + empty, all partitions will be searched. + output_fields: List of field names to include in search results. If empty, + only primary fields including distances will be returned. + timeout: Search operation timeout in seconds. If not specified, the client's + default timeout is used. + round_decimal: Number of decimal places for distance/similarity scores. + Defaults to -1 means no rounding. + """ + collection_name: str + search_strategy: SearchStrategyType + partition_names: List[str] = field(default_factory=list) + output_fields: List[str] = field(default_factory=list) + timeout: Optional[float] = None + round_decimal: int = -1 + + def __post_init__(self): + if not self.collection_name: + raise ValueError("Collection name must be provided") + + if not self.search_strategy: + raise ValueError("Search strategy must be provided") + + +@dataclass +class MilvusCollectionLoadParameters: + """Parameters that control how Milvus loads a collection into memory. + + This class provides fine-grained control over collection loading, which is + particularly important in resource-constrained environments. Proper + configuration can significantly reduce memory usage and improve query + performance by loading only necessary data. + + Args: + refresh: If True, forces a reload of the collection even if already loaded. + Ensures the most up-to-date data is in memory. + resource_groups: List of resource groups to load the collection into. Can be + used for load balancing across multiple query nodes. + load_fields: Specify which fields to load into memory. Loading only + necessary fields reduces memory usage. If empty, all fields loaded. + skip_load_dynamic_field: If True, dynamic/growing fields will not be loaded + into memory. Saves memory when dynamic fields aren't needed. + kwargs: Optional keyword arguments for additional collection load + parameters. Enables forward compatibility. + """ + refresh: bool = field(default_factory=bool) + resource_groups: List[str] = field(default_factory=list) + load_fields: List[str] = field(default_factory=list) + skip_load_dynamic_field: bool = field(default_factory=bool) + kwargs: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class MilvusSearchResult: + """Search result from Milvus per chunk. + + Args: + id: List of entity IDs returned from the search. Can be either string or + integer IDs. + distance: List of distances/similarity scores for each returned entity. + fields: List of dictionaries containing additional field values for each + entity. Each dictionary corresponds to one returned entity. + """ + id: List[Union[str, int]] = field(default_factory=list) + distance: List[float] = field(default_factory=list) + fields: List[Dict[str, Any]] = field(default_factory=list) + + +InputT, OutputT = Union[Chunk, List[Chunk]], List[Tuple[Chunk, Dict[str, Any]]] + + +class MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]): + """Enrichment handler for Milvus vector database searches. + + This handler is designed to work with the + :class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler` transform. + It enables enriching data through vector similarity, keyword, or hybrid + searches against Milvus collections. + + The handler supports different search strategies: + * Vector search - For finding similar embeddings based on vector similarity + * Keyword search - For text-based retrieval using BM25 or other text metrics + * Hybrid search - For combining vector and keyword search results + + This handler queries the Milvus database per element by default. To enable + batching for improved performance, set the `min_batch_size` and + `max_batch_size` parameters. These control the batching behavior in the + :class:`apache_beam.transforms.utils.BatchElements` transform. + + For memory-intensive operations, the handler allows fine-grained control over + collection loading through the `collection_load_parameters`. + """ + def __init__( + self, + connection_parameters: MilvusConnectionParameters, + search_parameters: MilvusSearchParameters, + *, + collection_load_parameters: Optional[MilvusCollectionLoadParameters], + min_batch_size: int = 1, + max_batch_size: int = 1000, + **kwargs): + """ + Example Usage: + connection_paramters = MilvusConnectionParameters( + uri="http://localhost:19530") + search_parameters = MilvusSearchParameters( + collection_name="my_collection", + search_strategy=VectorSearchParameters(anns_field="embedding")) + collection_load_parameters = MilvusCollectionLoadParameters( + load_fields=["embedding", "metadata"]), + milvus_handler = MilvusSearchEnrichmentHandler( + connection_paramters, + search_parameters, + collection_load_parameters=collection_load_parameters, + min_batch_size=10, + max_batch_size=100) + + Args: + connection_parameters (MilvusConnectionParameters): Configuration for + connecting to the Milvus server, including URI, credentials, and + connection options. + search_parameters (MilvusSearchParameters): Configuration for search + operations, including collection name, search strategy, and output + fields. + collection_load_parameters (Optional[MilvusCollectionLoadParameters]): + Parameters controlling how collections are loaded into memory, which can + significantly impact resource usage and performance. + min_batch_size (int): Minimum number of elements to batch together when + querying Milvus. Default is 1 (no batching when max_batch_size is 1). + max_batch_size (int): Maximum number of elements to batch together.Default + is 1000. Higher values may improve throughput but increase memory usage. + **kwargs: Additional keyword arguments for Milvus Enrichment Handler. + + Note: + * For large collections, consider setting appropriate values in + collection_load_parameters to reduce memory usage. + * The search_strategy in search_parameters determines the type of search + (vector, keyword, or hybrid) and associated parameters. + * Batching can significantly improve performance but requires more memory. + """ + self._connection_parameters = connection_parameters + self._search_parameters = search_parameters + self._collection_load_parameters = collection_load_parameters + if not self._collection_load_parameters: + self._collection_load_parameters = MilvusCollectionLoadParameters() + self._batching_kwargs = { + 'min_batch_size': min_batch_size, 'max_batch_size': max_batch_size + } + self.kwargs = kwargs + self.join_fn = join_fn + self.use_custom_types = True + + def __enter__(self): + connection_params = unpack_dataclass_with_kwargs( + self._connection_parameters) + collection_load_params = unpack_dataclass_with_kwargs( + self._collection_load_parameters) + self._client = MilvusClient(**connection_params) + self._client.load_collection( + collection_name=self.collection_name, + partition_names=self.partition_names, + **collection_load_params) + + def __call__(self, request: Union[Chunk, List[Chunk]], *args, + **kwargs) -> List[Tuple[Chunk, Dict[str, Any]]]: + reqs = request if isinstance(request, list) else [request] + search_result = self._search_documents(reqs) + return self._get_call_response(reqs, search_result) + + def _search_documents(self, chunks: List[Chunk]): + if isinstance(self.search_strategy, HybridSearchParameters): + data = self._get_hybrid_search_data(chunks) + return self._client.hybrid_search( + collection_name=self.collection_name, + partition_names=self.partition_names, + output_fields=self.output_fields, + timeout=self.timeout, + round_decimal=self.round_decimal, + reqs=data, + ranker=self.search_strategy.ranker, + limit=self.search_strategy.limit, + **self.search_strategy.kwargs) + elif isinstance(self.search_strategy, VectorSearchParameters): + data = list(map(self._get_vector_search_data, chunks)) + vector_search_params = unpack_dataclass_with_kwargs(self.search_strategy) + return self._client.search( + collection_name=self.collection_name, + partition_names=self.partition_names, + output_fields=self.output_fields, + timeout=self.timeout, + round_decimal=self.round_decimal, + data=data, + **vector_search_params) + elif isinstance(self.search_strategy, KeywordSearchParameters): + data = list(map(self._get_keyword_search_data, chunks)) + keyword_search_params = unpack_dataclass_with_kwargs(self.search_strategy) + return self._client.search( + collection_name=self.collection_name, + partition_names=self.partition_names, + output_fields=self.output_fields, + timeout=self.timeout, + round_decimal=self.round_decimal, + data=data, + **keyword_search_params) + else: + raise ValueError( + f"Not supported search strategy yet: {self.search_strategy}") + + def _get_hybrid_search_data(self, chunks: List[Chunk]): + vector_search_data = list(map(self._get_vector_search_data, chunks)) + keyword_search_data = list(map(self._get_keyword_search_data, chunks)) + + vector_search_req = AnnSearchRequest( + data=vector_search_data, + anns_field=self.search_strategy.vector.anns_field, + param=self.search_strategy.vector.search_params, + limit=self.search_strategy.vector.limit, + expr=self.search_strategy.vector.filter) + + keyword_search_req = AnnSearchRequest( + data=keyword_search_data, + anns_field=self.search_strategy.keyword.anns_field, + param=self.search_strategy.keyword.search_params, + limit=self.search_strategy.keyword.limit, + expr=self.search_strategy.keyword.filter) + + reqs = [vector_search_req, keyword_search_req] + return reqs + + def _get_vector_search_data(self, chunk: Chunk): + if not chunk.dense_embedding: + raise ValueError( + f"Chunk {chunk.id} missing dense embedding required for vector search" + ) + return chunk.dense_embedding + + def _get_keyword_search_data(self, chunk: Chunk): + if not chunk.content.text and not chunk.sparse_embedding: + raise ValueError( + f"Chunk {chunk.id} missing both text content and sparse embedding " + "required for keyword search") + + sparse_embedding = self.convert_sparse_embedding_to_milvus_format( + chunk.sparse_embedding) + + return chunk.content.text or sparse_embedding + + def _get_call_response( + self, chunks: List[Chunk], search_result: SearchResult[Hits]): + response = [] + for i in range(len(chunks)): + chunk = chunks[i] + hits: Hits = search_result[i] + result = MilvusSearchResult() + for i in range(len(hits)): + hit: Hit = hits[i] + normalized_fields = self._normalize_milvus_fields(hit.fields) + result.id.append(hit.id) + result.distance.append(hit.distance) + result.fields.append(normalized_fields) + response.append((chunk, result.__dict__)) + return response + + def _normalize_milvus_fields(self, fields: Dict[str, Any]): + normalized_fields = {} + for field, value in fields.items(): + value = self._normalize_milvus_value(value) + normalized_fields[field] = value + return normalized_fields + + def _normalize_milvus_value(self, value: Any): + # Convert Milvus-specific types to Python native types. + neither_str_nor_dict_nor_bytes = not isinstance(value, (str, dict, bytes)) + if isinstance(value, Sequence) and neither_str_nor_dict_nor_bytes: + return list(value) + elif hasattr(value, 'DESCRIPTOR'): + # Handle protobuf messages. + return MessageToDict(value) + else: + # Keep other types as they are. + return value + + def convert_sparse_embedding_to_milvus_format( + self, sparse_vector: Tuple[List[int], List[float]]) -> Dict[int, float]: + if not sparse_vector: + return None + # Converts sparse embedding from (indices, values) tuple format to + # Milvus-compatible values dict format {dimension_index: value, ...}. + indices, values = sparse_vector + return {int(idx): float(val) for idx, val in zip(indices, values)} + + @property + def collection_name(self): + """Getter method for collection_name property""" + return self._search_parameters.collection_name + + @property + def search_strategy(self): + """Getter method for search_strategy property""" + return self._search_parameters.search_strategy + + @property + def partition_names(self): + """Getter method for partition_names property""" + return self._search_parameters.partition_names + + @property + def output_fields(self): + """Getter method for output_fields property""" + return self._search_parameters.output_fields + + @property + def timeout(self): + """Getter method for search timeout property""" + return self._search_parameters.timeout + + @property + def round_decimal(self): + """Getter method for search round_decimal property""" + return self._search_parameters.round_decimal + + def __exit__(self, exc_type, exc_val, exc_tb): + self._client.release_collection(self.collection_name) + self._client.close() + self._client = None + + def batch_elements_kwargs(self) -> Dict[str, int]: + """Returns kwargs for beam.BatchElements.""" + return self._batching_kwargs + + +def join_fn(left: Embedding, right: Dict[str, Any]) -> Embedding: + left.metadata['enrichment_data'] = right + return left + + +def unpack_dataclass_with_kwargs(dataclass_instance): + # Create a copy of the dataclass's __dict__. + params_dict: dict = dataclass_instance.__dict__.copy() + + # Extract the nested kwargs dictionary. + nested_kwargs = params_dict.pop('kwargs', {}) + + # Merge the dictionaries, with nested_kwargs taking precedence + # in case of duplicate keys. + return {**params_dict, **nested_kwargs} diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py new file mode 100644 index 000000000000..ebc05722841c --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py @@ -0,0 +1,1371 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import contextlib +import logging +import os +import platform +import re +import socket +import tempfile +import unittest +from collections import defaultdict +from dataclasses import dataclass +from dataclasses import field +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import cast + +import pytest +import yaml +from pymilvus import CollectionSchema +from pymilvus import DataType +from pymilvus import FieldSchema +from pymilvus import Function +from pymilvus import FunctionType +from pymilvus import MilvusClient +from pymilvus import RRFRanker +from pymilvus.milvus_client import IndexParams +from testcontainers.core.config import MAX_TRIES as TC_MAX_TRIES +from testcontainers.core.config import testcontainers_config +from testcontainers.core.generic import DbContainer +from testcontainers.milvus import MilvusContainer + +import apache_beam as beam +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that + +try: + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.ml.rag.enrichment.milvus_search import ( + MilvusSearchEnrichmentHandler, + MilvusConnectionParameters, + MilvusSearchParameters, + MilvusCollectionLoadParameters, + VectorSearchParameters, + KeywordSearchParameters, + HybridSearchParameters, + VectorSearchMetrics, + KeywordSearchMetrics) +except ImportError as e: + raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}') + +_LOGGER = logging.getLogger(__name__) + + +def _construct_index_params(): + index_params = IndexParams() + + # Milvus doesn't support multiple indexes on the same field. This is a + # limitation of Milvus - someone can only create one index per field as yet. + + # Cosine similarity index on first dense embedding field + index_params.add_index( + field_name="dense_embedding_cosine", + index_name="dense_embedding_cosine_ivf_flat", + index_type="IVF_FLAT", + metric_type=VectorSearchMetrics.COSINE.value, + params={"nlist": 1}) + + # Euclidean distance index on second dense embedding field + index_params.add_index( + field_name="dense_embedding_euclidean", + index_name="dense_embedding_euclidean_ivf_flat", + index_type="IVF_FLAT", + metric_type=VectorSearchMetrics.EUCLIDEAN_DISTANCE.value, + params={"nlist": 1}) + + # Inner product index on third dense embedding field + index_params.add_index( + field_name="dense_embedding_inner_product", + index_name="dense_embedding_inner_product_ivf_flat", + index_type="IVF_FLAT", + metric_type=VectorSearchMetrics.INNER_PRODUCT.value, + params={"nlist": 1}) + + index_params.add_index( + field_name="sparse_embedding_inner_product", + index_name="sparse_embedding_inner_product_inverted_index", + index_type="SPARSE_INVERTED_INDEX", + metric_type=VectorSearchMetrics.INNER_PRODUCT.value, + params={ + "inverted_index_algo": "TAAT_NAIVE", + }) + + # BM25 index on sparse_embedding field. + # + # For deterministic testing results + # 1. Using TAAT_NAIVE: Most predictable algorithm that processes each term + # completely before moving to the next. + # 2. Using k1=1: Moderate term frequency weighting – repeated terms matter + # but with diminishing returns. + # 3. Using b=0: No document length normalization – longer documents not + # penalized. + # This combination provides maximum transparency and predictability for + # test assertions. + index_params.add_index( + field_name="sparse_embedding_bm25", + index_name="sparse_embedding_bm25_inverted_index", + index_type="SPARSE_INVERTED_INDEX", + metric_type=KeywordSearchMetrics.BM25.value, + params={ + "inverted_index_algo": "TAAT_NAIVE", + "bm25_k1": 1, + "bm25_b": 0, + }) + + return index_params + + +@dataclass +class MilvusITDataConstruct: + id: int + content: str + domain: str + cost: int + metadata: dict + tags: list[str] + dense_embedding: list[float] + sparse_embedding: dict + vocabulary: Dict[str, int] = field(default_factory=dict) + + def __getitem__(self, key): + return getattr(self, key) + + +MILVUS_IT_CONFIG = { + "collection_name": "docs_catalog", + "fields": [ + FieldSchema( + name="id", dtype=DataType.INT64, is_primary=True, auto_id=False), + FieldSchema( + name="content", + dtype=DataType.VARCHAR, + max_length=512, + enable_analyzer=True), + FieldSchema(name="domain", dtype=DataType.VARCHAR, max_length=128), + FieldSchema(name="cost", dtype=DataType.INT32), + FieldSchema(name="metadata", dtype=DataType.JSON), + FieldSchema( + name="tags", + dtype=DataType.ARRAY, + element_type=DataType.VARCHAR, + max_length=64, + max_capacity=64), + FieldSchema( + name="dense_embedding_cosine", dtype=DataType.FLOAT_VECTOR, dim=3), + FieldSchema( + name="dense_embedding_euclidean", + dtype=DataType.FLOAT_VECTOR, + dim=3), + FieldSchema( + name="dense_embedding_inner_product", + dtype=DataType.FLOAT_VECTOR, + dim=3), + FieldSchema( + name="sparse_embedding_bm25", dtype=DataType.SPARSE_FLOAT_VECTOR), + FieldSchema( + name="sparse_embedding_inner_product", + dtype=DataType.SPARSE_FLOAT_VECTOR) + ], + "functions": [ + Function( + name="content_bm25_emb", + input_field_names=["content"], + output_field_names=["sparse_embedding_bm25"], + function_type=FunctionType.BM25) + ], + "index": _construct_index_params, + "corpus": [ + MilvusITDataConstruct( + id=1, + content="This is a test document", + domain="medical", + cost=49, + metadata={"language": "en"}, + tags=["healthcare", "patient", "clinical"], + dense_embedding=[0.1, 0.2, 0.3], + sparse_embedding={ + 1: 0.05, 2: 0.41, 3: 0.05, 4: 0.41 + }), + MilvusITDataConstruct( + id=2, + content="Another test document", + domain="legal", + cost=75, + metadata={"language": "en"}, + tags=["contract", "law", "regulation"], + dense_embedding=[0.2, 0.3, 0.4], + sparse_embedding={ + 1: 0.07, 3: 3.07, 0: 0.53 + }), + MilvusITDataConstruct( + id=3, + content="وثيقة اختبار", + domain="financial", + cost=149, + metadata={"language": "ar"}, + tags=["banking", "investment", "arabic"], + dense_embedding=[0.3, 0.4, 0.5], + sparse_embedding={ + 6: 0.62, 5: 0.62 + }) + ], + "vocabulary": { + "this": 4, + "is": 2, + "test": 3, + "document": 1, + "another": 0, + "وثيقة": 6, + "اختبار": 5 + } +} + + +@dataclass +class MilvusDBContainerInfo: + container: DbContainer + host: str + port: int + user: Optional[str] = "" + password: Optional[str] = "" + token: Optional[str] = "" + id: Optional[str] = "default" + + @property + def uri(self) -> str: + return f"http://{self.host}:{self.port}" + + +class CustomMilvusContainer(MilvusContainer): + def __init__( + self, + image: str, + service_container_port, + healthcheck_container_port, + **kwargs, + ) -> None: + # Skip the parent class's constructor and go straight to + # GenericContainer. + super(MilvusContainer, self).__init__(image=image, **kwargs) + self.port = service_container_port + self.healthcheck_port = healthcheck_container_port + self.with_exposed_ports(service_container_port, healthcheck_container_port) + + # Get free host ports. + service_host_port = MilvusEnrichmentTestHelper.find_free_port() + healthcheck_host_port = MilvusEnrichmentTestHelper.find_free_port() + + # Bind container and host ports. + self.with_bind_ports(service_container_port, service_host_port) + self.with_bind_ports(healthcheck_container_port, healthcheck_host_port) + self.cmd = "milvus run standalone" + + # Set environment variables needed for Milvus. + envs = { + "ETCD_USE_EMBED": "true", + "ETCD_DATA_DIR": "/var/lib/milvus/etcd", + "COMMON_STORAGETYPE": "local", + "METRICS_PORT": str(healthcheck_container_port) + } + for env, value in envs.items(): + self.with_env(env, value) + + +class MilvusEnrichmentTestHelper: + @staticmethod + def start_db_container( + image="milvusdb/milvus:v2.5.10", + max_vec_fields=5, + vector_client_max_retries=3, + tc_max_retries=TC_MAX_TRIES) -> Optional[MilvusDBContainerInfo]: + service_container_port = MilvusEnrichmentTestHelper.find_free_port() + healthcheck_container_port = MilvusEnrichmentTestHelper.find_free_port() + user_yaml_creator = MilvusEnrichmentTestHelper.create_user_yaml + with user_yaml_creator(service_container_port, max_vec_fields) as cfg: + info = None + testcontainers_config.max_tries = tc_max_retries + for i in range(vector_client_max_retries): + try: + vector_db_container = CustomMilvusContainer( + image=image, + service_container_port=service_container_port, + healthcheck_container_port=healthcheck_container_port) + vector_db_container = vector_db_container.with_volume_mapping( + cfg, "/milvus/configs/user.yaml") + vector_db_container.start() + host = vector_db_container.get_container_host_ip() + port = vector_db_container.get_exposed_port(service_container_port) + info = MilvusDBContainerInfo(vector_db_container, host, port) + testcontainers_config.max_tries = TC_MAX_TRIES + _LOGGER.info( + "milvus db container started successfully on %s.", info.uri) + break + except Exception as e: + stdout_logs, stderr_logs = vector_db_container.get_logs() + stdout_logs = stdout_logs.decode("utf-8") + stderr_logs = stderr_logs.decode("utf-8") + _LOGGER.warning( + "Retry %d/%d: Failed to start Milvus DB container. Reason: %s. " + "STDOUT logs:\n%s\nSTDERR logs:\n%s", + i + 1, + vector_client_max_retries, + e, + stdout_logs, + stderr_logs) + if i == vector_client_max_retries - 1: + _LOGGER.error( + "Unable to start milvus db container for I/O tests after %d " + "retries. Tests cannot proceed. STDOUT logs:\n%s\n" + "STDERR logs:\n%s", + vector_client_max_retries, + stdout_logs, + stderr_logs) + raise e + return info + + @staticmethod + def stop_db_container(db_info: MilvusDBContainerInfo): + if db_info is None: + _LOGGER.warning("Milvus db info is None. Skipping stop operation.") + return + try: + _LOGGER.debug("Stopping milvus db container.") + db_info.container.stop() + _LOGGER.info("milvus db container stopped successfully.") + except Exception as e: + _LOGGER.warning( + "Error encountered while stopping milvus db container: %s", e) + + @staticmethod + def initialize_db_with_data(connc_params: MilvusConnectionParameters): + # Open the connection to the milvus db. + client = MilvusClient(**connc_params.__dict__) + + # Configure schema. + field_schemas: List[FieldSchema] = cast( + List[FieldSchema], MILVUS_IT_CONFIG["fields"]) + schema = CollectionSchema( + fields=field_schemas, functions=MILVUS_IT_CONFIG["functions"]) + + # Create collection with the schema. + collection_name = MILVUS_IT_CONFIG["collection_name"] + index_function: Callable[[], IndexParams] = cast( + Callable[[], IndexParams], MILVUS_IT_CONFIG["index"]) + client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_function()) + + # Assert that collection was created. + collection_error = f"Expected collection '{collection_name}' to be created." + assert client.has_collection(collection_name), collection_error + + # Gather all fields we have excluding 'sparse_embedding_bm25' special field. + fields = list(map(lambda field: field.name, field_schemas)) + + # Prep data for indexing. Currently we can't insert sparse vectors for BM25 + # sparse embedding field as it would be automatically generated by Milvus + # through the registered BM25 function. + data_ready_to_index = [] + for doc in MILVUS_IT_CONFIG["corpus"]: + item = {} + for field in fields: + if field.startswith("dense_embedding"): + item[field] = doc["dense_embedding"] + elif field == "sparse_embedding_inner_product": + item[field] = doc["sparse_embedding"] + elif field == "sparse_embedding_bm25": + # It is automatically generated by Milvus from the content field. + continue + else: + item[field] = doc[field] + data_ready_to_index.append(item) + + # Index data. + result = client.insert( + collection_name=collection_name, data=data_ready_to_index) + + # Assert that the intended data has been properly indexed. + insertion_err = f'failed to insert the {result["insert_count"]} data points' + assert result["insert_count"] == len(data_ready_to_index), insertion_err + + # Release the collection from memory. It will be loaded lazily when the + # enrichment handler is invoked. + client.release_collection(collection_name) + + # Close the connection to the Milvus database, as no further preparation + # operations are needed before executing the enrichment handler. + client.close() + + return collection_name + + @staticmethod + def find_free_port(): + """Find a free port on the local machine.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # Bind to port 0, which asks OS to assign a free port. + s.bind(('', 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # Return the port number assigned by OS. + return s.getsockname()[1] + + @staticmethod + @contextlib.contextmanager + def create_user_yaml(service_port: int, max_vector_field_num=5): + """Creates a temporary user.yaml file for Milvus configuration. + + This user yaml file overrides Milvus default configurations. It sets + the Milvus service port to the specified container service port. The + default for maxVectorFieldNum is 4, but we need 5 + (one unique field for each metric). + + Args: + service_port: Port number for the Milvus service. + max_vector_field_num: Max number of vec fields allowed per collection. + + Yields: + str: Path to the created temporary yaml file. + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', + delete=False) as temp_file: + # Define the content for user.yaml. + user_config = { + 'proxy': { + 'maxVectorFieldNum': max_vector_field_num, 'port': service_port + } + } + + # Write the content to the file. + yaml.dump(user_config, temp_file, default_flow_style=False) + path = temp_file.name + + try: + yield path + finally: + if os.path.exists(path): + os.remove(path) + + +@pytest.mark.uses_testcontainer +@unittest.skipUnless( + platform.system() == "Linux", + "Test runs only on Linux due to lack of support, as yet, for nested " + "virtualization in CI environments on Windows/macOS. Many CI providers run " + "tests in virtualized environments, and nested virtualization " + "(Docker inside a VM) is either unavailable or has several issues on " + "non-Linux platforms.") +class TestMilvusSearchEnrichment(unittest.TestCase): + """Tests for search functionality across all search strategies""" + + _db: MilvusDBContainerInfo + _version = "milvusdb/milvus:v2.5.10" + + @classmethod + def setUpClass(cls): + try: + cls._db = MilvusEnrichmentTestHelper.start_db_container( + cls._version, vector_client_max_retries=1, tc_max_retries=1) + cls._connection_params = MilvusConnectionParameters( + uri=cls._db.uri, + user=cls._db.user, + password=cls._db.password, + db_id=cls._db.id, + token=cls._db.token) + cls._collection_load_params = MilvusCollectionLoadParameters() + cls._collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data( + cls._connection_params) + except Exception as e: + pytest.skip( + f"Skipping all tests in {cls.__name__} due to DB startup failure: {e}" + ) + + @classmethod + def tearDownClass(cls): + MilvusEnrichmentTestHelper.stop_db_container(cls._db) + cls._db = None + + def test_invalid_query_on_non_existent_collection(self): + non_existent_collection = "nonexistent_collection" + existent_field = "dense_embedding_cosine" + + test_chunks = [ + Chunk( + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content()) + ] + + search_parameters = MilvusSearchParameters( + collection_name=non_existent_collection, + search_strategy=VectorSearchParameters(anns_field=existent_field)) + + collection_load_parameters = MilvusCollectionLoadParameters() + + handler = MilvusSearchEnrichmentHandler( + self._connection_params, + search_parameters, + collection_load_parameters=collection_load_parameters) + + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | Enrichment(handler)) + + expect_err_msg_contains = "collection not found" + self.assertIn(expect_err_msg_contains, str(context.exception)) + + def test_invalid_query_on_non_existent_field(self): + non_existent_field = "nonexistent_column" + existent_collection = MILVUS_IT_CONFIG["collection_name"] + + test_chunks = [ + Chunk( + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content()) + ] + + search_parameters = MilvusSearchParameters( + collection_name=existent_collection, + search_strategy=VectorSearchParameters(anns_field=non_existent_field)) + + collection_load_parameters = MilvusCollectionLoadParameters() + + handler = MilvusSearchEnrichmentHandler( + self._connection_params, + search_parameters, + collection_load_parameters=collection_load_parameters) + + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | Enrichment(handler)) + + expect_err_msg_contains = f"fieldName({non_existent_field}) not found" + self.assertIn(expect_err_msg_contains, str(context.exception)) + + def test_empty_input_chunks(self): + test_chunks = [] + anns_field = "dense_embedding_cosine" + + search_parameters = MilvusSearchParameters( + collection_name=MILVUS_IT_CONFIG["collection_name"], + search_strategy=VectorSearchParameters(anns_field=anns_field)) + + collection_load_parameters = MilvusCollectionLoadParameters() + + handler = MilvusSearchEnrichmentHandler( + self._connection_params, + search_parameters, + collection_load_parameters=collection_load_parameters) + + expected_chunks = [] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + assert_that( + result, + lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + + def test_filtered_search_with_cosine_similarity_and_batching(self): + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content()), + Chunk( + id="query2", + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4]), + content=Content()), + Chunk( + id="query3", + embedding=Embedding(dense_embedding=[0.3, 0.4, 0.5]), + content=Content()) + ] + + filter_condition = 'metadata["language"] == "en"' + + anns_field = "dense_embedding_cosine" + + addition_search_params = { + "metric_type": VectorSearchMetrics.COSINE.value, "nprobe": 1 + } + + vector_search_parameters = VectorSearchParameters( + anns_field=anns_field, + limit=10, + filter=filter_condition, + search_params=addition_search_params) + + search_parameters = MilvusSearchParameters( + collection_name=MILVUS_IT_CONFIG["collection_name"], + search_strategy=vector_search_parameters, + output_fields=["id", "content", "metadata"], + round_decimal=1) + + collection_load_parameters = MilvusCollectionLoadParameters() + + # Force batching. + min_batch_size, max_batch_size = 2, 2 + handler = MilvusSearchEnrichmentHandler( + connection_parameters=self._connection_params, + search_parameters=search_parameters, + collection_load_parameters=collection_load_parameters, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size) + + expected_chunks = [ + Chunk( + id='query1', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [1, 2], + 'distance': [1.0, 1.0], + 'fields': [{ + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }, + { + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }] + } + }, + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3])), + Chunk( + id='query2', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [2, 1], + 'distance': [1.0, 1.0], + 'fields': [{ + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4])), + Chunk( + id='query3', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [2, 1], + 'distance': [1.0, 1.0], + 'fields': [{ + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.3, 0.4, 0.5])) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + assert_that( + result, + lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + + def test_filtered_search_with_bm25_full_text_and_batching(self): + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(sparse_embedding=None), + content=Content(text="This is a test document")), + Chunk( + id="query2", + embedding=Embedding(sparse_embedding=None), + content=Content(text="Another test document")), + Chunk( + id="query3", + embedding=Embedding(sparse_embedding=None), + content=Content(text="وثيقة اختبار")) + ] + + filter_condition = 'ARRAY_CONTAINS_ANY(tags, ["healthcare", "banking"])' + + anns_field = "sparse_embedding_bm25" + + addition_search_params = {"metric_type": KeywordSearchMetrics.BM25.value} + + keyword_search_parameters = KeywordSearchParameters( + anns_field=anns_field, + limit=10, + filter=filter_condition, + search_params=addition_search_params) + + search_parameters = MilvusSearchParameters( + collection_name=MILVUS_IT_CONFIG["collection_name"], + search_strategy=keyword_search_parameters, + output_fields=["id", "content", "metadata"], + round_decimal=1) + + collection_load_parameters = MilvusCollectionLoadParameters() + + # Force batching. + min_batch_size, max_batch_size = 2, 2 + handler = MilvusSearchEnrichmentHandler( + connection_parameters=self._connection_params, + search_parameters=search_parameters, + collection_load_parameters=collection_load_parameters, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size) + + expected_chunks = [ + Chunk( + id='query1', + content=Content(text='This is a test document'), + metadata={ + 'enrichment_data': { + 'id': [1], + 'distance': [3.3], + 'fields': [{ + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding()), + Chunk( + id='query2', + content=Content(text='Another test document'), + metadata={ + 'enrichment_data': { + 'id': [1], + 'distance': [0.8], + 'fields': [{ + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding()), + Chunk( + id='query3', + content=Content(text='وثيقة اختبار'), + metadata={ + 'enrichment_data': { + 'id': [3], + 'distance': [2.3], + 'fields': [{ + 'content': 'وثيقة اختبار', + 'metadata': { + 'language': 'ar' + }, + 'id': 3 + }] + } + }, + embedding=Embedding()) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + assert_that( + result, + lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + + def test_vector_search_with_euclidean_distance(self): + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content()), + Chunk( + id="query2", + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4]), + content=Content()), + Chunk( + id="query3", + embedding=Embedding(dense_embedding=[0.3, 0.4, 0.5]), + content=Content()) + ] + + anns_field = "dense_embedding_euclidean" + + addition_search_params = { + "metric_type": VectorSearchMetrics.EUCLIDEAN_DISTANCE.value, + "nprobe": 1 + } + + vector_search_parameters = VectorSearchParameters( + anns_field=anns_field, limit=10, search_params=addition_search_params) + + search_parameters = MilvusSearchParameters( + collection_name=MILVUS_IT_CONFIG["collection_name"], + search_strategy=vector_search_parameters, + output_fields=["id", "content", "metadata"], + round_decimal=1) + + collection_load_parameters = MilvusCollectionLoadParameters() + + handler = MilvusSearchEnrichmentHandler( + connection_parameters=self._connection_params, + search_parameters=search_parameters, + collection_load_parameters=collection_load_parameters) + + expected_chunks = [ + Chunk( + id='query1', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [1, 2, 3], + 'distance': [0.0, 0.0, 0.1], + 'fields': [{ + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }, + { + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'وثيقة اختبار', + 'metadata': { + 'language': 'ar' + }, + 'id': 3 + }] + } + }, + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3])), + Chunk( + id='query2', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [2, 3, 1], + 'distance': [0.0, 0.0, 0.0], + 'fields': [{ + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'وثيقة اختبار', + 'metadata': { + 'language': 'ar' + }, + 'id': 3 + }, + { + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4])), + Chunk( + id='query3', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [3, 2, 1], + 'distance': [0.0, 0.0, 0.1], + 'fields': [{ + 'content': 'وثيقة اختبار', + 'metadata': { + 'language': 'ar' + }, + 'id': 3 + }, + { + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.3, 0.4, 0.5])) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + assert_that( + result, + lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + + def test_vector_search_with_inner_product_similarity(self): + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content()), + Chunk( + id="query2", + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4]), + content=Content()), + Chunk( + id="query3", + embedding=Embedding(dense_embedding=[0.3, 0.4, 0.5]), + content=Content()) + ] + + anns_field = "dense_embedding_inner_product" + + addition_search_params = { + "metric_type": VectorSearchMetrics.INNER_PRODUCT.value, "nprobe": 1 + } + + vector_search_parameters = VectorSearchParameters( + anns_field=anns_field, limit=10, search_params=addition_search_params) + + search_parameters = MilvusSearchParameters( + collection_name=MILVUS_IT_CONFIG["collection_name"], + search_strategy=vector_search_parameters, + output_fields=["id", "content", "metadata"], + round_decimal=1) + + collection_load_parameters = MilvusCollectionLoadParameters() + + handler = MilvusSearchEnrichmentHandler( + connection_parameters=self._connection_params, + search_parameters=search_parameters, + collection_load_parameters=collection_load_parameters) + + expected_chunks = [ + Chunk( + id='query1', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [3, 2, 1], + 'distance': [0.3, 0.2, 0.1], + 'fields': [{ + 'content': 'وثيقة اختبار', + 'metadata': { + 'language': 'ar' + }, + 'id': 3 + }, + { + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3])), + Chunk( + id='query2', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [3, 2, 1], + 'distance': [0.4, 0.3, 0.2], + 'fields': [{ + 'content': 'وثيقة اختبار', + 'metadata': { + 'language': 'ar' + }, + 'id': 3 + }, + { + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4])), + Chunk( + id='query3', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [3, 2, 1], + 'distance': [0.5, 0.4, 0.3], + 'fields': [{ + 'content': 'وثيقة اختبار', + 'metadata': { + 'language': 'ar' + }, + 'id': 3 + }, + { + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.3, 0.4, 0.5])) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + assert_that( + result, + lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + + def test_keyword_search_with_inner_product_sparse_embedding(self): + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding( + sparse_embedding=([1, 2, 3, 4], [0.05, 0.41, 0.05, 0.41])), + content=Content()) + ] + + anns_field = "sparse_embedding_inner_product" + + addition_search_params = { + "metric_type": VectorSearchMetrics.INNER_PRODUCT.value, + } + + keyword_search_parameters = KeywordSearchParameters( + anns_field=anns_field, limit=3, search_params=addition_search_params) + + search_parameters = MilvusSearchParameters( + collection_name=MILVUS_IT_CONFIG["collection_name"], + search_strategy=keyword_search_parameters, + output_fields=["id", "content", "metadata"], + round_decimal=1) + + collection_load_parameters = MilvusCollectionLoadParameters() + + handler = MilvusSearchEnrichmentHandler( + connection_parameters=self._connection_params, + search_parameters=search_parameters, + collection_load_parameters=collection_load_parameters) + + expected_chunks = [ + Chunk( + id='query1', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [1, 2], + 'distance': [0.3, 0.2], + 'fields': [{ + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }, + { + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }] + } + }, + embedding=Embedding( + sparse_embedding=([1, 2, 3, 4], [0.05, 0.41, 0.05, 0.41]))) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + assert_that( + result, + lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + + def test_hybrid_search(self): + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="This is a test document")) + ] + + anns_vector_field = "dense_embedding_cosine" + addition_vector_search_params = { + "metric_type": VectorSearchMetrics.COSINE.value, "nprobe": 1 + } + + vector_search_parameters = VectorSearchParameters( + anns_field=anns_vector_field, + limit=10, + search_params=addition_vector_search_params) + + anns_keyword_field = "sparse_embedding_bm25" + addition_keyword_search_params = { + "metric_type": KeywordSearchMetrics.BM25.value + } + + keyword_search_parameters = KeywordSearchParameters( + anns_field=anns_keyword_field, + limit=10, + search_params=addition_keyword_search_params) + + hybrid_search_parameters = HybridSearchParameters( + vector=vector_search_parameters, + keyword=keyword_search_parameters, + ranker=RRFRanker(1), + limit=1) + + search_parameters = MilvusSearchParameters( + collection_name=MILVUS_IT_CONFIG["collection_name"], + search_strategy=hybrid_search_parameters, + output_fields=["id", "content", "metadata"], + round_decimal=1) + + collection_load_parameters = MilvusCollectionLoadParameters() + + handler = MilvusSearchEnrichmentHandler( + connection_parameters=self._connection_params, + search_parameters=search_parameters, + collection_load_parameters=collection_load_parameters) + + expected_chunks = [ + Chunk( + content=Content(text='This is a test document'), + id='query1', + metadata={ + 'enrichment_data': { + 'id': [1], + 'distance': [1.0], + 'fields': [{ + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3])) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + assert_that( + result, + lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + + +def parse_chunk_strings(chunk_str_list: List[str]) -> List[Chunk]: + parsed_chunks = [] + + # Define safe globals and disable built-in functions for safety. + safe_globals = { + 'Chunk': Chunk, + 'Content': Content, + 'Embedding': Embedding, + 'defaultdict': defaultdict, + 'list': list, + '__builtins__': {} + } + + for raw_str in chunk_str_list: + try: + # replace "" with actual list reference. + cleaned_str = re.sub( + r"defaultdict\(", "defaultdict(list", raw_str) + + # Evaluate string in restricted environment. + chunk = eval(cleaned_str, safe_globals) # pylint: disable=eval-used + if isinstance(chunk, Chunk): + parsed_chunks.append(chunk) + else: + raise ValueError("Parsed object is not a Chunk instance") + except Exception as e: + raise ValueError(f"Error parsing string:\n{raw_str}\n{e}") + + return parsed_chunks + + +def assert_chunks_equivalent( + actual_chunks: List[Chunk], expected_chunks: List[Chunk]): + """assert_chunks_equivalent checks for presence rather than exact match""" + # Sort both lists by ID to ensure consistent ordering. + actual_sorted = sorted(actual_chunks, key=lambda c: c.id) + expected_sorted = sorted(expected_chunks, key=lambda c: c.id) + + actual_len = len(actual_sorted) + expected_len = len(expected_sorted) + err_msg = ( + f"Different number of chunks, actual: {actual_len}, " + f"expected: {expected_len}") + assert actual_len == expected_len, err_msg + + for actual, expected in zip(actual_sorted, expected_sorted): + # Assert that IDs match. + assert actual.id == expected.id + + # Assert that dense embeddings match. + err_msg = f"Dense embedding mismatch for chunk {actual.id}" + assert actual.dense_embedding == expected.dense_embedding, err_msg + + # Assert that sparse embeddings match. + err_msg = f"Sparse embedding mismatch for chunk {actual.id}" + assert actual.sparse_embedding == expected.sparse_embedding, err_msg + + # Assert that text content match. + err_msg = f"Text Content mismatch for chunk {actual.id}" + assert actual.content.text == expected.content.text, err_msg + + # For enrichment_data, be more flexible. + # If "expected" has values for enrichment_data but actual doesn't, that's + # acceptable since vector search results can vary based on many factors + # including implementation details, vector database state, and slight + # variations in similarity calculations. + + # First ensure the enrichment data key exists. + err_msg = f"Missing enrichment_data key in chunk {actual.id}" + assert 'enrichment_data' in actual.metadata, err_msg + + # For enrichment_data, ensure consistent ordering of results. + actual_data = actual.metadata['enrichment_data'] + expected_data = expected.metadata['enrichment_data'] + + # If actual has enrichment data, then perform detailed validation. + if actual_data: + # Ensure the id key exist. + err_msg = f"Missing id key in metadata {actual.id}" + assert 'id' in actual_data, err_msg + + # Validate IDs have consistent ordering. + actual_ids = sorted(actual_data['id']) + expected_ids = sorted(expected_data['id']) + err_msg = f"IDs in enrichment_data don't match for chunk {actual.id}" + assert actual_ids == expected_ids, err_msg + + # Ensure the distance key exist. + err_msg = f"Missing distance key in metadata {actual.id}" + assert 'distance' in actual_data, err_msg + + # Validate distances exist and have same length as IDs. + actual_distances = actual_data['distance'] + expected_distances = expected_data['distance'] + err_msg = ( + "Number of distances doesn't match number of IDs for " + f"chunk {actual.id}") + assert len(actual_distances) == len(expected_distances), err_msg + + # Ensure the fields key exist. + err_msg = f"Missing fields key in metadata {actual.id}" + assert 'fields' in actual_data, err_msg + + # Validate fields have consistent content. + # Sort fields by 'id' to ensure consistent ordering. + actual_fields_sorted = sorted( + actual_data['fields'], key=lambda f: f.get('id', 0)) + expected_fields_sorted = sorted( + expected_data['fields'], key=lambda f: f.get('id', 0)) + + # Compare field IDs. + actual_field_ids = [f.get('id') for f in actual_fields_sorted] + expected_field_ids = [f.get('id') for f in expected_fields_sorted] + err_msg = f"Field IDs don't match for chunk {actual.id}" + assert actual_field_ids == expected_field_ids, err_msg + + # Compare field content. + for a_f, e_f in zip(actual_fields_sorted, expected_fields_sorted): + # Ensure the id key exist. + err_msg = f"Missing id key in metadata.fields {actual.id}" + assert 'id' in a_f + + err_msg = f"Field ID mismatch chunk {actual.id}" + assert a_f['id'] == e_f['id'], err_msg + + # Validate field metadata. + err_msg = f"Field Metadata doesn't match for chunk {actual.id}" + assert a_f['metadata'] == e_f['metadata'], err_msg + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_test.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_test.py new file mode 100644 index 000000000000..e69915cb3e9b --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_test.py @@ -0,0 +1,343 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from parameterized import parameterized + +try: + from apache_beam.ml.rag.types import Chunk + from apache_beam.ml.rag.types import Embedding + from apache_beam.ml.rag.types import Content + from apache_beam.ml.rag.enrichment.milvus_search import ( + MilvusSearchEnrichmentHandler, + MilvusConnectionParameters, + MilvusSearchParameters, + MilvusCollectionLoadParameters, + VectorSearchParameters, + KeywordSearchParameters, + HybridSearchParameters, + MilvusBaseRanker, + unpack_dataclass_with_kwargs) +except ImportError as e: + raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}') + + +class MockRanker(MilvusBaseRanker): + def dict(self): + return {"name": "mock_ranker"} + + +class TestMilvusSearchEnrichment(unittest.TestCase): + """Unit tests for general search functionality in the Enrichment Handler.""" + def test_invalid_connection_parameters(self): + """Test validation errors for invalid connection parameters.""" + # Empty URI in connection parameters. + with self.assertRaises(ValueError) as context: + connection_params = MilvusConnectionParameters(uri="") + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=VectorSearchParameters(anns_field="embedding")) + collection_load_params = MilvusCollectionLoadParameters() + + _ = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + self.assertIn( + "URI must be provided for Milvus connection", str(context.exception)) + + @parameterized.expand([ + # Empty collection name. + ( + lambda: MilvusSearchParameters( + collection_name="", + search_strategy=VectorSearchParameters(anns_field="embedding")), + "Collection name must be provided" + ), + # Missing search strategy. + ( + lambda: MilvusSearchParameters( + collection_name="test_collection", + search_strategy=None), # type: ignore[arg-type] + "Search strategy must be provided" + ), + ]) + def test_invalid_search_parameters(self, create_params, expected_error_msg): + """Test validation errors for invalid general search parameters.""" + with self.assertRaises(ValueError) as context: + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + search_params = create_params() + collection_load_params = MilvusCollectionLoadParameters() + + _ = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + self.assertIn(expected_error_msg, str(context.exception)) + + def test_unpack_dataclass_with_kwargs(self): + """Test the unpack_dataclass_with_kwargs function.""" + # Create a test dataclass instance. + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530", + user="test_user", + kwargs={"custom_param": "value"}) + + # Call the actual function. + result = unpack_dataclass_with_kwargs(connection_params) + + # Verify the function correctly unpacks the dataclass and merges kwargs. + self.assertEqual(result["uri"], "http://localhost:19530") + self.assertEqual(result["user"], "test_user") + self.assertEqual(result["custom_param"], "value") + + # Verify that kwargs take precedence over existing attributes. + connection_params_with_override = MilvusConnectionParameters( + uri="http://localhost:19530", + user="test_user", + kwargs={"user": "override_user"}) + + result_with_override = unpack_dataclass_with_kwargs( + connection_params_with_override) + self.assertEqual(result_with_override["user"], "override_user") + + +class TestMilvusVectorSearchEnrichment(unittest.TestCase): + """Unit tests specific to vector search functionality""" + + @parameterized.expand([ + # Negative limit in vector search parameters. + ( + lambda: VectorSearchParameters(anns_field="embedding", limit=-1), + "Search limit must be positive, got -1" + ), + # Missing anns_field in vector search parameters. + ( + lambda: VectorSearchParameters(anns_field=None), # type: ignore[arg-type] + "Approximate Nearest Neighbor Search (ANNS) field must be provided" + ), + ]) + def test_invalid_search_parameters(self, create_params, expected_error_msg): + """Test validation errors for invalid vector search parameters.""" + with self.assertRaises(ValueError) as context: + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + vector_search_params = create_params() + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=vector_search_params) + collection_load_params = MilvusCollectionLoadParameters() + + _ = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + self.assertIn(expected_error_msg, str(context.exception)) + + def test_missing_dense_embedding(self): + with self.assertRaises(ValueError) as context: + chunk = Chunk( + id=1, content=None, embedding=Embedding(dense_embedding=None)) + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + vector_search_params = VectorSearchParameters(anns_field="embedding") + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=vector_search_params) + collection_load_params = MilvusCollectionLoadParameters() + handler = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + _ = handler._get_vector_search_data(chunk) + + err_msg = "Chunk 1 missing dense embedding required for vector search" + self.assertIn(err_msg, str(context.exception)) + + +class TestMilvusKeywordSearchEnrichment(unittest.TestCase): + """Unit tests specific to keyword search functionality""" + + @parameterized.expand([ + # Negative limit in keyword search parameters. + ( + lambda: KeywordSearchParameters( + anns_field="sparse_embedding", limit=-1), + "Search limit must be positive, got -1" + ), + # Missing anns_field in keyword search parameters. + ( + lambda: KeywordSearchParameters(anns_field=None), # type: ignore[arg-type] + "Approximate Nearest Neighbor Search (ANNS) field must be provided" + ), + ]) + def test_invalid_search_parameters(self, create_params, expected_error_msg): + """Test validation errors for invalid keyword search parameters.""" + with self.assertRaises(ValueError) as context: + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + keyword_search_params = create_params() + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=keyword_search_params) + collection_load_params = MilvusCollectionLoadParameters() + + _ = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + self.assertIn(expected_error_msg, str(context.exception)) + + def test_missing_text_content_and_sparse_embedding(self): + with self.assertRaises(ValueError) as context: + chunk = Chunk( + id=1, + content=Content(text=None), + embedding=Embedding(sparse_embedding=None)) + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + vector_search_params = VectorSearchParameters(anns_field="embedding") + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=vector_search_params) + collection_load_params = MilvusCollectionLoadParameters() + handler = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + _ = handler._get_keyword_search_data(chunk) + + err_msg = ( + "Chunk 1 missing both text content and sparse embedding " + "required for keyword search") + self.assertIn(err_msg, str(context.exception)) + + def test_missing_text_content_only(self): + try: + chunk = Chunk( + id=1, + content=Content(text=None), + embedding=Embedding( + sparse_embedding=([1, 2, 3, 4], [0.05, 0.41, 0.05, 0.41]))) + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + vector_search_params = VectorSearchParameters(anns_field="embedding") + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=vector_search_params) + collection_load_params = MilvusCollectionLoadParameters() + handler = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + _ = handler._get_keyword_search_data(chunk) + except Exception as e: + self.fail(f"raised an unexpected exception: {e}") + + def test_missing_sparse_embedding_only(self): + try: + chunk = Chunk( + id=1, + content=Content(text="what is apache beam?"), + embedding=Embedding(sparse_embedding=None)) + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + vector_search_params = VectorSearchParameters(anns_field="embedding") + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=vector_search_params) + collection_load_params = MilvusCollectionLoadParameters() + handler = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + _ = handler._get_keyword_search_data(chunk) + except Exception as e: + self.fail(f"raised an unexpected exception: {e}") + pass + + +class TestMilvusHybridSearchEnrichment(unittest.TestCase): + """Tests specific to hybrid search functionality""" + + @parameterized.expand([ + # Missing vector in hybrid search parameters. + ( + lambda: HybridSearchParameters( + vector=None, # type: ignore[arg-type] + keyword=KeywordSearchParameters(anns_field="sparse_embedding"), + ranker=MockRanker()), + "Vector and keyword search parameters must be provided for hybrid " + "search" + ), + # Missing keyword in hybrid search parameters. + ( + lambda: HybridSearchParameters( + vector=VectorSearchParameters(anns_field="embedding"), + keyword=None, # type: ignore[arg-type] + ranker=MockRanker()), + "Vector and keyword search parameters must be provided for hybrid " + "search" + ), + # Missing ranker in hybrid search parameters. + ( + lambda: HybridSearchParameters( + vector=VectorSearchParameters(anns_field="embedding"), + keyword=KeywordSearchParameters(anns_field="sparse_embedding"), + ranker=None), # type: ignore[arg-type] + "Ranker must be provided for hybrid search" + ), + # Negative limit in hybrid search parameters. + ( + lambda: HybridSearchParameters( + vector=VectorSearchParameters(anns_field="embedding"), + keyword=KeywordSearchParameters(anns_field="sparse_embedding"), + ranker=MockRanker(), + limit=-1), + "Search limit must be positive, got -1" + ), + ]) + def test_invalid_search_parameters(self, create_params, expected_error_msg): + """Test validation errors for invalid hybrid search parameters.""" + with self.assertRaises(ValueError) as context: + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + hybrid_search_params = create_params() + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=hybrid_search_params) + collection_load_params = MilvusCollectionLoadParameters() + + _ = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + self.assertIn(expected_error_msg, str(context.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/types.py b/sdks/python/apache_beam/ml/rag/types.py index 79429899e4c1..3bb0e01b68cc 100644 --- a/sdks/python/apache_beam/ml/rag/types.py +++ b/sdks/python/apache_beam/ml/rag/types.py @@ -44,7 +44,7 @@ class Content: @dataclass class Embedding: """Represents vector embeddings. - + Args: dense_embedding: Dense vector representation sparse_embedding: Optional sparse vector representation for hybrid @@ -58,16 +58,24 @@ class Embedding: @dataclass class Chunk: """Represents a chunk of embeddable content with metadata. - + Args: content: The actual content of the chunk id: Unique identifier for the chunk index: Index of this chunk within the original document metadata: Additional metadata about the chunk (e.g., document source) - embedding: Vector embeddings of the content + embedding: Vector embeddings of the content """ content: Content id: str = field(default_factory=lambda: str(uuid.uuid4())) index: int = 0 metadata: Dict[str, Any] = field(default_factory=dict) embedding: Optional[Embedding] = None + + @property + def dense_embedding(self): + return self.embedding.dense_embedding if self.embedding else None + + @property + def sparse_embedding(self): + return self.embedding.sparse_embedding if self.embedding else None diff --git a/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py b/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py index 4d9790f2dd5c..b1e53a79bd41 100644 --- a/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py +++ b/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py @@ -16,8 +16,8 @@ # import os -import time import unittest +import uuid import pytest @@ -53,7 +53,7 @@ def _create_row(self, num: int): def test_write_read_pipeline(self): iceberg_config = { - "table": "test_iceberg_write_read.test_" + str(int(time.time())), + "table": "test_iceberg_write_read.test_" + uuid.uuid4().hex, "catalog_name": "default", "catalog_properties": { "type": "hadoop", diff --git a/sdks/python/apache_beam/transforms/periodicsequence.py b/sdks/python/apache_beam/transforms/periodicsequence.py index daab9d42387b..8916de0fa58a 100644 --- a/sdks/python/apache_beam/transforms/periodicsequence.py +++ b/sdks/python/apache_beam/transforms/periodicsequence.py @@ -33,28 +33,27 @@ from apache_beam.transforms.window import TimestampedValue from apache_beam.utils import timestamp from apache_beam.utils.timestamp import MAX_TIMESTAMP +from apache_beam.utils.timestamp import Duration from apache_beam.utils.timestamp import Timestamp +from apache_beam.utils.timestamp import TimestampTypes class ImpulseSeqGenRestrictionProvider(core.RestrictionProvider): def initial_restriction(self, element): start, end, interval = element - if isinstance(start, Timestamp): - start_micros = start.micros - else: - start_micros = round(start * 1000000) + if not isinstance(start, Timestamp): + start = Timestamp.of(start) - if isinstance(end, Timestamp): - end_micros = end.micros - else: - end_micros = round(end * 1000000) + if not isinstance(end, Timestamp): + end = Timestamp.of(end) - interval_micros = round(interval * 1000000) + interval_duration = Duration(interval) - assert start_micros <= end_micros + assert start <= end assert interval > 0 - delta_micros: int = end_micros - start_micros - total_outputs = math.ceil(delta_micros / interval_micros) + total_duration = end - start + total_outputs = math.ceil(total_duration.micros / interval_duration.micros) + return OffsetRange(0, total_outputs) def create_tracker(self, restriction): @@ -230,38 +229,31 @@ def _validate_and_adjust_duration(self): assert self.data # The total time we need to impulse all the data. - data_duration = (len(self.data) - 1) * self.interval + data_duration = (len(self.data) - 1) * Duration(self.interval) is_pre_timestamped = isinstance(self.data[0], tuple) and \ isinstance(self.data[0][0], timestamp.Timestamp) - if isinstance(self.start_ts, Timestamp): - start = self.start_ts.micros / 1000000 - else: - start = self.start_ts - - if isinstance(self.stop_ts, Timestamp): - if self.stop_ts == MAX_TIMESTAMP: - # When the stop timestamp is unbounded (MAX_TIMESTAMP), set it to the - # data's actual end time plus an extra fire interval, because the - # impulse duration's upper bound is exclusive. - end = start + data_duration + self.interval - self.stop_ts = Timestamp(micros=end * 1000000) - else: - end = self.stop_ts.micros / 1000000 - else: - end = self.stop_ts + start_ts = Timestamp.of(self.start_ts) + stop_ts = Timestamp.of(self.stop_ts) + + if stop_ts == MAX_TIMESTAMP: + # When the stop timestamp is unbounded (MAX_TIMESTAMP), set it to the + # data's actual end time plus an extra fire interval, because the + # impulse duration's upper bound is exclusive. + self.stop_ts = start_ts + data_duration + Duration(self.interval) + stop_ts = self.stop_ts # The total time for the impulse signal which occurs in [start, end). - impulse_duration = end - start - if round(data_duration + self.interval, 6) < round(impulse_duration, 6): + impulse_duration = stop_ts - start_ts + if data_duration + Duration(self.interval) < impulse_duration: # We don't have enough data for the impulse. # If we can fit at least one more data point in the impulse duration, # then we will be in the repeat mode. message = 'The number of elements in the provided pre-timestamped ' \ 'data sequence is not enough to span the full impulse duration. ' \ - f'Expected duration: {impulse_duration:.6f}, ' \ - f'actual data duration: {data_duration:.6f}.' + f'Expected duration: {impulse_duration}, ' \ + f'actual data duration: {data_duration}.' if is_pre_timestamped: raise ValueError( @@ -274,8 +266,8 @@ def _validate_and_adjust_duration(self): def __init__( self, - start_timestamp: Timestamp = Timestamp.now(), - stop_timestamp: Timestamp = MAX_TIMESTAMP, + start_timestamp: TimestampTypes = Timestamp.now(), + stop_timestamp: TimestampTypes = MAX_TIMESTAMP, fire_interval: float = 360.0, apply_windowing: bool = False, data: Optional[Sequence[Any]] = None): @@ -327,11 +319,11 @@ def expand(self, pbegin): | 'GenSequence' >> beam.ParDo(ImpulseSeqGenDoFn(self.data))) if not self.data: - # This step is only to ensure the current PTransform expansion is - # compatible with the previous Beam versions. - result = ( - result - | 'MapToTimestamped' >> beam.Map(lambda tt: TimestampedValue(tt, tt))) + # This step is actually an identity transform, because the Timestamped + # values have already been generated in `ImpulseSeqGenDoFn`. + # We keep this step here to prevent the current PeriodicImpulse from + # breaking the compatibility. + result = (result | 'MapToTimestamped' >> beam.Map(lambda tt: tt)) if self.apply_windowing: result = result | 'ApplyWindowing' >> beam.WindowInto( diff --git a/sdks/python/apache_beam/transforms/periodicsequence_test.py b/sdks/python/apache_beam/transforms/periodicsequence_test.py index fdf0995f8e5a..fce2061614af 100644 --- a/sdks/python/apache_beam/transforms/periodicsequence_test.py +++ b/sdks/python/apache_beam/transforms/periodicsequence_test.py @@ -261,22 +261,78 @@ def test_not_enough_timestamped_value(self): data=data, fire_interval=0.5)) - def test_fuzzy_interval(self): - seed = int(time.time() * 1000) + def test_fuzzy_length_and_interval(self): times = 30 - logging.warning("random seed=%d", seed) - random.seed(seed) for _ in range(times): + seed = int(time.time() * 1000) + random.seed(seed) n = int(random.randint(1, 100)) data = list(range(n)) m = random.randint(1, 1000) interval = m / 1e6 now = Timestamp.now() - with TestPipeline() as p: - ret = ( - p | PeriodicImpulse( - start_timestamp=now, data=data, fire_interval=interval)) - assert_that(ret, equal_to(data)) + try: + with TestPipeline() as p: + ret = ( + p | PeriodicImpulse( + start_timestamp=now, data=data, fire_interval=interval)) + assert_that(ret, equal_to(data)) + except Exception as e: # pylint: disable=broad-except + logging.error("Error occurred at random seed=%d", seed) + raise e + + def test_fuzzy_length_at_minimal_interval(self): + times = 30 + for _ in range(times): + seed = int(time.time() * 1000) + seed = 1751135957975 + random.seed(seed) + n = int(random.randint(1, 100)) + data = list(range(n)) + interval = 1e-6 + now = Timestamp.now() + try: + with TestPipeline() as p: + ret = ( + p | PeriodicImpulse( + start_timestamp=now, data=data, fire_interval=interval)) + assert_that(ret, equal_to(data)) + except Exception as e: # pylint: disable=broad-except + logging.error("Error occurred at random seed=%d", seed) + raise e + + def test_int_type_input(self): + # This test is to verify that if input timestamps and interval are integers, + # the generated timestamped values are also integers. + # This is necessary for the following test to pass: + # apache_beam.examples.snippets.snippets_test.SlowlyChangingSideInputsTest + with TestPipeline() as p: + ret = ( + p | PeriodicImpulse( + start_timestamp=1, stop_timestamp=5, fire_interval=1)) + expected = [1, 2, 3, 4] + assert_that( + ret, equal_to(expected, lambda x, y: type(x) is type(y) and x == y)) + + def test_float_type_input(self): + with TestPipeline() as p: + ret = ( + p | PeriodicImpulse( + start_timestamp=1.0, stop_timestamp=5.0, fire_interval=1)) + expected = [1.0, 2.0, 3.0, 4.0] + assert_that( + ret, equal_to(expected, lambda x, y: type(x) is type(y) and x == y)) + + def test_timestamp_type_input(self): + with TestPipeline() as p: + ret = ( + p | PeriodicImpulse( + start_timestamp=Timestamp.of(1), + stop_timestamp=Timestamp.of(5), + fire_interval=1)) + expected = [1.0, 2.0, 3.0, 4.0] + assert_that( + ret, equal_to(expected, lambda x, y: type(x) is type(y) and x == y)) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/yaml/tests/filter.yaml b/sdks/python/apache_beam/yaml/tests/filter.yaml index 40f0307770d5..23352c9378bc 100644 --- a/sdks/python/apache_beam/yaml/tests/filter.yaml +++ b/sdks/python/apache_beam/yaml/tests/filter.yaml @@ -116,6 +116,29 @@ pipelines: - {transaction_id: "T0302", product_name: "Monitor", category: "Electronics", price: 249.99} + # Simple Filter using SQL + - pipeline: + type: chain + transforms: + - type: Create + config: + elements: + - {transaction_id: "T0012", product_name: "Headphones", category: "Electronics", price: 59.99} + - {transaction_id: "T5034", product_name: "Leather Jacket", category: "Apparel", price: 109.99} + - {transaction_id: "T0024", product_name: "Aluminum Mug", category: "Kitchen", price: 29.9} + - {transaction_id: "T0104", product_name: "Headphones", category: "Electronics", price: 59.99} + - {transaction_id: "T0302", product_name: "Monitor", category: "Electronics", price: 249.99} + - type: Filter + config: + language: sql + keep: category = 'Electronics' and price < 100 + - type: AssertEqual + config: + elements: + - {transaction_id: "T0012", product_name: "Headphones", category: "Electronics", price: 59.99} + - {transaction_id: "T0104", product_name: "Headphones", category: "Electronics", price: 59.99} + + # Simple Filter with error handling - pipeline: type: composite diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 1011d45383dc..4433a0503b9a 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -724,7 +724,7 @@ def _PyJsMapToFields( @beam.ptransform.ptransform_fn def _SqlFilterTransform(pcoll, sql_transform_constructor, keep, language): return pcoll | sql_transform_constructor( - f'SELECT * FROM PCOLLECTION WHERE {keep}') + f"SELECT * FROM PCOLLECTION WHERE {keep.get('expression')}") @beam.ptransform.ptransform_fn diff --git a/sdks/python/container/build.gradle b/sdks/python/container/build.gradle index 861d90b6e6be..f4de804b80b7 100644 --- a/sdks/python/container/build.gradle +++ b/sdks/python/container/build.gradle @@ -70,6 +70,7 @@ for(int i=min_python_version; i<=max_python_version; ++i) { } tasks.register("pushAll") { + dependsOn ':sdks:python:container:distroless:pushAll' for(int ver=min_python_version; ver<=max_python_version; ++ver) { if (!project.hasProperty("skip-python-3" + ver + "-images")) { dependsOn ':sdks:python:container:push3' + ver diff --git a/sdks/python/container/distroless/build.gradle b/sdks/python/container/distroless/build.gradle index a967a80f4fc0..314484ade61a 100644 --- a/sdks/python/container/distroless/build.gradle +++ b/sdks/python/container/distroless/build.gradle @@ -38,7 +38,9 @@ for(int i=min_python_version; i<=max_python_version; ++i) { if (cur != min_version) { // Enforce ordering to allow the prune step to happen between runs. // This will ensure we don't use up too much space (especially in CI environments) - mustRunAfter(":sdks:python:container:distroless:push" + prev) + if (!project.hasProperty("skip-python-3" + prev + "-images")) { + mustRunAfter(":sdks:python:container:distroless:push" + prev) + } } dependsOn ':sdks:python:container:distroless:py' + cur + ':docker' @@ -55,6 +57,8 @@ for(int i=min_python_version; i<=max_python_version; ++i) { tasks.register("pushAll") { for(int ver=min_python_version; ver<=max_python_version; ++ver) { - dependsOn ':sdks:python:container:distroless:push3' + ver + if (!project.hasProperty("skip-python-3" + ver + "-images")) { + dependsOn ':sdks:python:container:distroless:push3' + ver + } } } diff --git a/sdks/python/container/license_scripts/dep_urls_py.yaml b/sdks/python/container/license_scripts/dep_urls_py.yaml index da10163fdb4f..b46fc10adf13 100644 --- a/sdks/python/container/license_scripts/dep_urls_py.yaml +++ b/sdks/python/container/license_scripts/dep_urls_py.yaml @@ -135,6 +135,8 @@ pip_dependencies: license: "https://github.com/PiotrDabkowski/pyjsparser/blob/master/LICENSE" pymongo: license: "https://raw.githubusercontent.com/mongodb/mongo-python-driver/master/LICENSE" + milvus-lite: + license: "https://raw.githubusercontent.com/milvus-io/milvus-lite/refs/heads/main/LICENSE" pyproject_hooks: license: "https://raw.githubusercontent.com/pypa/pyproject-hooks/main/LICENSE" python-gflags: diff --git a/sdks/python/container/py310/base_image_requirements.txt b/sdks/python/container/py310/base_image_requirements.txt index 5f69f6b11928..e9b4f1905399 100644 --- a/sdks/python/container/py310/base_image_requirements.txt +++ b/sdks/python/container/py310/base_image_requirements.txt @@ -43,7 +43,6 @@ cloud-sql-python-connector==1.18.2 crcmod==1.7 cryptography==45.0.4 Cython==3.1.2 -deprecation==2.1.0 dill==0.3.1.1 dnspython==2.7.0 docker==7.1.0 @@ -57,17 +56,17 @@ freezegun==1.5.2 frozenlist==1.7.0 future==1.0.0 google-api-core==2.25.1 -google-api-python-client==2.172.0 +google-api-python-client==2.174.0 google-apitools==0.5.31 google-auth==2.40.3 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.97.0 +google-cloud-aiplatform==1.100.0 google-cloud-bigquery==3.34.0 google-cloud-bigquery-storage==2.32.0 google-cloud-bigtable==2.31.0 google-cloud-core==2.4.3 google-cloud-datastore==2.21.0 -google-cloud-dlp==3.30.0 +google-cloud-dlp==3.31.0 google-cloud-language==2.17.2 google-cloud-profiler==4.1.0 google-cloud-pubsub==2.30.0 @@ -79,7 +78,7 @@ google-cloud-storage==2.19.0 google-cloud-videointelligence==2.16.2 google-cloud-vision==3.10.2 google-crc32c==1.7.1 -google-genai==1.20.0 +google-genai==1.23.0 google-resumable-media==2.7.2 googleapis-common-protos==1.70.0 greenlet==3.2.3 @@ -93,27 +92,27 @@ hdfs==2.7.3 httpcore==1.0.9 httplib2==0.22.0 httpx==0.28.1 -hypothesis==6.135.10 +hypothesis==6.135.17 idna==3.10 importlib_metadata==8.7.0 iniconfig==2.1.0 jaraco.classes==3.4.0 jaraco.context==6.0.1 -jaraco.functools==4.1.0 +jaraco.functools==4.2.1 jeepney==0.9.0 Jinja2==3.1.6 joblib==1.5.1 jsonpickle==3.4.2 jsonschema==4.24.0 jsonschema-specifications==2025.4.1 -kafka-python==2.2.11 keyring==25.6.0 keyrings.google-artifactregistry-auth==1.1.2 MarkupSafe==3.0.2 +milvus-lite==2.5.1 mmh3==5.1.0 mock==5.2.0 more-itertools==10.7.0 -multidict==6.4.4 +multidict==6.6.2 mysql-connector-python==9.3.0 nltk==3.9.1 numpy==2.2.6 @@ -122,13 +121,14 @@ objsize==0.7.1 opentelemetry-api==1.34.1 opentelemetry-sdk==1.34.1 opentelemetry-semantic-conventions==0.55b1 -oracledb==3.1.1 +oracledb==3.2.0 orjson==3.10.18 overrides==7.7.0 packaging==25.0 pandas==2.2.3 parameterized==0.9.0 pg8000==1.31.2 +pip==25.1.1 pluggy==1.6.0 propcache==0.3.2 proto-plus==1.26.1 @@ -144,7 +144,8 @@ pydantic_core==2.33.2 pydot==1.4.2 PyHamcrest==2.1.0 PyJWT==2.9.0 -pymongo==4.13.1 +pymilvus==2.5.11 +pymongo==4.13.2 PyMySQL==1.1.1 pyparsing==3.2.3 pyproject_hooks==1.2.0 @@ -152,6 +153,7 @@ pytest==7.4.4 pytest-timeout==2.4.0 pytest-xdist==3.7.0 python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 python-tds==1.16.1 pytz==2025.2 PyYAML==6.0.2 @@ -166,6 +168,7 @@ scikit-learn==1.7.0 scipy==1.15.3 scramp==1.4.5 SecretStorage==3.3.3 +setuptools==80.9.0 shapely==2.1.1 six==1.17.0 sniffio==1.3.1 @@ -175,17 +178,19 @@ SQLAlchemy==2.0.41 sqlalchemy_pytds==1.0.2 sqlparse==0.5.3 tenacity==8.5.0 -testcontainers==3.7.1 +testcontainers==4.10.0 threadpoolctl==3.6.0 tomli==2.2.1 tqdm==4.67.1 typing-inspection==0.4.1 typing_extensions==4.14.0 tzdata==2025.2 +ujson==5.10.0 uritemplate==4.2.0 -urllib3==2.4.0 +urllib3==2.5.0 virtualenv-clone==0.5.7 websockets==15.0.1 +wheel==0.45.1 wrapt==1.17.2 yarl==1.20.1 zipp==3.23.0 diff --git a/sdks/python/container/py311/base_image_requirements.txt b/sdks/python/container/py311/base_image_requirements.txt index 10d55d17f409..af2e75a54b8f 100644 --- a/sdks/python/container/py311/base_image_requirements.txt +++ b/sdks/python/container/py311/base_image_requirements.txt @@ -42,7 +42,6 @@ cloud-sql-python-connector==1.18.2 crcmod==1.7 cryptography==45.0.4 Cython==3.1.2 -deprecation==2.1.0 dill==0.3.1.1 dnspython==2.7.0 docker==7.1.0 @@ -55,17 +54,17 @@ freezegun==1.5.2 frozenlist==1.7.0 future==1.0.0 google-api-core==2.25.1 -google-api-python-client==2.172.0 +google-api-python-client==2.174.0 google-apitools==0.5.31 google-auth==2.40.3 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.97.0 +google-cloud-aiplatform==1.100.0 google-cloud-bigquery==3.34.0 google-cloud-bigquery-storage==2.32.0 google-cloud-bigtable==2.31.0 google-cloud-core==2.4.3 google-cloud-datastore==2.21.0 -google-cloud-dlp==3.30.0 +google-cloud-dlp==3.31.0 google-cloud-language==2.17.2 google-cloud-profiler==4.1.0 google-cloud-pubsub==2.30.0 @@ -77,7 +76,7 @@ google-cloud-storage==2.19.0 google-cloud-videointelligence==2.16.2 google-cloud-vision==3.10.2 google-crc32c==1.7.1 -google-genai==1.20.0 +google-genai==1.23.0 google-resumable-media==2.7.2 googleapis-common-protos==1.70.0 greenlet==3.2.3 @@ -91,27 +90,27 @@ hdfs==2.7.3 httpcore==1.0.9 httplib2==0.22.0 httpx==0.28.1 -hypothesis==6.135.10 +hypothesis==6.135.17 idna==3.10 importlib_metadata==8.7.0 iniconfig==2.1.0 jaraco.classes==3.4.0 jaraco.context==6.0.1 -jaraco.functools==4.1.0 +jaraco.functools==4.2.1 jeepney==0.9.0 Jinja2==3.1.6 joblib==1.5.1 jsonpickle==3.4.2 jsonschema==4.24.0 jsonschema-specifications==2025.4.1 -kafka-python==2.2.11 keyring==25.6.0 keyrings.google-artifactregistry-auth==1.1.2 MarkupSafe==3.0.2 +milvus-lite==2.5.1 mmh3==5.1.0 mock==5.2.0 more-itertools==10.7.0 -multidict==6.4.4 +multidict==6.6.2 mysql-connector-python==9.3.0 nltk==3.9.1 numpy==2.2.6 @@ -120,13 +119,14 @@ objsize==0.7.1 opentelemetry-api==1.34.1 opentelemetry-sdk==1.34.1 opentelemetry-semantic-conventions==0.55b1 -oracledb==3.1.1 +oracledb==3.2.0 orjson==3.10.18 overrides==7.7.0 packaging==25.0 pandas==2.2.3 parameterized==0.9.0 pg8000==1.31.2 +pip==25.1.1 pluggy==1.6.0 propcache==0.3.2 proto-plus==1.26.1 @@ -142,7 +142,8 @@ pydantic_core==2.33.2 pydot==1.4.2 PyHamcrest==2.1.0 PyJWT==2.9.0 -pymongo==4.13.1 +pymilvus==2.5.11 +pymongo==4.13.2 PyMySQL==1.1.1 pyparsing==3.2.3 pyproject_hooks==1.2.0 @@ -150,6 +151,7 @@ pytest==7.4.4 pytest-timeout==2.4.0 pytest-xdist==3.7.0 python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 python-tds==1.16.1 pytz==2025.2 PyYAML==6.0.2 @@ -161,9 +163,10 @@ requests-mock==1.12.1 rpds-py==0.25.1 rsa==4.9.1 scikit-learn==1.7.0 -scipy==1.15.3 +scipy==1.16.0 scramp==1.4.5 SecretStorage==3.3.3 +setuptools==80.9.0 shapely==2.1.1 six==1.17.0 sniffio==1.3.1 @@ -173,16 +176,18 @@ SQLAlchemy==2.0.41 sqlalchemy_pytds==1.0.2 sqlparse==0.5.3 tenacity==8.5.0 -testcontainers==3.7.1 +testcontainers==4.10.0 threadpoolctl==3.6.0 tqdm==4.67.1 typing-inspection==0.4.1 typing_extensions==4.14.0 tzdata==2025.2 +ujson==5.10.0 uritemplate==4.2.0 -urllib3==2.4.0 +urllib3==2.5.0 virtualenv-clone==0.5.7 websockets==15.0.1 +wheel==0.45.1 wrapt==1.17.2 yarl==1.20.1 zipp==3.23.0 diff --git a/sdks/python/container/py312/base_image_requirements.txt b/sdks/python/container/py312/base_image_requirements.txt index d4b9c8751dca..f48d350e01d3 100644 --- a/sdks/python/container/py312/base_image_requirements.txt +++ b/sdks/python/container/py312/base_image_requirements.txt @@ -41,7 +41,6 @@ cloud-sql-python-connector==1.18.2 crcmod==1.7 cryptography==45.0.4 Cython==3.1.2 -deprecation==2.1.0 dill==0.3.1.1 dnspython==2.7.0 docker==7.1.0 @@ -54,17 +53,17 @@ freezegun==1.5.2 frozenlist==1.7.0 future==1.0.0 google-api-core==2.25.1 -google-api-python-client==2.172.0 +google-api-python-client==2.174.0 google-apitools==0.5.31 google-auth==2.40.3 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.97.0 +google-cloud-aiplatform==1.100.0 google-cloud-bigquery==3.34.0 google-cloud-bigquery-storage==2.32.0 google-cloud-bigtable==2.31.0 google-cloud-core==2.4.3 google-cloud-datastore==2.21.0 -google-cloud-dlp==3.30.0 +google-cloud-dlp==3.31.0 google-cloud-language==2.17.2 google-cloud-profiler==4.1.0 google-cloud-pubsub==2.30.0 @@ -76,7 +75,7 @@ google-cloud-storage==2.19.0 google-cloud-videointelligence==2.16.2 google-cloud-vision==3.10.2 google-crc32c==1.7.1 -google-genai==1.20.0 +google-genai==1.23.0 google-resumable-media==2.7.2 googleapis-common-protos==1.70.0 greenlet==3.2.3 @@ -90,27 +89,27 @@ hdfs==2.7.3 httpcore==1.0.9 httplib2==0.22.0 httpx==0.28.1 -hypothesis==6.135.10 +hypothesis==6.135.17 idna==3.10 importlib_metadata==8.7.0 iniconfig==2.1.0 jaraco.classes==3.4.0 jaraco.context==6.0.1 -jaraco.functools==4.1.0 +jaraco.functools==4.2.1 jeepney==0.9.0 Jinja2==3.1.6 joblib==1.5.1 jsonpickle==3.4.2 jsonschema==4.24.0 jsonschema-specifications==2025.4.1 -kafka-python==2.2.11 keyring==25.6.0 keyrings.google-artifactregistry-auth==1.1.2 MarkupSafe==3.0.2 +milvus-lite==2.5.1 mmh3==5.1.0 mock==5.2.0 more-itertools==10.7.0 -multidict==6.4.4 +multidict==6.6.2 mysql-connector-python==9.3.0 nltk==3.9.1 numpy==2.2.6 @@ -119,13 +118,14 @@ objsize==0.7.1 opentelemetry-api==1.34.1 opentelemetry-sdk==1.34.1 opentelemetry-semantic-conventions==0.55b1 -oracledb==3.1.1 +oracledb==3.2.0 orjson==3.10.18 overrides==7.7.0 packaging==25.0 pandas==2.2.3 parameterized==0.9.0 pg8000==1.31.2 +pip==25.1.1 pluggy==1.6.0 propcache==0.3.2 proto-plus==1.26.1 @@ -141,7 +141,8 @@ pydantic_core==2.33.2 pydot==1.4.2 PyHamcrest==2.1.0 PyJWT==2.9.0 -pymongo==4.13.1 +pymilvus==2.5.11 +pymongo==4.13.2 PyMySQL==1.1.1 pyparsing==3.2.3 pyproject_hooks==1.2.0 @@ -149,6 +150,7 @@ pytest==7.4.4 pytest-timeout==2.4.0 pytest-xdist==3.7.0 python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 python-tds==1.16.1 pytz==2025.2 PyYAML==6.0.2 @@ -160,7 +162,7 @@ requests-mock==1.12.1 rpds-py==0.25.1 rsa==4.9.1 scikit-learn==1.7.0 -scipy==1.15.3 +scipy==1.16.0 scramp==1.4.5 SecretStorage==3.3.3 setuptools==80.9.0 @@ -173,14 +175,15 @@ SQLAlchemy==2.0.41 sqlalchemy_pytds==1.0.2 sqlparse==0.5.3 tenacity==8.5.0 -testcontainers==3.7.1 +testcontainers==4.10.0 threadpoolctl==3.6.0 tqdm==4.67.1 typing-inspection==0.4.1 typing_extensions==4.14.0 tzdata==2025.2 +ujson==5.10.0 uritemplate==4.2.0 -urllib3==2.4.0 +urllib3==2.5.0 virtualenv-clone==0.5.7 websockets==15.0.1 wheel==0.45.1 diff --git a/sdks/python/container/py39/base_image_requirements.txt b/sdks/python/container/py39/base_image_requirements.txt index 849786b95756..1c2ebc4c7a4c 100644 --- a/sdks/python/container/py39/base_image_requirements.txt +++ b/sdks/python/container/py39/base_image_requirements.txt @@ -43,7 +43,6 @@ cloud-sql-python-connector==1.18.2 crcmod==1.7 cryptography==45.0.4 Cython==3.1.2 -deprecation==2.1.0 dill==0.3.1.1 dnspython==2.7.0 docker==7.1.0 @@ -57,17 +56,17 @@ freezegun==1.5.2 frozenlist==1.7.0 future==1.0.0 google-api-core==2.25.1 -google-api-python-client==2.172.0 +google-api-python-client==2.174.0 google-apitools==0.5.31 google-auth==2.40.3 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.97.0 +google-cloud-aiplatform==1.100.0 google-cloud-bigquery==3.34.0 google-cloud-bigquery-storage==2.32.0 google-cloud-bigtable==2.31.0 google-cloud-core==2.4.3 google-cloud-datastore==2.21.0 -google-cloud-dlp==3.30.0 +google-cloud-dlp==3.31.0 google-cloud-language==2.17.2 google-cloud-profiler==4.1.0 google-cloud-pubsub==2.30.0 @@ -79,7 +78,7 @@ google-cloud-storage==2.19.0 google-cloud-videointelligence==2.16.2 google-cloud-vision==3.10.2 google-crc32c==1.7.1 -google-genai==1.20.0 +google-genai==1.23.0 google-resumable-media==2.7.2 googleapis-common-protos==1.70.0 greenlet==3.2.3 @@ -93,27 +92,27 @@ hdfs==2.7.3 httpcore==1.0.9 httplib2==0.22.0 httpx==0.28.1 -hypothesis==6.135.10 +hypothesis==6.135.17 idna==3.10 importlib_metadata==8.7.0 iniconfig==2.1.0 jaraco.classes==3.4.0 jaraco.context==6.0.1 -jaraco.functools==4.1.0 +jaraco.functools==4.2.1 jeepney==0.9.0 Jinja2==3.1.6 joblib==1.5.1 jsonpickle==3.4.2 jsonschema==4.24.0 jsonschema-specifications==2025.4.1 -kafka-python==2.2.11 keyring==25.6.0 keyrings.google-artifactregistry-auth==1.1.2 MarkupSafe==3.0.2 +milvus-lite==2.5.1 mmh3==5.1.0 mock==5.2.0 more-itertools==10.7.0 -multidict==6.4.4 +multidict==6.6.2 mysql-connector-python==9.3.0 nltk==3.9.1 numpy==2.0.2 @@ -122,13 +121,14 @@ objsize==0.7.1 opentelemetry-api==1.34.1 opentelemetry-sdk==1.34.1 opentelemetry-semantic-conventions==0.55b1 -oracledb==3.1.1 +oracledb==3.2.0 orjson==3.10.18 overrides==7.7.0 packaging==25.0 pandas==2.2.3 parameterized==0.9.0 pg8000==1.31.2 +pip==25.1.1 pluggy==1.6.0 propcache==0.3.2 proto-plus==1.26.1 @@ -144,7 +144,8 @@ pydantic_core==2.33.2 pydot==1.4.2 PyHamcrest==2.1.0 PyJWT==2.9.0 -pymongo==4.13.1 +pymilvus==2.5.11 +pymongo==4.13.2 PyMySQL==1.1.1 pyparsing==3.2.3 pyproject_hooks==1.2.0 @@ -152,6 +153,7 @@ pytest==7.4.4 pytest-timeout==2.4.0 pytest-xdist==3.7.0 python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 python-tds==1.16.1 pytz==2025.2 PyYAML==6.0.2 @@ -166,6 +168,7 @@ scikit-learn==1.6.1 scipy==1.13.1 scramp==1.4.5 SecretStorage==3.3.3 +setuptools==80.9.0 shapely==2.0.7 six==1.17.0 sniffio==1.3.1 @@ -175,17 +178,19 @@ SQLAlchemy==2.0.41 sqlalchemy_pytds==1.0.2 sqlparse==0.5.3 tenacity==8.5.0 -testcontainers==3.7.1 +testcontainers==4.10.0 threadpoolctl==3.6.0 tomli==2.2.1 tqdm==4.67.1 typing-inspection==0.4.1 typing_extensions==4.14.0 tzdata==2025.2 +ujson==5.10.0 uritemplate==4.2.0 -urllib3==2.4.0 +urllib3==2.5.0 virtualenv-clone==0.5.7 websockets==15.0.1 +wheel==0.45.1 wrapt==1.17.2 yarl==1.20.1 zipp==3.23.0 diff --git a/sdks/python/container/run_generate_requirements.sh b/sdks/python/container/run_generate_requirements.sh index 6c160bc6ac9e..23964d10e7b4 100755 --- a/sdks/python/container/run_generate_requirements.sh +++ b/sdks/python/container/run_generate_requirements.sh @@ -72,7 +72,7 @@ pip uninstall -y apache-beam echo "Checking for broken dependencies:" pip check echo "Installed dependencies:" -pip freeze +pip freeze --all PY_IMAGE="py${PY_VERSION//.}" REQUIREMENTS_FILE=$PWD/sdks/python/container/$PY_IMAGE/base_image_requirements.txt @@ -103,7 +103,7 @@ cat < "$REQUIREMENTS_FILE" EOT # Remove pkg_resources to guard against # https://stackoverflow.com/questions/39577984/what-is-pkg-resources-0-0-0-in-output-of-pip-freeze-command -pip freeze | grep -v pkg_resources >> "$REQUIREMENTS_FILE" +pip freeze --all | grep -v pkg_resources >> "$REQUIREMENTS_FILE" if grep -q "tensorflow==" "$REQUIREMENTS_FILE"; then # Get the version of tensorflow from the .txt file. diff --git a/sdks/python/setup.py b/sdks/python/setup.py index a0bbc301435b..d309a7ea4a64 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -401,6 +401,7 @@ def get_portability_package_data(): 'typing-extensions>=3.7.0', 'zstandard>=0.18.0,<1', 'pyyaml>=3.12,<7.0.0', + 'pymilvus>=2.5.10,<3.0.0', # Dynamic dependencies must be specified in a separate list, otherwise # Dependabot won't be able to parse the main list. Any dynamic # dependencies will not receive updates from Dependabot. @@ -434,11 +435,10 @@ def get_portability_package_data(): 'pytest-xdist>=2.5.0,<4', 'pytest-timeout>=2.1.0,<3', 'scikit-learn>=0.20.0', - 'setuptools', 'sqlalchemy>=1.3,<3.0', 'psycopg2-binary>=2.8.5,<2.9.10; python_version <= "3.9"', 'psycopg2-binary>=2.8.5,<3.0; python_version >= "3.10"', - 'testcontainers[mysql,kafka]>=3.0.3,<4.0.0', + 'testcontainers[mysql,kafka,milvus]>=4.0.0,<5.0.0', 'cryptography>=41.0.2', 'hypothesis>5.0.0,<7.0.0', 'virtualenv-clone>=0.5,<1.0',