diff --git a/v2/datastream-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/DataStreamToSpanner.java b/v2/datastream-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/DataStreamToSpanner.java index 11b8ffdee4..6126d70118 100644 --- a/v2/datastream-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/DataStreamToSpanner.java +++ b/v2/datastream-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/DataStreamToSpanner.java @@ -38,7 +38,7 @@ import com.google.cloud.teleport.v2.spanner.migrations.shard.ShardingContext; import com.google.cloud.teleport.v2.spanner.migrations.transformation.CustomTransformation; import com.google.cloud.teleport.v2.spanner.migrations.transformation.TransformationContext; -import com.google.cloud.teleport.v2.spanner.migrations.utils.DataflowWorkerMachineTypeValidator; +import com.google.cloud.teleport.v2.spanner.migrations.utils.DataflowWorkerMachineTypeUtils; import com.google.cloud.teleport.v2.spanner.migrations.utils.SessionFileReader; import com.google.cloud.teleport.v2.spanner.migrations.utils.ShardingContextReader; import com.google.cloud.teleport.v2.spanner.migrations.utils.TransformationContextReader; @@ -649,7 +649,7 @@ public static PipelineResult run(Options options) { Pipeline pipeline = Pipeline.create(options); String workerMachineType = pipeline.getOptions().as(DataflowPipelineWorkerPoolOptions.class).getWorkerMachineType(); - DataflowWorkerMachineTypeValidator.validateMachineSpecs(workerMachineType, 4); + DataflowWorkerMachineTypeUtils.validateMachineSpecs(workerMachineType, 4); DeadLetterQueueManager dlqManager = buildDlqManager(options); // Ingest session file into schema object. Schema schema = SessionFileReader.read(options.getSessionFilePath()); diff --git a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/options/OptionsToConfigBuilder.java b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/options/OptionsToConfigBuilder.java index 8632b951be..dcf3346fe8 100644 --- a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/options/OptionsToConfigBuilder.java +++ b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/options/OptionsToConfigBuilder.java @@ -24,6 +24,7 @@ import com.google.cloud.teleport.v2.source.reader.io.jdbc.iowrapper.config.SQLDialect; import com.google.cloud.teleport.v2.source.reader.io.jdbc.iowrapper.config.defaults.MySqlConfigDefaults; import com.google.cloud.teleport.v2.source.reader.io.schema.SourceSchemaReference; +import com.google.cloud.teleport.v2.spanner.migrations.utils.DataflowWorkerMachineTypeUtils; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.re2j.Matcher; @@ -33,6 +34,8 @@ import java.util.List; import java.util.Map.Entry; import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.options.DataflowPipelineWorkerPoolOptions; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.Wait; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; @@ -43,6 +46,21 @@ public final class OptionsToConfigBuilder { private static final Logger LOG = LoggerFactory.getLogger(OptionsToConfigBuilder.class); public static final String DEFAULT_POSTGRESQL_NAMESPACE = "public"; + /** + * Extracts the worker zone from the options. + * + * @param options Pipeline options. + * @return The worker zone or null if not found. + */ + public static String extractWorkerZone(PipelineOptions options) { + try { + return options.as(DataflowPipelineWorkerPoolOptions.class).getWorkerZone(); + } catch (Exception e) { + LOG.warn("Could not extract worker zone from options. Defaulting to null.", e); + return null; + } + } + public static JdbcIOWrapperConfig getJdbcIOWrapperConfigWithDefaults( SourceDbToSpannerOptions options, List tables, @@ -60,6 +78,12 @@ public static JdbcIOWrapperConfig getJdbcIOWrapperConfigWithDefaults( long maxConnections = options.getMaxConnections() > 0 ? (long) (options.getMaxConnections()) : 0; Integer numPartitions = options.getNumPartitions(); + String workerZone = extractWorkerZone(options); + + Integer fetchSize = options.getFetchSize(); + if (fetchSize != null && fetchSize < 0) { + fetchSize = null; + } return getJdbcIOWrapperConfig( sqlDialect, @@ -78,8 +102,11 @@ public static JdbcIOWrapperConfig getJdbcIOWrapperConfigWithDefaults( maxConnections, numPartitions, waitOn, - options.getFetchSize(), - options.getUniformizationStageCountHint()); + fetchSize, + options.getUniformizationStageCountHint(), + options.getProjectId(), + workerZone, + options.as(DataflowPipelineWorkerPoolOptions.class).getWorkerMachineType()); } public static JdbcIOWrapperConfig getJdbcIOWrapperConfig( @@ -100,7 +127,10 @@ public static JdbcIOWrapperConfig getJdbcIOWrapperConfig( Integer numPartitions, Wait.OnSignal waitOn, Integer fetchSize, - Long uniformizationStageCountHint) { + Long uniformizationStageCountHint, + String projectId, + String workerZone, + String workerMachineType) { JdbcIOWrapperConfig.Builder builder = builderWithDefaultsFor(sqlDialect); SourceSchemaReference sourceSchemaReference = sourceSchemaReferenceFrom(sqlDialect, dbName, namespace); @@ -115,6 +145,14 @@ public static JdbcIOWrapperConfig getJdbcIOWrapperConfig( .build()) .setJdbcDriverClassName(jdbcDriverClassName) .setJdbcDriverJars(jdbcDriverJars); + + if (workerMachineType != null && !workerMachineType.isEmpty()) { + builder.setWorkerMemoryGB( + DataflowWorkerMachineTypeUtils.getWorkerMemoryGB( + projectId, workerZone, workerMachineType)); + builder.setWorkerCores( + DataflowWorkerMachineTypeUtils.getWorkerCores(projectId, workerZone, workerMachineType)); + } if (maxConnections != 0) { builder = builder.setMaxConnections(maxConnections); } @@ -161,27 +199,34 @@ public static JdbcIOWrapperConfig getJdbcIOWrapperConfig( } /** - * For MySQL Dialect, if Fetchsize is expecitly set by the user, enables `useCursorFetch`. + * For MySQL Dialect, if Fetchsize is explicitly set by the user or if it's auto-inferred (null), + * enables `useCursorFetch`. It is disabled only if user explicitly sets FetchSize to 0. * * @param sqlDialect Sql Dialect. * @param url DB Url from passed configs. * @param fetchSize FetchSize Setting (Null if user has not explicitly set) - * @return Updated URL with `useCursorFetch` only if dialect is MySql and Fetchsize is not null. - * Same as input URL in all other cases. + * @return Updated URL with `useCursorFetch` only if dialect is MySql and Fetchsize is not 0. Same + * as input URL in all other cases. */ @VisibleForTesting @Nullable protected static String mysqlSetCursorModeIfNeeded( SQLDialect sqlDialect, String url, @Nullable Integer fetchSize) { - if (fetchSize == null) { - LOG.info( - "FetchSize is not explicitly configured. In case of out of memory errors, please set `FetchSize` according to the available memory and maximum size of a row."); + if (sqlDialect != SQLDialect.MYSQL) { return url; } - if (sqlDialect != SQLDialect.MYSQL) { + // For MySQL, to enable streaming/cursor mode, useCursorFetch must be true. + // We enable it if fetchSize is NULL (Auto-infer) or > 0. + // We only disable it if fetchSize is explicitly 0 (Fetch All). + if (fetchSize != null && fetchSize == 0) { + LOG.info( + "FetchSize is explicitly 0. MySQL cursor mode (useCursorFetch) will not be enabled explicitly."); return url; } - LOG.info("For Mysql, Fetchsize is explicitly configured. So setting `useCursorMode=true`."); + + LOG.info( + "FetchSize is {}. Setting MySQL `useCursorFetch=true`.", + fetchSize == null ? "Auto" : fetchSize); String updatedUrl = addParamToJdbcUrl(url, "useCursorFetch", "true"); return updatedUrl; } diff --git a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/dialectadapter/DialectAdapter.java b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/dialectadapter/DialectAdapter.java index 6f24ebe5e8..deb736051f 100644 --- a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/dialectadapter/DialectAdapter.java +++ b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/dialectadapter/DialectAdapter.java @@ -19,15 +19,18 @@ import com.google.cloud.teleport.v2.source.reader.io.exception.RetriableSchemaDiscoveryException; import com.google.cloud.teleport.v2.source.reader.io.exception.SchemaDiscoveryException; import com.google.cloud.teleport.v2.source.reader.io.jdbc.JdbcSchemaReference; +import com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper.JdbcValueMappingsProvider; import com.google.cloud.teleport.v2.source.reader.io.jdbc.uniformsplitter.UniformSplitterDBAdapter; import com.google.cloud.teleport.v2.source.reader.io.schema.RetriableSchemaDiscovery; import com.google.cloud.teleport.v2.source.reader.io.schema.SourceColumnIndexInfo; import com.google.cloud.teleport.v2.source.reader.io.schema.SourceSchemaReference; import com.google.cloud.teleport.v2.source.reader.io.schema.SourceSchemaReference.Kind; +import com.google.cloud.teleport.v2.source.reader.io.schema.SourceTableSchema; import com.google.cloud.teleport.v2.spanner.migrations.schema.SourceColumnType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.Map; /** * Interface to support various dialects of JDBC databases. @@ -111,4 +114,21 @@ ImmutableMap> discoverTableIndexes( JdbcSchemaReference sourceSchemaReference, ImmutableList tables) throws SchemaDiscoveryException, RetriableSchemaDiscoveryException; + + default long estimateRowSize( + SourceTableSchema sourceTableSchema, JdbcValueMappingsProvider jdbcValueMappingsProvider) { + return estimateRowSize( + sourceTableSchema.sourceColumnNameToSourceColumnType(), jdbcValueMappingsProvider); + } + + default long estimateRowSize( + Map sourceColumnNameToSourceColumnType, + JdbcValueMappingsProvider jdbcValueMappingsProvider) { + long estimatedRowSize = 0; + for (Map.Entry entry : + sourceColumnNameToSourceColumnType.entrySet()) { + estimatedRowSize += jdbcValueMappingsProvider.estimateColumnSize(entry.getValue()); + } + return estimatedRowSize; + } } diff --git a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/FetchSizeCalculator.java b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/FetchSizeCalculator.java new file mode 100644 index 0000000000..55869ccf64 --- /dev/null +++ b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/FetchSizeCalculator.java @@ -0,0 +1,105 @@ +/* + * Copyright (C) 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.source.reader.io.jdbc.iowrapper; + +import com.google.cloud.teleport.v2.source.reader.io.jdbc.iowrapper.config.TableConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Calculates the fetch size for JDBC readers based on worker resources and row size estimation. + * Formula: FetchSize = (WorkerMemory) / (4 * WorkerCores * MaxRowSize) + */ +public final class FetchSizeCalculator { + private static final Logger LOG = LoggerFactory.getLogger(FetchSizeCalculator.class); + + private static final int MIN_FETCH_SIZE = 1; + private static final int MAX_FETCH_SIZE = Integer.MAX_VALUE; + + private FetchSizeCalculator() {} + + /** + * @param estimatedRowSize Estimated size of a row in bytes. + * @param workerMemoryGB The Dataflow worker memory in GB. + * @param workerCores The Dataflow worker cores. + * @return The calculated fetch size, or 0 if it cannot be calculated. + */ + public static Integer getFetchSize( + TableConfig tableConfig, long estimatedRowSize, Double workerMemoryGB, Integer workerCores) { + if (tableConfig.fetchSize() != null) { + LOG.info( + "Explicitly configured fetch size for table {}: {}", + tableConfig.tableName(), + tableConfig.fetchSize()); + return tableConfig.fetchSize(); + } + + try { + if (estimatedRowSize == 0) { + LOG.warn( + "Estimated row size is 0 for table {}. FetchSize cannot be calculated. Cursor mode will not be enabled.", + tableConfig.tableName()); + return 0; + } + + if (workerMemoryGB == null || workerCores == null) { + LOG.warn( + "Worker memory or cores unavailable. FetchSize cannot be calculated. Cursor mode will not be enabled."); + return 0; + } + + long workerMemoryBytes = (long) (workerMemoryGB * 1024 * 1024 * 1024); + + // Formula: (Memory of Dataflow worker VM) / (2 * 2 * (Number of cores on the + // Dataflow worker VM) * (Maximum row size)) + // 2 * 2 = 4 (Safety factor) + long denominator = 4L * workerCores * estimatedRowSize; + + if (denominator == 0) { // Should not happen given estimatedRowSize check and cores >= 1 + LOG.warn( + "Denominator for fetch size calculation is zero for table {}. FetchSize cannot be calculated. Cursor mode will not be enabled.", + tableConfig.tableName()); + return 0; + } + + long calculatedFetchSize = workerMemoryBytes / denominator; + + LOG.info( + "Auto-inferred fetchSize for table {}: {} (Memory: {} bytes, Cores: {}, RowSize: {} bytes)", + tableConfig.tableName(), + calculatedFetchSize, + workerMemoryBytes, + workerCores, + estimatedRowSize); + + if (calculatedFetchSize < MIN_FETCH_SIZE) { + return MIN_FETCH_SIZE; + } + if (calculatedFetchSize > MAX_FETCH_SIZE) { + return MAX_FETCH_SIZE; + } + + return (int) calculatedFetchSize; + + } catch (Exception e) { + LOG.warn( + "Failed to auto-infer fetch size for table {}, error: {}. Cursor mode will not be enabled.", + tableConfig.tableName(), + e.getMessage()); + return 0; + } + } +} diff --git a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/JdbcIoWrapper.java b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/JdbcIoWrapper.java index d20275489f..d1ec985ce3 100644 --- a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/JdbcIoWrapper.java +++ b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/JdbcIoWrapper.java @@ -181,6 +181,11 @@ public SourceSchema discoverTableSchema() { tableConfig -> { SourceTableSchema sourceTableSchema = findSourceTableSchema(sourceSchema, tableConfig); + long estimatedRowSize = sourceTableSchema.estimatedRowSize(); + Integer calculatedFetchSize = + FetchSizeCalculator.getFetchSize( + tableConfig, estimatedRowSize, config.workerMemoryGB(), config.workerCores()); + int fetchSize = calculatedFetchSize; return Map.entry( SourceTableReference.builder() .setSourceSchemaReference(sourceSchema.schemaReference()) @@ -193,13 +198,15 @@ public SourceSchema discoverTableSchema() { dataSourceConfiguration, sourceSchema.schemaReference(), tableConfig, - sourceTableSchema) + sourceTableSchema, + fetchSize) : getJdbcIO( config, dataSourceConfiguration, sourceSchema.schemaReference(), tableConfig, - sourceTableSchema)); + sourceTableSchema, + fetchSize)); }) .collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); } @@ -237,6 +244,11 @@ static SourceSchema getSourceSchema( colEntry -> sourceTableSchemaBuilder.addSourceColumnNameToSourceColumnType( colEntry.getKey(), colEntry.getValue())); + long estimatedRowSize = + config + .dialectAdapter() + .estimateRowSize(tableEntry.getValue(), config.valueMappingsProvider()); + sourceTableSchemaBuilder.setEstimatedRowSize(estimatedRowSize); return sourceTableSchemaBuilder.build(); }) .forEach(sourceSchemaBuilder::addTableSchema); @@ -280,6 +292,10 @@ private static TableConfig getTableConfig( if (config.maxPartitions() != null && config.maxPartitions() != 0) { tableConfigBuilder.setMaxPartitions(config.maxPartitions()); } + // Set fetch size for the table from global fetch size if configured + if (config.maxFetchSize() != null) { + tableConfigBuilder.setFetchSize(config.maxFetchSize()); + } /* * TODO(vardhanvthigle): Add optional support for non-primary indexes. * Note: most of the implementation is generic for any unique index. @@ -413,7 +429,8 @@ private static PTransform> getJdbcIO( DataSourceConfiguration dataSourceConfiguration, SourceSchemaReference sourceSchemaReference, TableConfig tableConfig, - SourceTableSchema sourceTableSchema) { + SourceTableSchema sourceTableSchema, + int fetchSize) { ReadWithPartitions jdbcIO = JdbcIO.readWithPartitions() .withTable(delimitIdentifier(tableConfig.tableName())) @@ -428,9 +445,6 @@ private static PTransform> getJdbcIO( if (tableConfig.maxPartitions() != null) { jdbcIO = jdbcIO.withNumPartitions(tableConfig.maxPartitions()); } - if (config.maxFetchSize() != null) { - jdbcIO = jdbcIO.withFetchSize(config.maxFetchSize()); - } return jdbcIO; } @@ -449,7 +463,8 @@ private static PTransform> getReadWithUniformPart DataSourceConfiguration dataSourceConfiguration, SourceSchemaReference sourceSchemaReference, TableConfig tableConfig, - SourceTableSchema sourceTableSchema) { + SourceTableSchema sourceTableSchema, + int fetchSize) { ReadWithUniformPartitions.Builder readWithUniformPartitionsBuilder = ReadWithUniformPartitions.builder() @@ -458,7 +473,7 @@ private static PTransform> getReadWithUniformPart .setDataSourceProviderFn(JdbcIO.PoolableDataSourceProvider.of(dataSourceConfiguration)) .setDbAdapter(config.dialectAdapter()) .setApproxTotalRowCount(tableConfig.approxRowCount()) - .setFetchSize(config.maxFetchSize()) + .setFetchSize(fetchSize) .setRowMapper( new JdbcSourceRowMapper( config.valueMappingsProvider(), diff --git a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/config/JdbcIOWrapperConfig.java b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/config/JdbcIOWrapperConfig.java index b43d076a5f..a2193e0ba1 100644 --- a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/config/JdbcIOWrapperConfig.java +++ b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/config/JdbcIOWrapperConfig.java @@ -102,9 +102,12 @@ public JdbcSchemaReference jdbcSourceSchemaReference() { public abstract Integer maxPartitions(); /** - * Configures the size of data read in db, per db read call. Defaults to beam's DEFAULT_FETCH_SIZE - * of 50_000. For manually fine-tuning this, take into account the read ahead buffer pool settings - * (innodb_read_ahead_threshold) and the worker memory. + * Configures the size of data read in db, per db read call. + * + *

If explicitly set, this value overrides the auto-inferred fetch size. + * + *

If not set (null), the fetch size is auto-calculated based on the worker memory and + * estimated row size to optimize for the available resources. */ @Nullable public abstract Integer maxFetchSize(); @@ -253,6 +256,14 @@ public JdbcSchemaReference jdbcSourceSchemaReference() { private static final Integer DEFAULT_MIN_EVICTABLE_IDLE_TIME_MILLIS = 8 * 3600 * 1000; + /** Worker Memory in GB. */ + @Nullable + public abstract Double workerMemoryGB(); + + /** Worker Cores. */ + @Nullable + public abstract Integer workerCores(); + public abstract Builder toBuilder(); public static Builder builderWithMySqlDefaults() { @@ -268,7 +279,6 @@ public static Builder builderWithMySqlDefaults() { .setTableVsPartitionColumns(ImmutableMap.of()) .setMaxPartitions(null) .setWaitOn(null) - .setMaxFetchSize(null) .setDbParallelizationForReads(null) .setDbParallelizationForSplitProcess(DEFAULT_PARALLELIZATION_FOR_SLIT_PROCESS) .setReadWithUniformPartitionsFeatureEnabled(true) @@ -281,7 +291,9 @@ public static Builder builderWithMySqlDefaults() { .setMinEvictableIdleTimeMillis(DEFAULT_MIN_EVICTABLE_IDLE_TIME_MILLIS) .setSchemaDiscoveryConnectivityTimeoutMilliSeconds( DEFAULT_SCHEMA_DISCOVERY_CONNECTIVITY_TIMEOUT_MILLISECONDS) - .setSplitStageCountHint(-1L); + .setSplitStageCountHint(-1L) + .setWorkerMemoryGB(null) + .setWorkerCores(null); } public static Builder builderWithPostgreSQLDefaults() { @@ -312,7 +324,9 @@ public static Builder builderWithPostgreSQLDefaults() { .setMinEvictableIdleTimeMillis(DEFAULT_MIN_EVICTABLE_IDLE_TIME_MILLIS) .setSchemaDiscoveryConnectivityTimeoutMilliSeconds( DEFAULT_SCHEMA_DISCOVERY_CONNECTIVITY_TIMEOUT_MILLISECONDS) - .setSplitStageCountHint(-1L); + .setSplitStageCountHint(-1L) + .setWorkerMemoryGB(null) + .setWorkerCores(null); } @AutoValue.Builder @@ -385,6 +399,10 @@ public abstract Builder setAdditionalOperationsOnRanges( public abstract Builder setSplitStageCountHint(Long value); + public abstract Builder setWorkerMemoryGB(Double value); + + public abstract Builder setWorkerCores(Integer value); + public abstract JdbcIOWrapperConfig autoBuild(); public JdbcIOWrapperConfig build() { diff --git a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/config/TableConfig.java b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/config/TableConfig.java index 2058af20a5..0aae0cd36e 100644 --- a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/config/TableConfig.java +++ b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/config/TableConfig.java @@ -40,11 +40,16 @@ public abstract class TableConfig { /** Approximate count of the rows in the table. */ public abstract Long approxRowCount(); + /** Fetch Size for the table. */ + @Nullable + public abstract Integer fetchSize(); + public static Builder builder(String tableName) { return new AutoValue_TableConfig.Builder() .setTableName(tableName) .setMaxPartitions(null) - .setApproxRowCount(0L); + .setApproxRowCount(0L) + .setFetchSize(null); } @AutoValue.Builder @@ -58,6 +63,8 @@ public abstract static class Builder { public abstract Builder setApproxRowCount(Long value); + public abstract Builder setFetchSize(Integer value); + public Builder withPartitionColum(PartitionColumn column) { this.partitionColumnsBuilder().add(column); return this; diff --git a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/JdbcMappings.java b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/JdbcMappings.java new file mode 100644 index 0000000000..73c2e3c2d2 --- /dev/null +++ b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/JdbcMappings.java @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper; + +import com.google.auto.value.AutoValue; +import com.google.cloud.teleport.v2.spanner.migrations.schema.SourceColumnType; +import com.google.common.collect.ImmutableMap; +import java.io.Serializable; +import java.util.function.Function; + +/** Registry for JDBC type mappings, including value extraction, mapping, and size estimation. */ +@AutoValue +public abstract class JdbcMappings implements Serializable { + + public abstract ImmutableMap> mappings(); + + public abstract ImmutableMap> sizeEstimators(); + + public static Builder builder() { + return new AutoValue_JdbcMappings.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + abstract ImmutableMap.Builder> mappingsBuilder(); + + abstract ImmutableMap.Builder> + sizeEstimatorsBuilder(); + + /** + * Register a mapping with a constant size estimate. + * + * @param typeName The JDBC type name (e.g., "VARCHAR"). + * @param extractor The extractor to get value from ResultSet. + * @param mapper The mapper to convert value to Avro. + * @param constantSize The constant size estimate in bytes. + */ + public Builder put( + String typeName, + ResultSetValueExtractor extractor, + ResultSetValueMapper mapper, + int constantSize) { + return put(typeName, extractor, mapper, (ignore) -> constantSize); + } + + /** + * Register a mapping with a dynamic size estimator. + * + * @param typeName The JDBC type name (e.g., "VARCHAR"). + * @param extractor The extractor to get value from ResultSet. + * @param mapper The mapper to convert value to Avro. + * @param sizeEstimator The function to estimate size based on SourceColumnType. + */ + public Builder put( + String typeName, + ResultSetValueExtractor extractor, + ResultSetValueMapper mapper, + Function sizeEstimator) { + mappingsBuilder().put(typeName, new JdbcValueMapper<>(extractor, mapper)); + sizeEstimatorsBuilder().put(typeName, sizeEstimator); + return this; + } + + public abstract JdbcMappings build(); + } +} diff --git a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/JdbcValueMappingsProvider.java b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/JdbcValueMappingsProvider.java index 8dbf46b0b5..193a560b36 100644 --- a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/JdbcValueMappingsProvider.java +++ b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/JdbcValueMappingsProvider.java @@ -15,6 +15,7 @@ */ package com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper; +import com.google.cloud.teleport.v2.spanner.migrations.schema.SourceColumnType; import com.google.common.collect.ImmutableMap; import java.io.Serializable; @@ -24,6 +25,14 @@ */ public interface JdbcValueMappingsProvider extends Serializable { + /** + * Estimate the column size in bytes for a given column type. + * + * @param sourceColumnType The column type to estimate size for. + * @return Estimated size in bytes. + */ + int estimateColumnSize(SourceColumnType sourceColumnType); + /** * Get Mapping of source types to {@link JdbcValueMapper}. * diff --git a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/provider/MysqlJdbcValueMappings.java b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/provider/MysqlJdbcValueMappings.java index ec8c89cac5..72ad6ce1be 100644 --- a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/provider/MysqlJdbcValueMappings.java +++ b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/provider/MysqlJdbcValueMappings.java @@ -15,6 +15,7 @@ */ package com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper.provider; +import com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper.JdbcMappings; import com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper.JdbcValueMapper; import com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper.JdbcValueMappingsProvider; import com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper.ResultSetValueExtractor; @@ -22,6 +23,7 @@ import com.google.cloud.teleport.v2.source.reader.io.schema.typemapping.provider.unified.CustomLogical.TimeIntervalMicros; import com.google.cloud.teleport.v2.source.reader.io.schema.typemapping.provider.unified.CustomSchema.DateTime; import com.google.cloud.teleport.v2.source.reader.io.schema.typemapping.provider.unified.CustomSchema.Interval; +import com.google.cloud.teleport.v2.spanner.migrations.schema.SourceColumnType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.re2j.Matcher; @@ -33,18 +35,19 @@ import java.sql.ResultSet; import java.time.Instant; import java.util.Calendar; -import java.util.Map; -import java.util.Map.Entry; import java.util.TimeZone; import java.util.concurrent.TimeUnit; import org.apache.avro.Schema.Field; import org.apache.avro.generic.GenericRecordBuilder; import org.apache.commons.codec.binary.Hex; import org.apache.commons.lang3.StringUtils; -import org.apache.commons.lang3.tuple.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class MysqlJdbcValueMappings implements JdbcValueMappingsProvider { + private static final Logger LOG = LoggerFactory.getLogger(MysqlJdbcValueMappings.class); + /** * Pass the value extracted from {@link ResultSet} to {@link GenericRecordBuilder#set(Field, * Object)}. Most of the values, like basic types don't need any marshalling and can be directly @@ -150,6 +153,11 @@ private static long instantToMicro(Instant instant) { + TimeUnit.NANOSECONDS.toMicros(instant.getNano()); } + private static long getLengthOrPrecision(SourceColumnType sourceColumnType) { + Long[] mods = sourceColumnType.getMods(); + return (mods != null && mods.length > 0 && mods[0] != null) ? mods[0] : 0; + } + /** Map {@link java.sql.Timestamp} to timestampMicros LogicalType. */ private static final ResultSetValueMapper sqlTimestampToAvroTimestampMicros = (value, schema) -> instantToMicro(value.toInstant()); @@ -183,58 +191,149 @@ private static long instantToMicro(Instant instant) { return isNegative ? -1 * micros : micros; }; - /** - * Static mapping of SourceColumnType to {@link ResultSetValueExtractor} and {@link - * ResultSetValueMapper}. - */ - private static final ImmutableMap> SCHEMA_MAPPINGS = - ImmutableMap., ResultSetValueMapper>>builder() - .put("BIGINT", Pair.of(ResultSet::getLong, valuePassThrough)) - .put("BIGINT UNSIGNED", Pair.of(ResultSet::getBigDecimal, bigDecimalToAvroNumber)) - .put("BINARY", Pair.of(ResultSet::getBytes, bytesToHexString)) - .put("BIT", Pair.of(ResultSet::getLong, valuePassThrough)) - .put("BLOB", Pair.of(ResultSet::getBlob, blobToHexString)) - .put("BOOL", Pair.of(ResultSet::getInt, valuePassThrough)) - .put("CHAR", Pair.of(ResultSet::getString, valuePassThrough)) - .put("DATE", Pair.of(utcDateExtractor, sqlDateToAvroTimestampMicros)) - .put("DATETIME", Pair.of(utcTimeStampExtractor, sqlTimestampToAvroDateTime)) - .put("DECIMAL", Pair.of(ResultSet::getBigDecimal, bigDecimalToByteArray)) - .put("DOUBLE", Pair.of(ResultSet::getDouble, valuePassThrough)) - .put("ENUM", Pair.of(ResultSet::getString, valuePassThrough)) - .put("FLOAT", Pair.of(ResultSet::getFloat, valuePassThrough)) - .put("INTEGER", Pair.of(ResultSet::getInt, valuePassThrough)) - .put("INTEGER UNSIGNED", Pair.of(ResultSet::getLong, valuePassThrough)) - .put("JSON", Pair.of(ResultSet::getString, valuePassThrough)) - .put("LONGBLOB", Pair.of(ResultSet::getBlob, blobToHexString)) - .put("LONGTEXT", Pair.of(ResultSet::getString, valuePassThrough)) - .put("MEDIUMBLOB", Pair.of(ResultSet::getBlob, blobToHexString)) - .put("MEDIUMINT", Pair.of(ResultSet::getInt, valuePassThrough)) - .put("MEDIUMTEXT", Pair.of(ResultSet::getString, valuePassThrough)) - .put("SET", Pair.of(ResultSet::getString, valuePassThrough)) - .put("SMALLINT", Pair.of(ResultSet::getInt, valuePassThrough)) - .put("TEXT", Pair.of(ResultSet::getString, valuePassThrough)) - .put("TIME", Pair.of(ResultSet::getString, timeStringToAvroTimeInterval)) - .put("TIMESTAMP", Pair.of(utcTimeStampExtractor, sqlTimestampToAvroTimestampMicros)) - .put("TINYBLOB", Pair.of(ResultSet::getBlob, blobToHexString)) - .put("TINYINT", Pair.of(ResultSet::getInt, valuePassThrough)) - .put("TINYTEXT", Pair.of(ResultSet::getString, valuePassThrough)) - .put("VARBINARY", Pair.of(ResultSet::getBytes, bytesToHexString)) - .put("VARCHAR", Pair.of(ResultSet::getString, valuePassThrough)) - .put("YEAR", Pair.of(ResultSet::getInt, valuePassThrough)) - .build() - .entrySet() - .stream() - .map( - entry -> - Map.entry( - entry.getKey(), - new JdbcValueMapper<>( - entry.getValue().getLeft(), entry.getValue().getRight()))) - .collect(ImmutableMap.toImmutableMap(Entry::getKey, Entry::getValue)); + private static final JdbcMappings JDBC_MAPPINGS = + JdbcMappings.builder() + .put("BIGINT", ResultSet::getLong, valuePassThrough, 8) + .put("BIGINT UNSIGNED", ResultSet::getBigDecimal, bigDecimalToAvroNumber, 8) + .put( + "BINARY", + ResultSet::getBytes, + bytesToHexString, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + // in BINARY length is measured in bytes. ref: + // https://dev.mysql.com/doc/refman/8.4/en/binary-varbinary.html + return (int) (n > 0 ? n : 255); + }) + .put( + "BIT", + ResultSet::getLong, + valuePassThrough, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + // BIT(N) -> (N+7)/8 since it is stored in bytes + return (int) ((n > 0 ? n : 1) + 7) / 8; + }) + .put("BLOB", ResultSet::getBlob, blobToHexString, 65_535) // BLOB -> 65,535 bytes + .put("BOOL", ResultSet::getInt, valuePassThrough, 1) + .put( + "CHAR", + ResultSet::getString, + valuePassThrough, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + // CHAR -> N * 4 since it takes 4 bytes per char in utf8mb4 format. Max length + // is 255. + return (int) ((n > 0 ? n : 255) * 4); + }) + /* + * Time related type sizes are inferred from the way the JDBC driver decodes the + * binary data ref: + * https://github.com/mysql/mysql-connector-j/blob/release/8.0/src/main/protocol-impl/java/com/mysql/cj/protocol/a/MysqlBinaryValueDecoder.java + */ + .put("DATE", utcDateExtractor, sqlDateToAvroTimestampMicros, 4) + .put("DATETIME", utcTimeStampExtractor, sqlTimestampToAvroDateTime, 11) + .put( + "DECIMAL", + ResultSet::getBigDecimal, + bigDecimalToByteArray, + sourceColumnType -> { + long m = getLengthOrPrecision(sourceColumnType); + // DECIMAL(M,D) -> M + 2 bytes since it is internally stored as a byte encoded + // string (+2 for sign and decimal point) + // Max number of digits in decimal is 65. Ref: + // https://dev.mysql.com/doc/refman/8.4/en/fixed-point-types.html + return (int) ((m > 0 ? m : 65) + 2); + }) + .put("DOUBLE", ResultSet::getDouble, valuePassThrough, 8) + .put( + "ENUM", + ResultSet::getString, + valuePassThrough, + 1020) // The maximum supported length of an individual ENUM element is M <= 255 + // and (M x w) <= 1020, where M is the element literal length and w is the + // number of bytes required for the maximum-length character in the character + // set. https://dev.mysql.com/doc/refman/8.0/en/string-type-syntax.html + .put("FLOAT", ResultSet::getFloat, valuePassThrough, 4) + .put("INTEGER", ResultSet::getInt, valuePassThrough, 4) + .put("INTEGER UNSIGNED", ResultSet::getLong, valuePassThrough, 4) + .put( + "JSON", ResultSet::getString, valuePassThrough, Integer.MAX_VALUE) // JSON -> Long Max + .put( + "LONGBLOB", + ResultSet::getBlob, + blobToHexString, + Integer.MAX_VALUE) // LONGBLOB -> Long Max + .put( + "LONGTEXT", + ResultSet::getString, + valuePassThrough, + Integer.MAX_VALUE) // LONGTEXT -> Long Max + .put("MEDIUMBLOB", ResultSet::getBlob, blobToHexString, 16_777_215) // MEDIUMBLOB -> 16MB + .put("MEDIUMINT", ResultSet::getInt, valuePassThrough, 4) + .put( + "MEDIUMTEXT", + ResultSet::getString, + valuePassThrough, + 16_777_215) // MEDIUMTEXT -> 16MB + .put( + "SET", + ResultSet::getString, + valuePassThrough, + 1020 * 64) // Number of elements in a SET can be up to 64. The maximum supported + // length of an individual SET element is M <= 255 and (M x w) <= 1020, where M + // is the element literal length and w is the number of bytes required for the + // maximum-length character in the character set. + // https://dev.mysql.com/doc/refman/8.0/en/string-type-syntax.html + .put("SMALLINT", ResultSet::getInt, valuePassThrough, 2) + .put("TEXT", ResultSet::getString, valuePassThrough, 65_535) + .put("TIME", ResultSet::getString, timeStringToAvroTimeInterval, 12) + .put("TIMESTAMP", utcTimeStampExtractor, sqlTimestampToAvroTimestampMicros, 11) + .put("TINYBLOB", ResultSet::getBlob, blobToHexString, 255) + .put("TINYINT", ResultSet::getInt, valuePassThrough, 1) + .put("TINYTEXT", ResultSet::getString, valuePassThrough, 255) + .put( + "VARBINARY", + ResultSet::getBytes, + bytesToHexString, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + // in VARBINARY length is measured in bytes. ref: + // https://dev.mysql.com/doc/refman/8.4/en/binary-varbinary.html + return (int) (n > 0 ? n : 65535); + }) + .put( + "VARCHAR", + ResultSet::getString, + valuePassThrough, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + // VARCHAR -> N * 4 since it takes 4 bytes per char in utf8mb4 format. Max bytes + // allowed is 65535. ref: https://dev.mysql.com/doc/refman/8.4/en/char.html + return (int) (n > 0 ? n * 4 : 65535); + }) + .put("YEAR", ResultSet::getInt, valuePassThrough, 2) + .build(); /** Get static mapping of SourceColumnType to {@link JdbcValueMapper}. */ @Override public ImmutableMap> getMappings() { - return SCHEMA_MAPPINGS; + return JDBC_MAPPINGS.mappings(); + } + + /** + * estimate the column size in bytes for a given column type. + * + *

Ref: MySQL + * Storage Requirements + */ + @Override + public int estimateColumnSize(SourceColumnType sourceColumnType) { + String typeName = sourceColumnType.getName().toUpperCase(); + if (JDBC_MAPPINGS.sizeEstimators().containsKey(typeName)) { + return JDBC_MAPPINGS.sizeEstimators().get(typeName).apply(sourceColumnType); + } + LOG.warn("Unknown column type: {}. Defaulting to size: 65,535.", sourceColumnType); + return 65_535; } } diff --git a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/provider/PostgreSQLJdbcValueMappings.java b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/provider/PostgreSQLJdbcValueMappings.java index 2680ec227a..1237e84575 100644 --- a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/provider/PostgreSQLJdbcValueMappings.java +++ b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/provider/PostgreSQLJdbcValueMappings.java @@ -15,6 +15,7 @@ */ package com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper.provider; +import com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper.JdbcMappings; import com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper.JdbcValueMapper; import com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper.JdbcValueMappingsProvider; import com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper.ResultSetValueExtractor; @@ -25,21 +26,22 @@ import java.sql.ResultSet; import java.time.Instant; import java.time.OffsetDateTime; -import java.time.OffsetTime; import java.time.ZoneOffset; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; import java.time.temporal.ChronoField; import java.util.Calendar; -import java.util.Map; import java.util.TimeZone; import java.util.concurrent.TimeUnit; import org.apache.avro.generic.GenericRecordBuilder; -import org.apache.commons.lang3.tuple.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** PostgreSQL data type mapping to AVRO types. */ public class PostgreSQLJdbcValueMappings implements JdbcValueMappingsProvider { + private static final Logger LOG = LoggerFactory.getLogger(PostgreSQLJdbcValueMappings.class); + private static final Calendar UTC_CALENDAR = Calendar.getInstance(TimeZone.getTimeZone(ZoneOffset.UTC)); @@ -57,13 +59,6 @@ private static long toMicros(Instant instant) { + TimeUnit.NANOSECONDS.toMicros(instant.getNano()); } - private static long toMicros(OffsetTime offsetTime) { - return TimeUnit.HOURS.toMicros(offsetTime.getHour()) - + TimeUnit.MINUTES.toMicros(offsetTime.getMinute()) - + TimeUnit.SECONDS.toMicros(offsetTime.getSecond()) - + TimeUnit.NANOSECONDS.toMicros(offsetTime.getNano()); - } - private static final ResultSetValueMapper valuePassThrough = (value, schema) -> value; private static final ResultSetValueExtractor bytesExtractor = @@ -109,63 +104,187 @@ private static long toMicros(OffsetTime offsetTime) { TimeUnit.SECONDS.toMillis(value.getOffset().getTotalSeconds())) .build(); - private static final ImmutableMap> SCHEMA_MAPPINGS = - ImmutableMap., ResultSetValueMapper>>builder() - .put("BIGINT", Pair.of(ResultSet::getLong, valuePassThrough)) - .put("BIGSERIAL", Pair.of(ResultSet::getLong, valuePassThrough)) - .put("BIT", Pair.of(bytesExtractor, valuePassThrough)) - .put("BIT VARYING", Pair.of(bytesExtractor, valuePassThrough)) - .put("BOOL", Pair.of(ResultSet::getBoolean, valuePassThrough)) - .put("BOOLEAN", Pair.of(ResultSet::getBoolean, valuePassThrough)) - .put("BYTEA", Pair.of(bytesExtractor, valuePassThrough)) - .put("CHAR", Pair.of(ResultSet::getString, valuePassThrough)) - .put("CHARACTER", Pair.of(ResultSet::getString, valuePassThrough)) - .put("CHARACTER VARYING", Pair.of(ResultSet::getString, valuePassThrough)) - .put("CITEXT", Pair.of(ResultSet::getString, valuePassThrough)) - .put("DATE", Pair.of(dateExtractor, dateToAvro)) - .put("DECIMAL", Pair.of(ResultSet::getObject, numericToAvro)) - .put("DOUBLE PRECISION", Pair.of(ResultSet::getDouble, valuePassThrough)) - .put("FLOAT4", Pair.of(ResultSet::getFloat, valuePassThrough)) - .put("FLOAT8", Pair.of(ResultSet::getDouble, valuePassThrough)) - .put("INT", Pair.of(ResultSet::getInt, valuePassThrough)) - .put("INTEGER", Pair.of(ResultSet::getInt, valuePassThrough)) - .put("INT2", Pair.of(ResultSet::getInt, valuePassThrough)) - .put("INT4", Pair.of(ResultSet::getInt, valuePassThrough)) - .put("INT8", Pair.of(ResultSet::getLong, valuePassThrough)) - .put("JSON", Pair.of(ResultSet::getString, valuePassThrough)) - .put("JSONB", Pair.of(ResultSet::getString, valuePassThrough)) - .put("MONEY", Pair.of(ResultSet::getDouble, valuePassThrough)) - .put("NUMERIC", Pair.of(ResultSet::getObject, numericToAvro)) - .put("OID", Pair.of(ResultSet::getLong, valuePassThrough)) - .put("REAL", Pair.of(ResultSet::getFloat, valuePassThrough)) - .put("SERIAL", Pair.of(ResultSet::getInt, valuePassThrough)) - .put("SERIAL2", Pair.of(ResultSet::getInt, valuePassThrough)) - .put("SERIAL4", Pair.of(ResultSet::getInt, valuePassThrough)) - .put("SERIAL8", Pair.of(ResultSet::getLong, valuePassThrough)) - .put("SMALLINT", Pair.of(ResultSet::getInt, valuePassThrough)) - .put("SMALLSERIAL", Pair.of(ResultSet::getInt, valuePassThrough)) - .put("TEXT", Pair.of(ResultSet::getString, valuePassThrough)) - .put("TIMESTAMP", Pair.of(timestampExtractor, timestampToAvro)) - .put("TIMESTAMPTZ", Pair.of(timestamptzExtractor, timestamptzToAvro)) - .put("TIMESTAMP WITH TIME ZONE", Pair.of(timestamptzExtractor, timestamptzToAvro)) - .put("TIMESTAMP WITHOUT TIME ZONE", Pair.of(timestampExtractor, timestampToAvro)) - .put("UUID", Pair.of(ResultSet::getString, valuePassThrough)) - .put("VARBIT", Pair.of(bytesExtractor, valuePassThrough)) - .put("VARCHAR", Pair.of(ResultSet::getString, valuePassThrough)) - .build() - .entrySet() - .stream() - .map( - entry -> - Map.entry( - entry.getKey(), - new JdbcValueMapper<>( - entry.getValue().getLeft(), entry.getValue().getRight()))) - .collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); - - /** Get static mapping of SourceColumnType to {@link JdbcValueMapper}. */ + private static final JdbcMappings JDBC_MAPPINGS = + /* + Postgres JDBC uses binary encoding for most types ref:org.postgresql.jdbc.PgConnection.getSupportedBinaryOids() + */ + JdbcMappings.builder() + .put("BIGINT", ResultSet::getLong, valuePassThrough, 8) // - + .put("BIGSERIAL", ResultSet::getLong, valuePassThrough, 8) // - + .put( + "BIT", + bytesExtractor, + valuePassThrough, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + return (int) ((n > 0 ? n : 1)); // bit uses text protocol. + }) + .put( + "BIT VARYING", + bytesExtractor, + valuePassThrough, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + return (int) + ((n > 0 + ? n + : 10 * 1024 + * 1024)); // bit varying without a length specification means unlimited + // length. ref: + // https://www.postgresql.org/docs/current/datatype-bit.html + }) + .put("BOOL", ResultSet::getBoolean, valuePassThrough, 1) + .put("BOOLEAN", ResultSet::getBoolean, valuePassThrough, 1) + .put( + "BYTEA", + bytesExtractor, + valuePassThrough, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + long length = n > 0 ? n : 10 * 1024 * 1024; + return (int) Math.min(length, Integer.MAX_VALUE); + }) + .put( + "CHAR", + ResultSet::getString, + valuePassThrough, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + // CHAR(N) -> N * 4 bytes (UTF-8 max) + overhead. + return (int) Math.min(((n > 0 ? n : 255) * 4), Integer.MAX_VALUE); + }) + .put( + "CHARACTER", + ResultSet::getString, + valuePassThrough, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + return (int) Math.min(((n > 0 ? n : 255) * 4) + 24, Integer.MAX_VALUE); + }) + .put( + "CHARACTER VARYING", + ResultSet::getString, + valuePassThrough, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + return (int) Math.min(((n > 0 ? n : 255) * 4) + 24, Integer.MAX_VALUE); + }) + .put( + "CITEXT", + ResultSet::getString, + valuePassThrough, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + long length = n > 0 ? n : 10 * 1024 * 1024; + return (int) Math.min((length * 4) + 24, Integer.MAX_VALUE); + }) + .put("DATE", dateExtractor, dateToAvro, 4) + .put( + "DECIMAL", + ResultSet::getObject, + numericToAvro, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + return (int) (n / 2 + 8); + }) + .put("DOUBLE PRECISION", ResultSet::getDouble, valuePassThrough, 8) + .put("FLOAT4", ResultSet::getFloat, valuePassThrough, 4) + .put("FLOAT8", ResultSet::getDouble, valuePassThrough, 8) + .put("INT", ResultSet::getInt, valuePassThrough, 4) + .put("INTEGER", ResultSet::getInt, valuePassThrough, 4) + .put("INT2", ResultSet::getInt, valuePassThrough, 2) + .put("INT4", ResultSet::getInt, valuePassThrough, 4) + .put("INT8", ResultSet::getLong, valuePassThrough, 8) + .put( + "JSON", + ResultSet::getString, + valuePassThrough, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + long length = n > 0 ? n : 10 * 1024 * 1024; + return (int) Math.min((length * 4) + 24, Integer.MAX_VALUE); + }) + .put( + "JSONB", + ResultSet::getString, + valuePassThrough, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + long length = n > 0 ? n : 10 * 1024 * 1024; + return (int) Math.min((length * 4) + 24, Integer.MAX_VALUE); + }) + .put("MONEY", ResultSet::getDouble, valuePassThrough, 8) + .put( + "NUMERIC", + ResultSet::getObject, + numericToAvro, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + return (int) (n / 2 + 8); + }) + .put( + "OID", + ResultSet::getLong, + valuePassThrough, + 4) // Usually unsigned int, mapped to long for safety + .put("REAL", ResultSet::getFloat, valuePassThrough, 4) + .put("SERIAL", ResultSet::getInt, valuePassThrough, 4) + .put("SERIAL2", ResultSet::getInt, valuePassThrough, 2) + .put("SERIAL4", ResultSet::getInt, valuePassThrough, 4) + .put("SERIAL8", ResultSet::getLong, valuePassThrough, 8) + .put("SMALLINT", ResultSet::getInt, valuePassThrough, 2) + .put("SMALLSERIAL", ResultSet::getInt, valuePassThrough, 2) + .put( + "TEXT", + ResultSet::getString, + valuePassThrough, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + long length = n > 0 ? n : 10 * 1024 * 1024; + return (int) Math.min((length * 4) + 24, Integer.MAX_VALUE); + }) + .put("TIMESTAMP", timestampExtractor, timestampToAvro, 8) + .put("TIMESTAMPTZ", timestamptzExtractor, timestamptzToAvro, 8) + .put("TIMESTAMP WITH TIME ZONE", timestamptzExtractor, timestamptzToAvro, 8) + .put("TIMESTAMP WITHOUT TIME ZONE", timestampExtractor, timestampToAvro, 8) + .put("UUID", ResultSet::getString, valuePassThrough, 16) + .put( + "VARBIT", + bytesExtractor, + valuePassThrough, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + return (int) ((n > 0 ? n : 10 * 1024 * 1024) + 24); + }) + .put( + "VARCHAR", + ResultSet::getString, + valuePassThrough, + sourceColumnType -> { + long n = getLengthOrPrecision(sourceColumnType); + return (int) Math.min(((n > 0 ? n : 10 * 1024 * 1024) * 4) + 24, Integer.MAX_VALUE); + }) + .build(); + + @Override + public int estimateColumnSize( + com.google.cloud.teleport.v2.spanner.migrations.schema.SourceColumnType sourceColumnType) { + String typeName = sourceColumnType.getName().toUpperCase(); + if (JDBC_MAPPINGS.sizeEstimators().containsKey(typeName)) { + return JDBC_MAPPINGS.sizeEstimators().get(typeName).apply(sourceColumnType); + } + LOG.warn("Unknown column type: {}. Defaulting to size: 65,535.", sourceColumnType); + return 65_535; + } + + private static long getLengthOrPrecision( + com.google.cloud.teleport.v2.spanner.migrations.schema.SourceColumnType sourceColumnType) { + Long[] mods = sourceColumnType.getMods(); + return (mods != null && mods.length > 0 && mods[0] != null) ? mods[0] : 0; + } + @Override public ImmutableMap> getMappings() { - return SCHEMA_MAPPINGS; + return JDBC_MAPPINGS.mappings(); } } diff --git a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/schema/SourceTableSchema.java b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/schema/SourceTableSchema.java index 498f506f5a..9bf1b026d2 100644 --- a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/schema/SourceTableSchema.java +++ b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/source/reader/io/schema/SourceTableSchema.java @@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableMap; import java.io.Serializable; import java.util.UUID; +import javax.annotation.Nullable; import org.apache.avro.LogicalTypes; import org.apache.avro.Schema; import org.apache.avro.SchemaBuilder.FieldAssembler; @@ -47,6 +48,9 @@ public abstract class SourceTableSchema implements Serializable { public abstract String tableName(); + @Nullable + public abstract Long estimatedRowSize(); + // Source Schema from metadata tables. SourceColumnType is similar to // com.google.cloud.teleport.v2.spanner.migrations.schema /* TODO(vardhanvthigle): @@ -86,6 +90,8 @@ public abstract static class Builder { public abstract Builder setTableName(String value); + public abstract Builder setEstimatedRowSize(@Nullable Long value); + @VisibleForTesting protected UnifiedTypeMapper.MapperType mapperType; abstract ImmutableMap.Builder diff --git a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/DbConfigContainer.java b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/DbConfigContainer.java index 2cde0f6543..4d38ad9d23 100644 --- a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/DbConfigContainer.java +++ b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/DbConfigContainer.java @@ -53,6 +53,7 @@ Map getSrcTableToShardIdColumnMap( * * @param sourceTables * @param waitOnSignal + * @param schemaMapper * @return */ IoWrapper getIOWrapper(List sourceTables, Wait.OnSignal waitOnSignal); diff --git a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/PipelineController.java b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/PipelineController.java index d9ef53744c..7b5863c89c 100644 --- a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/PipelineController.java +++ b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/PipelineController.java @@ -38,6 +38,7 @@ import java.util.Optional; import java.util.stream.Collectors; import org.apache.beam.repackaged.core.org.apache.commons.lang3.StringUtils; +import org.apache.beam.runners.dataflow.options.DataflowPipelineWorkerPoolOptions; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig; @@ -348,6 +349,8 @@ public ShardedJdbcDbConfigContainer( public JdbcIOWrapperConfig getJDBCIOWrapperConfig( List sourceTables, Wait.OnSignal waitOnSignal) { + String workerZone = OptionsToConfigBuilder.extractWorkerZone(options); + return OptionsToConfigBuilder.getJdbcIOWrapperConfig( sqlDialect, sourceTables, @@ -366,7 +369,10 @@ public JdbcIOWrapperConfig getJDBCIOWrapperConfig( options.getNumPartitions(), waitOnSignal, options.getFetchSize(), - options.getUniformizationStageCountHint()); + options.getUniformizationStageCountHint(), + options.getProjectId(), + workerZone, + options.as(DataflowPipelineWorkerPoolOptions.class).getWorkerMachineType()); } @Override diff --git a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/SourceDbToSpanner.java b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/SourceDbToSpanner.java index 24272eadad..cb309aa4d8 100644 --- a/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/SourceDbToSpanner.java +++ b/v2/sourcedb-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/SourceDbToSpanner.java @@ -21,7 +21,7 @@ import com.google.cloud.teleport.v2.common.UncaughtExceptionLogger; import com.google.cloud.teleport.v2.options.SourceDbToSpannerOptions; import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard; -import com.google.cloud.teleport.v2.spanner.migrations.utils.DataflowWorkerMachineTypeValidator; +import com.google.cloud.teleport.v2.spanner.migrations.utils.DataflowWorkerMachineTypeUtils; import com.google.cloud.teleport.v2.spanner.migrations.utils.SecretManagerAccessorImpl; import com.google.cloud.teleport.v2.spanner.migrations.utils.ShardFileReader; import com.google.common.annotations.VisibleForTesting; @@ -104,7 +104,7 @@ static PipelineResult run(SourceDbToSpannerOptions options) { Pipeline pipeline = Pipeline.create(options); String workerMachineType = pipeline.getOptions().as(DataflowPipelineWorkerPoolOptions.class).getWorkerMachineType(); - DataflowWorkerMachineTypeValidator.validateMachineSpecs(workerMachineType, 4); + DataflowWorkerMachineTypeUtils.validateMachineSpecs(workerMachineType, 4); SpannerConfig spannerConfig = createSpannerConfig(options); diff --git a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/options/OptionsToConfigBuilderTest.java b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/options/OptionsToConfigBuilderTest.java index 3ca2fc170e..ce667adf8f 100644 --- a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/options/OptionsToConfigBuilderTest.java +++ b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/options/OptionsToConfigBuilderTest.java @@ -23,6 +23,7 @@ import java.net.URISyntaxException; import java.util.ArrayList; import java.util.List; +import org.apache.beam.runners.dataflow.options.DataflowPipelineWorkerPoolOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; @@ -31,6 +32,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; /** Test class for {@link OptionsToConfigBuilder}. */ @@ -61,7 +63,9 @@ public void testConfigWithMySqlDefaultsFromOptions() { sourceDbToSpannerOptions, List.of("table1", "table2"), null, Wait.on(dummyPCollection)); assertThat(config.jdbcDriverClassName()).isEqualTo(testDriverClassName); assertThat(config.sourceDbURL()) - .isEqualTo(testUrl + "?allowMultiQueries=true&autoReconnect=true&maxReconnects=10"); + .isEqualTo( + testUrl + + "?allowMultiQueries=true&autoReconnect=true&maxReconnects=10&useCursorFetch=true"); assertThat(config.tables()).containsExactlyElementsIn(new String[] {"table1", "table2"}); assertThat(config.dbAuth().getUserName().get()).isEqualTo(testUser); assertThat(config.dbAuth().getPassword().get()).isEqualTo(testPassword); @@ -101,7 +105,10 @@ public void testConfigWithMySqlUrlFromOptions() { 0, Wait.on(dummyPCollection), null, - 0L); + 0L, + null, + null, + null); JdbcIOWrapperConfig configWithoutConnectionProperties = OptionsToConfigBuilder.getJdbcIOWrapperConfig( @@ -122,14 +129,17 @@ public void testConfigWithMySqlUrlFromOptions() { 0, Wait.on(dummyPCollection), null, - 0L); + 0L, + null, + null, + null); assertThat(configWithConnectionProperties.sourceDbURL()) .isEqualTo( - "jdbc:mysql://myhost:3306/mydb?testParam=testValue&allowMultiQueries=true&autoReconnect=true&maxReconnects=10"); + "jdbc:mysql://myhost:3306/mydb?testParam=testValue&allowMultiQueries=true&autoReconnect=true&maxReconnects=10&useCursorFetch=true"); assertThat(configWithoutConnectionProperties.sourceDbURL()) .isEqualTo( - "jdbc:mysql://myhost:3306/mydb?allowMultiQueries=true&autoReconnect=true&maxReconnects=10"); + "jdbc:mysql://myhost:3306/mydb?allowMultiQueries=true&autoReconnect=true&maxReconnects=10&useCursorFetch=true"); } @Test @@ -188,7 +198,10 @@ public void testConfigWithPostgreSqlUrlFromOptions() { 0, Wait.on(dummyPCollection), null, - 0L); + 0L, + null, + null, + null); JdbcIOWrapperConfig configWithoutConnectionParameters = OptionsToConfigBuilder.getJdbcIOWrapperConfig( SQLDialect.POSTGRESQL, @@ -208,7 +221,10 @@ public void testConfigWithPostgreSqlUrlFromOptions() { 0, Wait.on(dummyPCollection), null, - -1L); + -1L, + null, + null, + null); assertThat(configWithoutConnectionParameters.sourceDbURL()) .isEqualTo("jdbc:postgresql://myhost:5432/mydb?currentSchema=public"); assertThat(configWithConnectionParameters.sourceDbURL()) @@ -240,7 +256,10 @@ public void testConfigWithPostgreSqlUrlWithNamespace() { 0, Wait.on(dummyPCollection), null, - 0L); + 0L, + null, + null, + null); assertThat(configWithNamespace.sourceDbURL()) .isEqualTo("jdbc:postgresql://myhost:5432/mydb?currentSchema=mynamespace"); } @@ -304,10 +323,85 @@ public void testMySqlSetCursorModeIfNeeded() { assertThat( OptionsToConfigBuilder.mysqlSetCursorModeIfNeeded( SQLDialect.MYSQL, "jdbc:mysql://localhost:3306/testDB?useSSL=true", null)) - .isEqualTo("jdbc:mysql://localhost:3306/testDB?useSSL=true"); + .isEqualTo("jdbc:mysql://localhost:3306/testDB?useSSL=true&useCursorFetch=true"); assertThat( OptionsToConfigBuilder.mysqlSetCursorModeIfNeeded( SQLDialect.POSTGRESQL, "jdbc:mysql://localhost:3306/testDB?useSSL=true", 42)) .isEqualTo("jdbc:mysql://localhost:3306/testDB?useSSL=true"); } + + @Test + public void testExtractWorkerZone() { + DataflowPipelineWorkerPoolOptions mockOptions = + Mockito.mock(DataflowPipelineWorkerPoolOptions.class); + Mockito.when(mockOptions.getWorkerZone()).thenReturn("us-central1-a"); + Mockito.when(mockOptions.as(DataflowPipelineWorkerPoolOptions.class)).thenReturn(mockOptions); + + String workerZone = OptionsToConfigBuilder.extractWorkerZone(mockOptions); + assertThat(workerZone).isEqualTo("us-central1-a"); + } + + @Test + public void testExtractWorkerZoneException() { + DataflowPipelineWorkerPoolOptions mockOptions = + Mockito.mock(DataflowPipelineWorkerPoolOptions.class); + Mockito.when(mockOptions.as(DataflowPipelineWorkerPoolOptions.class)) + .thenThrow(new RuntimeException("Test Exception")); + + String workerZone = OptionsToConfigBuilder.extractWorkerZone(mockOptions); + assertThat(workerZone).isNull(); + } + + @Test + public void testFetchSizeMinusOneBehavesLikeNull() { + SourceDbToSpannerOptions options = PipelineOptionsFactory.as(SourceDbToSpannerOptions.class); + options.setSourceDbDialect(SQLDialect.MYSQL.name()); + options.setSourceConfigURL("jdbc:mysql://localhost:3306/testDB"); + options.setJdbcDriverClassName("com.mysql.jdbc.Driver"); + options.setFetchSize(-1); // Should be normalized to null + + JdbcIOWrapperConfig config = + OptionsToConfigBuilder.getJdbcIOWrapperConfigWithDefaults( + options, List.of("table1"), null, null); + + assertThat(config.maxFetchSize()).isNull(); + } + + @Test + public void testMySqlCursorModeEnabledForNullFetchSize() { + String url = "jdbc:mysql://localhost:3306/testDB"; + String updatedUrl = + OptionsToConfigBuilder.mysqlSetCursorModeIfNeeded(SQLDialect.MYSQL, url, null); + assertThat(updatedUrl).isEqualTo(url + "?useCursorFetch=true"); + } + + @Test + public void testMySqlCursorModeEnabledForMinusOneFetchSize() { + // Note: In the builder, -1 is normalized to null BEFORE calling + // mysqlSetCursorModeIfNeeded, + // but here we test the method directly. If we pass -1 directly (if it were + // possible), + // it would be treated as != 0, so it would enable cursor mode. + // However, the main propagation test testFetchSizeMinusOneBehavesLikeNull + // covers the normalization. + String url = "jdbc:mysql://localhost:3306/testDB"; + String updatedUrl = + OptionsToConfigBuilder.mysqlSetCursorModeIfNeeded(SQLDialect.MYSQL, url, -1); + assertThat(updatedUrl).isEqualTo(url + "?useCursorFetch=true"); + } + + @Test + public void testMySqlCursorModeDisabledForZeroFetchSize() { + String url = "jdbc:mysql://localhost:3306/testDB"; + String updatedUrl = OptionsToConfigBuilder.mysqlSetCursorModeIfNeeded(SQLDialect.MYSQL, url, 0); + assertThat(updatedUrl).isEqualTo(url); // No change + } + + @Test + public void testMySqlCursorModeEnabledForPositiveFetchSize() { + String url = "jdbc:mysql://localhost:3306/testDB"; + String updatedUrl = + OptionsToConfigBuilder.mysqlSetCursorModeIfNeeded(SQLDialect.MYSQL, url, 100); + assertThat(updatedUrl).isEqualTo(url + "?useCursorFetch=true"); + } } diff --git a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/dialectadapter/DialectAdapterTest.java b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/dialectadapter/DialectAdapterTest.java index 8a5180f0cb..be3d65cda9 100644 --- a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/dialectadapter/DialectAdapterTest.java +++ b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/dialectadapter/DialectAdapterTest.java @@ -21,11 +21,17 @@ import com.google.cloud.teleport.v2.source.reader.io.cassandra.schema.CassandraSchemaReference; import com.google.cloud.teleport.v2.source.reader.io.jdbc.dialectadapter.mysql.MysqlDialectAdapter; import com.google.cloud.teleport.v2.source.reader.io.jdbc.dialectadapter.mysql.MysqlDialectAdapter.MySqlVersion; +import com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper.JdbcValueMappingsProvider; import com.google.cloud.teleport.v2.source.reader.io.schema.SourceSchemaReference; +import com.google.cloud.teleport.v2.source.reader.io.schema.SourceTableSchema; +import com.google.cloud.teleport.v2.spanner.migrations.schema.SourceColumnType; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.ArgumentMatchers; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; @RunWith(MockitoJUnitRunner.class) @@ -63,4 +69,38 @@ public void testMismatchedSource() { SourceSchemaReference.ofCassandra(mockCassandraSchemaReference), ImmutableList.of())); } + + @Test + public void testEstimateRowSize() { + // Create a mock for the interface and call real methods for default + // implementations + DialectAdapter dialectAdapter = Mockito.mock(DialectAdapter.class); + Mockito.doCallRealMethod() + .when(dialectAdapter) + .estimateRowSize( + ArgumentMatchers.any(SourceTableSchema.class), + ArgumentMatchers.any(JdbcValueMappingsProvider.class)); + Mockito.doCallRealMethod() + .when(dialectAdapter) + .estimateRowSize( + ArgumentMatchers.anyMap(), ArgumentMatchers.any(JdbcValueMappingsProvider.class)); + + // Mock SourceTableSchema + SourceTableSchema mockSourceTableSchema = Mockito.mock(SourceTableSchema.class); + SourceColumnType col1 = new SourceColumnType("col1", new Long[] {10L}, null); + SourceColumnType col2 = new SourceColumnType("col2", new Long[] {20L}, null); + + Mockito.when(mockSourceTableSchema.sourceColumnNameToSourceColumnType()) + .thenReturn(ImmutableMap.of("col1", col1, "col2", col2)); + + // Mock JdbcValueMappingsProvider + JdbcValueMappingsProvider mockProvider = Mockito.mock(JdbcValueMappingsProvider.class); + Mockito.when(mockProvider.estimateColumnSize(col1)).thenReturn(100); + Mockito.when(mockProvider.estimateColumnSize(col2)).thenReturn(200); + + // Verify + long expectedSize = 300L; + org.junit.Assert.assertEquals( + expectedSize, dialectAdapter.estimateRowSize(mockSourceTableSchema, mockProvider)); + } } diff --git a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/FetchSizeCalculatorTest.java b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/FetchSizeCalculatorTest.java new file mode 100644 index 0000000000..c0a9f7c55e --- /dev/null +++ b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/iowrapper/FetchSizeCalculatorTest.java @@ -0,0 +1,116 @@ +/* + * Copyright (C) 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.source.reader.io.jdbc.iowrapper; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.teleport.v2.source.reader.io.jdbc.iowrapper.config.TableConfig; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class FetchSizeCalculatorTest { + + private TableConfig tableConfig; + + @Before + public void setUp() { + tableConfig = TableConfig.builder("t1").setApproxRowCount(100L).build(); + } + + @Test + public void testGetFetchSize_NoMachineType() { + // Test when machine type (memory/cores) is not provided. + int fetchSize = FetchSizeCalculator.getFetchSize(tableConfig, 100L, null, null); + assertEquals(0, fetchSize); + } + + @Test + public void testGetFetchSize_ExplicitFetchSize() { + // Test when fetch size is explicitly configured in TableConfig. + TableConfig configWithFetchSize = TableConfig.builder("t1").setFetchSize(12345).build(); + int fetchSize = FetchSizeCalculator.getFetchSize(configWithFetchSize, 100L, null, null); + assertEquals(12345, fetchSize); + } + + @Test + public void testGetFetchSize_ZeroRowSize() { + // Test when estimated row size is 0. + int fetchSize = FetchSizeCalculator.getFetchSize(tableConfig, 0L, 16.0, 4); + assertEquals(0, fetchSize); + } + + @Test + public void testGetFetchSize_StandardCalculation() { + // Test standard calculation. + // Memory: 16 GB = 17,179,869,184 bytes + // Cores: 4 + // Row Size: 1000 bytes + // Denominator = 4 * 4 * 1000 = 16,000 + // Expected Fetch Size = 17,179,869,184 / 16,000 = 1,073,741 + + int fetchSize = FetchSizeCalculator.getFetchSize(tableConfig, 1000L, 16.0, 4); + assertEquals(1073741, fetchSize); + } + + @Test + public void testGetFetchSize_SmallRowSize() { + // Test with small row size (should result in large fetch size, capped at + // MAX_INTEGER if implemented, + // but here check calculation logic). + // Memory: 16 GB + // Cores: 4 + // Row Size: 10 bytes + // Denominator = 4 * 4 * 10 = 160 + // Expected Fetch Size = 17,179,869,184 / 160 = 107,374,182 + // Integer.MAX_VALUE = 2,147,483,647. Result is within integer range. + + int fetchSize = FetchSizeCalculator.getFetchSize(tableConfig, 10L, 16.0, 4); + assertEquals(107374182, fetchSize); + } + + @Test + public void testGetFetchSize_LargeRowSize() { + // Test with large row size. + // Memory: 16 GB + // Cores: 4 + // Row Size: 10 MB (10,485,760 bytes) + // Denominator = 4 * 4 * 10,485,760 = 167,772,160 + // Expected Fetch Size = 17,179,869,184 / 167,772,160 = 102 + + int fetchSize = FetchSizeCalculator.getFetchSize(tableConfig, 10485760L, 16.0, 4); + assertEquals(102, fetchSize); + } + + @Test + public void testGetFetchSize_ZeroCores() { + // Test when cores are 0 (should typically be handled by Utils returning null or + // >=1, but testing calculator logic). + // In this case, providing 0 cores. + int fetchSize = FetchSizeCalculator.getFetchSize(tableConfig, 100L, 16.0, 0); + assertEquals(0, fetchSize); + } + + @Test + public void testGetFetchSize_nullInputs() { + // Null memory + assertEquals(0, (int) FetchSizeCalculator.getFetchSize(tableConfig, 100L, null, 4)); + // Null cores + assertEquals(0, (int) FetchSizeCalculator.getFetchSize(tableConfig, 100L, 16.0, null)); + } +} diff --git a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/provider/MysqlJdbcValueMappingsTest.java b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/provider/MysqlJdbcValueMappingsTest.java new file mode 100644 index 0000000000..8728c0768a --- /dev/null +++ b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/provider/MysqlJdbcValueMappingsTest.java @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper.provider; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.teleport.v2.spanner.migrations.schema.SourceColumnType; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MysqlJdbcValueMappingsTest { + + @Test + public void testAllMappedTypesHaveRowSizeEstimate() { + MysqlJdbcValueMappings mappings = new MysqlJdbcValueMappings(); + for (String typeName : mappings.getMappings().keySet()) { + SourceColumnType sourceColumnType = new SourceColumnType(typeName, new Long[] {10L}, null); + int size = mappings.estimateColumnSize(sourceColumnType); + assertTrue("Row size estimate for type " + typeName + " should be > 0", size > 0); + } + } + + @Test + public void testUnknownTypeReturnsDefaultSize() { + MysqlJdbcValueMappings mappings = new MysqlJdbcValueMappings(); + SourceColumnType sourceColumnType = + new SourceColumnType("UNKNOWN_TYPE", new Long[] {10L}, null); + int size = mappings.estimateColumnSize(sourceColumnType); + assertEquals(65_535, size); + } +} diff --git a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/provider/PostgreSQLJdbcValueMappingsTest.java b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/provider/PostgreSQLJdbcValueMappingsTest.java new file mode 100644 index 0000000000..666519a054 --- /dev/null +++ b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/jdbc/rowmapper/provider/PostgreSQLJdbcValueMappingsTest.java @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.source.reader.io.jdbc.rowmapper.provider; + +import static org.junit.Assert.assertTrue; + +import com.google.cloud.teleport.v2.spanner.migrations.schema.SourceColumnType; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class PostgreSQLJdbcValueMappingsTest { + + @Test + public void testAllMappedTypesHaveRowSizeEstimate() { + PostgreSQLJdbcValueMappings mappings = new PostgreSQLJdbcValueMappings(); + for (String typeName : mappings.getMappings().keySet()) { + SourceColumnType sourceColumnType = new SourceColumnType(typeName, new Long[] {10L}, null); + int size = mappings.estimateColumnSize(sourceColumnType); + assertTrue("Row size estimate for type " + typeName + " should be > 0", size > 0); + } + } + + @Test + public void testUnknownTypeReturnsDefaultSize() { + PostgreSQLJdbcValueMappings mappings = new PostgreSQLJdbcValueMappings(); + SourceColumnType sourceColumnType = + new SourceColumnType("UNKNOWN_TYPE", new Long[] {10L}, null); + int size = mappings.estimateColumnSize(sourceColumnType); + assertTrue("Row size estimate for unknown type should be 65,535", size == 65_535); + } +} diff --git a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/schema/SchemaTestUtils.java b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/schema/SchemaTestUtils.java index ce2de13a96..cb3f943ac7 100644 --- a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/schema/SchemaTestUtils.java +++ b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/schema/SchemaTestUtils.java @@ -37,6 +37,7 @@ public static SourceTableSchema generateTestTableSchema(String tableName) { TEST_FIELD_NAME_1, new SourceColumnType("varchar", new Long[] {20L}, null)) .addSourceColumnNameToSourceColumnType( TEST_FIELD_NAME_2, new SourceColumnType("varchar", new Long[] {20L}, null)) + .setEstimatedRowSize(0L) .build(); } } diff --git a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/schema/SourceTableSchemaTest.java b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/schema/SourceTableSchemaTest.java index 3e9ce6ec9f..45d8e70bba 100644 --- a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/schema/SourceTableSchemaTest.java +++ b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/source/reader/io/schema/SourceTableSchemaTest.java @@ -64,7 +64,11 @@ public void testTableSchemaPreConditions() { // Miss Adding any fields to schema. Assert.assertThrows( java.lang.IllegalStateException.class, - () -> SourceTableSchema.builder(SQLDialect.MYSQL).setTableName(tableName).build()); + () -> + SourceTableSchema.builder(SQLDialect.MYSQL) + .setTableName(tableName) + .setEstimatedRowSize(0L) + .build()); } @Test diff --git a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/templates/PipelineControllerTest.java b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/templates/PipelineControllerTest.java index 8d38f6a4a2..917094f973 100644 --- a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/templates/PipelineControllerTest.java +++ b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/templates/PipelineControllerTest.java @@ -362,7 +362,9 @@ public void singleDbConfigContainerWithUrlTest() { List.of("table1", "table2"), Wait.on(dummyPCollection)); assertThat(config.jdbcDriverClassName()).isEqualTo(testDriverClassName); assertThat(config.sourceDbURL()) - .isEqualTo(testUrl + "?allowMultiQueries=true&autoReconnect=true&maxReconnects=10"); + .isEqualTo( + testUrl + + "?allowMultiQueries=true&autoReconnect=true&maxReconnects=10&useCursorFetch=true"); assertThat(config.tables()).containsExactlyElementsIn(new String[] {"table1", "table2"}); assertThat(config.dbAuth().getUserName().get()).isEqualTo(testUser); assertThat(config.dbAuth().getPassword().get()).isEqualTo(testPassword); @@ -410,7 +412,9 @@ public void shardedDbConfigContainerMySqlTest() { assertThat(config.jdbcDriverClassName()).isEqualTo(testDriverClassName); assertThat(config.sourceDbURL()) - .isEqualTo(testUrl + "?allowMultiQueries=true&autoReconnect=true&maxReconnects=10"); + .isEqualTo( + testUrl + + "?allowMultiQueries=true&autoReconnect=true&maxReconnects=10&useCursorFetch=true"); assertThat(config.tables()).containsExactlyElementsIn(new String[] {"table1", "table2"}); assertThat(config.dbAuth().getUserName().get()).isEqualTo(testUser); assertThat(config.dbAuth().getPassword().get()).isEqualTo(testPassword); diff --git a/v2/sourcedb-to-spanner/src/test/resources/DataTypesIT/mysql-data-types.sql b/v2/sourcedb-to-spanner/src/test/resources/DataTypesIT/mysql-data-types.sql index 569ea9e838..10fb30dd71 100644 --- a/v2/sourcedb-to-spanner/src/test/resources/DataTypesIT/mysql-data-types.sql +++ b/v2/sourcedb-to-spanner/src/test/resources/DataTypesIT/mysql-data-types.sql @@ -851,7 +851,7 @@ CREATE TABLE IF NOT EXISTS spatial_linestring ( ); INSERT INTO spatial_linestring (path) -VALUES (LineString(Point(77.5946, 12.9716), Point(77.6100, 12.9600))); +VALUES (LineString(Point(77.5946, 12.9716), Point(77.6100, 12.9600))); CREATE TABLE IF NOT EXISTS spatial_polygon ( id INT AUTO_INCREMENT PRIMARY KEY, diff --git a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/utils/DataflowWorkerMachineTypeUtils.java b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/utils/DataflowWorkerMachineTypeUtils.java new file mode 100644 index 0000000000..239d163473 --- /dev/null +++ b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/utils/DataflowWorkerMachineTypeUtils.java @@ -0,0 +1,178 @@ +/* + * Copyright (C) 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.spanner.migrations.utils; + +import com.google.cloud.compute.v1.MachineType; +import com.google.cloud.compute.v1.MachineTypesClient; +import com.google.common.base.Preconditions; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class DataflowWorkerMachineTypeUtils { + + private static final Logger LOG = LoggerFactory.getLogger(DataflowWorkerMachineTypeUtils.class); + private static final Map MACHINE_SPEC_CACHE = new ConcurrentHashMap<>(); + private static final String DEFAULT_ZONE = "us-central1-a"; + + public static Double getWorkerMemoryGB(String projectId, String zone, String workerMachineType) { + MachineSpec spec = getMachineSpec(projectId, zone, workerMachineType); + return spec != null ? spec.memoryGB : null; + } + + public static Integer getWorkerCores(String projectId, String zone, String workerMachineType) { + MachineSpec spec = getMachineSpec(projectId, zone, workerMachineType); + return spec != null ? spec.vCPUs : null; + } + + private static MachineSpec getMachineSpec( + String projectId, String zone, String workerMachineType) { + Preconditions.checkArgument( + workerMachineType != null && !StringUtils.isBlank(workerMachineType), + "workerMachineType cannot be null or empty."); + + // Check cache first + if (MACHINE_SPEC_CACHE.containsKey(workerMachineType)) { + return MACHINE_SPEC_CACHE.get(workerMachineType); + } + + // Fetch from Compute Engine API + MachineSpec apiSpec = fetchMachineSpecFromApi(projectId, zone, workerMachineType); + if (apiSpec != null) { + MACHINE_SPEC_CACHE.put(workerMachineType, apiSpec); + return apiSpec; + } + + return null; + } + + private static MachineSpec fetchMachineSpecFromApi( + String projectId, String zone, String workerMachineType) { + if (zone == null) { + LOG.warn("Could not determine Zone. Defaulting to {}.", DEFAULT_ZONE); + zone = DEFAULT_ZONE; + } + + if (projectId == null) { + LOG.warn("Could not determine Project ID. Cannot fetch machine type details from API."); + return null; + } + + try (MachineTypesClient client = MachineTypesClient.create()) { + // machineTypes.get() returns the resource or throws NotFoundException + // API documentation confirms custom types like custom-CPUS-MEM are supported + MachineType machineType = client.get(projectId, zone, workerMachineType); + + // machineType.getMemoryMb() is int, returns memory in MB + // machineType.getGuestCpus() is int + double memoryGB = machineType.getMemoryMb() / 1024.0; + int vCPUs = machineType.getGuestCpus(); + + LOG.info( + "Fetched machine type {} from API: {} vCPUs, {} GB RAM", + workerMachineType, + vCPUs, + memoryGB); + + return new MachineSpec(memoryGB, vCPUs); + } catch (Exception e) { + LOG.warn( + "Failed to fetch machine type '{}' from Compute Engine API (Project: {}, Zone: {}): {}", + workerMachineType, + projectId, + DEFAULT_ZONE, + e.getMessage()); + return null; + } + } + + private static class MachineSpec { + final double memoryGB; + final int vCPUs; + + MachineSpec(double memoryGB, int vCPUs) { + this.memoryGB = memoryGB; + this.vCPUs = vCPUs; + } + } + + public static void validateMachineSpecs(String workerMachineType, Integer minCPUs) { + Preconditions.checkArgument( + workerMachineType != null && !StringUtils.isBlank(workerMachineType), + "Policy Violation: You must specify a workerMachineType with at least %s vCPUs.", + minCPUs); + + // Handle custom machine types first, format is custom-{vCPU}-{RAM} + if (workerMachineType.startsWith("custom-")) { + String[] parts = workerMachineType.split("-"); + Preconditions.checkArgument( + parts.length == 3, + "Invalid custom machine type format: '%s'. Expected format: custom-{vCPU}-{RAM}.", + workerMachineType); + Integer vCpus = null; + try { + vCpus = Integer.parseInt(parts[1]); + } catch (NumberFormatException e) { + Preconditions.checkArgument( + false, "Invalid vCPU number in custom machine type: '%s'", workerMachineType); + } + Preconditions.checkArgument( + vCpus >= minCPUs, + "Policy Violation: Custom machine type '%s' has %s vCPUs. Minimum allowed is %s. Please use a higher machine type.", + workerMachineType, + vCpus, + minCPUs); + } else { + // Handle standard machine types. + java.util.regex.Pattern pattern = java.util.regex.Pattern.compile(".*-(\\d+)$"); + java.util.regex.Matcher matcher = pattern.matcher(workerMachineType); + + if (matcher.find()) { + Integer vCpus = null; + try { + vCpus = Integer.parseInt(matcher.group(1)); + } catch (NumberFormatException e) { + Preconditions.checkArgument( + false, "Invalid vCPU number in machine type: '%s'", workerMachineType); + } + Preconditions.checkArgument( + vCpus >= minCPUs, + "Policy Violation: Machine type '%s' has %s vCPUs. Minimum allowed is %s.", + workerMachineType, + vCpus, + minCPUs); + } else { + Preconditions.checkArgument( + false, + "Unknown machine type format: '%s'. Please use a standard machine type (e.g., n1-standard-4) or a custom machine type (e.g., custom-4-4096) with at least %s vCPUs.", + workerMachineType, + minCPUs); + } + } + } + + @com.google.common.annotations.VisibleForTesting + static void putMachineSpecForTesting(String machineType, double memoryGB, int vCPUs) { + MACHINE_SPEC_CACHE.put(machineType, new MachineSpec(memoryGB, vCPUs)); + } + + @com.google.common.annotations.VisibleForTesting + static void resetCacheForTesting() { + MACHINE_SPEC_CACHE.clear(); + } +} diff --git a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/utils/DataflowWorkerMachineTypeValidator.java b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/utils/DataflowWorkerMachineTypeValidator.java deleted file mode 100644 index 791a7e528f..0000000000 --- a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/utils/DataflowWorkerMachineTypeValidator.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (C) 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not - * use this file except in compliance with the License. You may obtain a copy of - * the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations under - * the License. - */ -package com.google.cloud.teleport.v2.spanner.migrations.utils; - -import com.google.common.base.Preconditions; -import org.apache.commons.lang3.StringUtils; - -public class DataflowWorkerMachineTypeValidator { - - public static void validateMachineSpecs(String workerMachineType, Integer minCPUs) { - Preconditions.checkArgument( - workerMachineType != null && !StringUtils.isBlank(workerMachineType), - "Policy Violation: You must specify a workerMachineType with at least %s vCPUs.", - minCPUs); - - // Handle custom machine types first, format is custom-{vCPU}-{RAM} - if (workerMachineType.startsWith("custom-")) { - String[] parts = workerMachineType.split("-"); - Preconditions.checkArgument( - parts.length == 3, - "Invalid custom machine type format: '%s'. Expected format: custom-{vCPU}-{RAM}.", - workerMachineType); - Integer vCpus = null; - try { - vCpus = Integer.parseInt(parts[1]); - } catch (NumberFormatException e) { - Preconditions.checkArgument( - false, "Invalid vCPU number in custom machine type: '%s'", workerMachineType); - } - Preconditions.checkArgument( - vCpus >= minCPUs, - "Policy Violation: Custom machine type '%s' has %s vCPUs. Minimum allowed is %s. Please use a higher machine type.", - workerMachineType, - vCpus, - minCPUs); - } else { - // Handle standard machine types. - java.util.regex.Pattern pattern = java.util.regex.Pattern.compile(".*-(\\d+)$"); - java.util.regex.Matcher matcher = pattern.matcher(workerMachineType); - - if (matcher.find()) { - Integer vCpus = null; - try { - vCpus = Integer.parseInt(matcher.group(1)); - } catch (NumberFormatException e) { - Preconditions.checkArgument( - false, "Invalid vCPU number in machine type: '%s'", workerMachineType); - } - Preconditions.checkArgument( - vCpus >= minCPUs, - "Policy Violation: Machine type '%s' has %s vCPUs. Minimum allowed is %s.", - workerMachineType, - vCpus, - minCPUs); - } else { - Preconditions.checkArgument( - false, - "Unknown machine type format: '%s'. Please use a standard machine type (e.g., n1-standard-4) or a custom machine type (e.g., custom-4-4096) with at least %s vCPUs.", - workerMachineType, - minCPUs); - } - } - } -} diff --git a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/utils/DataflowWorkerMachineTypeUtilsTest.java b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/utils/DataflowWorkerMachineTypeUtilsTest.java new file mode 100644 index 0000000000..17bad99adb --- /dev/null +++ b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/utils/DataflowWorkerMachineTypeUtilsTest.java @@ -0,0 +1,288 @@ +/* + * Copyright (C) 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.spanner.migrations.utils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; + +import com.google.cloud.compute.v1.MachineType; +import com.google.cloud.compute.v1.MachineTypesClient; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.MockedStatic; + +@RunWith(JUnit4.class) +public class DataflowWorkerMachineTypeUtilsTest { + + @Before + public void setUp() { + DataflowWorkerMachineTypeUtils.resetCacheForTesting(); + // Pre-populate cache to avoid API calls during tests + + // Standard types + DataflowWorkerMachineTypeUtils.putMachineSpecForTesting("n1-standard-4", 15.00, 4); + DataflowWorkerMachineTypeUtils.putMachineSpecForTesting("n1-standard-8", 30.00, 8); + DataflowWorkerMachineTypeUtils.putMachineSpecForTesting("n1-standard-96", 360.00, 96); + DataflowWorkerMachineTypeUtils.putMachineSpecForTesting("n1-highmem-8", 52.00, 8); + + // Custom types used in tests + DataflowWorkerMachineTypeUtils.putMachineSpecForTesting("custom-2-4096", 4.0, 2); + DataflowWorkerMachineTypeUtils.putMachineSpecForTesting("n2-custom-4-8192", 8.0, 4); + DataflowWorkerMachineTypeUtils.putMachineSpecForTesting("n2d-custom-2-2048", 2.0, 2); + DataflowWorkerMachineTypeUtils.putMachineSpecForTesting("e2-custom-2-4096", 4.0, 2); + DataflowWorkerMachineTypeUtils.putMachineSpecForTesting("n4-custom-32-131072", 128.0, 32); + DataflowWorkerMachineTypeUtils.putMachineSpecForTesting("n2-custom-4-8192-ext", 8.0, 4); + } + + @Test + public void testValidMachineType() { + DataflowWorkerMachineTypeUtils.validateMachineSpecs("n1-standard-4", 4); + } + + @Test + public void testValidMachineTypeHighCpu() { + DataflowWorkerMachineTypeUtils.validateMachineSpecs("n1-standard-8", 4); + } + + @Test + public void testInvalidMachineTypeLowCpu() { + assertThrows( + IllegalArgumentException.class, + () -> { + DataflowWorkerMachineTypeUtils.validateMachineSpecs("n1-standard-2", 4); + }); + } + + @Test + public void testNullMachineType() { + assertThrows( + IllegalArgumentException.class, + () -> { + DataflowWorkerMachineTypeUtils.validateMachineSpecs(null, 4); + }); + } + + @Test + public void testEmptyMachineType() { + assertThrows( + IllegalArgumentException.class, + () -> { + DataflowWorkerMachineTypeUtils.validateMachineSpecs(" ", 4); + }); + } + + @Test + public void testValidCustomMachineType() { + DataflowWorkerMachineTypeUtils.validateMachineSpecs("custom-8-12345", 4); + } + + @Test + public void testValidCustomMachineTypeMinCpu() { + DataflowWorkerMachineTypeUtils.validateMachineSpecs("custom-4-12345", 4); + } + + @Test + public void testInvalidCustomMachineTypeLowCpu() { + assertThrows( + IllegalArgumentException.class, + () -> { + DataflowWorkerMachineTypeUtils.validateMachineSpecs("custom-2-12345", 4); + }); + } + + @Test + public void testInvalidCustomMachineTypeFormat() { + assertThrows( + IllegalArgumentException.class, + () -> { + DataflowWorkerMachineTypeUtils.validateMachineSpecs("custom-2", 4); + }); + } + + @Test + public void testInvalidCustomMachineTypeNonNumericCpu() { + assertThrows( + IllegalArgumentException.class, + () -> { + DataflowWorkerMachineTypeUtils.validateMachineSpecs("custom-abc-12345", 4); + }); + } + + @Test + public void testUnknownMachineType() { + assertThrows( + IllegalArgumentException.class, + () -> { + DataflowWorkerMachineTypeUtils.validateMachineSpecs("unknown-machine-type", 4); + }); + } + + @Test + public void testGetWorkerMemoryGBStandard() { + assertEquals( + 15.00, + DataflowWorkerMachineTypeUtils.getWorkerMemoryGB(null, null, "n1-standard-4"), + 0.001); + assertEquals( + 52.00, DataflowWorkerMachineTypeUtils.getWorkerMemoryGB(null, null, "n1-highmem-8"), 0.001); + } + + @Test + public void testGetWorkerMemoryGBInvalid() { + // This will try to fetch from API and fail (return null) because it's not in + // cache and invalid/unknown project + assertNull(DataflowWorkerMachineTypeUtils.getWorkerMemoryGB(null, null, "unknown-machine")); + + // Custom types with invalid structure or not matching regex should return null + // or throw depending on usage + // logic in getWorkerMemoryGB checks getMachineSpec which checks + // tryParseCustomMachineType + assertNull(DataflowWorkerMachineTypeUtils.getWorkerMemoryGB(null, null, "custom-2-invalid")); + + assertThrows( + IllegalArgumentException.class, + () -> { + DataflowWorkerMachineTypeUtils.getWorkerMemoryGB(null, null, null); + }); + } + + @Test + public void testGetWorkerCoresStandard() { + assertEquals( + 4, (int) DataflowWorkerMachineTypeUtils.getWorkerCores(null, null, "n1-standard-4")); + assertEquals( + 8, (int) DataflowWorkerMachineTypeUtils.getWorkerCores(null, null, "n1-highmem-8")); + assertEquals( + 96, (int) DataflowWorkerMachineTypeUtils.getWorkerCores(null, null, "n1-standard-96")); + } + + @Test + public void testGetWorkerCoresInvalid() { + assertNull(DataflowWorkerMachineTypeUtils.getWorkerCores(null, null, "unknown-machine")); + assertNull( + DataflowWorkerMachineTypeUtils.getWorkerCores( + null, null, "custom-2-invalid")); // Invalid parsing + assertNull( + DataflowWorkerMachineTypeUtils.getWorkerCores( + null, null, "invalid-custom-2-1024")); // Invalid family + + assertThrows( + IllegalArgumentException.class, + () -> { + DataflowWorkerMachineTypeUtils.getWorkerCores(null, null, null); + }); + } + + @Test + public void testGetWorkerMemoryGBCustom() { + // custom-2-4096 => 4096MB = 4GB + assertEquals( + 4.0, DataflowWorkerMachineTypeUtils.getWorkerMemoryGB(null, null, "custom-2-4096"), 0.001); + // n2-custom-4-8192 => 8192MB = 8GB + assertEquals( + 8.0, + DataflowWorkerMachineTypeUtils.getWorkerMemoryGB(null, null, "n2-custom-4-8192"), + 0.001); + // n2d-custom-2-2048 => 2048MB = 2GB + assertEquals( + 2.0, + DataflowWorkerMachineTypeUtils.getWorkerMemoryGB(null, null, "n2d-custom-2-2048"), + 0.001); + // e2-custom-2-4096 => 4096MB = 4GB + assertEquals( + 4.0, + DataflowWorkerMachineTypeUtils.getWorkerMemoryGB(null, null, "e2-custom-2-4096"), + 0.001); + // n4-custom-32-131072 => 128GB + assertEquals( + 128.0, + DataflowWorkerMachineTypeUtils.getWorkerMemoryGB(null, null, "n4-custom-32-131072"), + 0.001); + // extended memory + assertEquals( + 8.0, + DataflowWorkerMachineTypeUtils.getWorkerMemoryGB(null, null, "n2-custom-4-8192-ext"), + 0.001); + } + + @Test + public void testGetWorkerCoresCustom() { + assertEquals( + 2, (int) DataflowWorkerMachineTypeUtils.getWorkerCores(null, null, "custom-2-4096")); + assertEquals( + 4, (int) DataflowWorkerMachineTypeUtils.getWorkerCores(null, null, "n2-custom-4-8192")); + assertEquals( + 2, (int) DataflowWorkerMachineTypeUtils.getWorkerCores(null, null, "n2d-custom-2-2048")); + assertEquals( + 32, (int) DataflowWorkerMachineTypeUtils.getWorkerCores(null, null, "n4-custom-32-131072")); + } + + @Test + public void testFetchMachineSpecFromApi_Success() { + String projectId = "test-project"; + String zone = "us-central1-a"; + String machineType = "n1-standard-new"; // Not in pre-populated cache + + try (MockedStatic mockedStaticClient = + mockStatic(MachineTypesClient.class)) { + MachineTypesClient mockClient = mock(MachineTypesClient.class); + mockedStaticClient.when(MachineTypesClient::create).thenReturn(mockClient); + + MachineType mockMachineType = mock(MachineType.class); + when(mockMachineType.getGuestCpus()).thenReturn(4); + when(mockMachineType.getMemoryMb()).thenReturn(15360); // 15 GB + + when(mockClient.get(anyString(), anyString(), anyString())).thenReturn(mockMachineType); + + Double memoryGB = + DataflowWorkerMachineTypeUtils.getWorkerMemoryGB(projectId, zone, machineType); + Integer vCpus = DataflowWorkerMachineTypeUtils.getWorkerCores(projectId, zone, machineType); + + assertEquals(15.0, memoryGB, 0.001); + assertEquals(4, (int) vCpus); + } + } + + @Test + public void testFetchMachineSpecFromApi_Failure() { + String projectId = "test-project"; + String zone = "us-central1-a"; + String machineType = "unknown-machine"; + + try (MockedStatic mockedStaticClient = + mockStatic(MachineTypesClient.class)) { + MachineTypesClient mockClient = mock(MachineTypesClient.class); + mockedStaticClient.when(MachineTypesClient::create).thenReturn(mockClient); + + when(mockClient.get(anyString(), anyString(), anyString())) + .thenThrow(new RuntimeException("API Error")); + + Double memoryGB = + DataflowWorkerMachineTypeUtils.getWorkerMemoryGB(projectId, zone, machineType); + Integer vCpus = DataflowWorkerMachineTypeUtils.getWorkerCores(projectId, zone, machineType); + + assertNull(memoryGB); + assertNull(vCpus); + } + } +} diff --git a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/utils/DataflowWorkerMachineTypeValidatorTest.java b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/utils/DataflowWorkerMachineTypeValidatorTest.java deleted file mode 100644 index 25ed9bee7b..0000000000 --- a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/utils/DataflowWorkerMachineTypeValidatorTest.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Copyright (C) 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not - * use this file except in compliance with the License. You may obtain a copy of - * the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations under - * the License. - */ -package com.google.cloud.teleport.v2.spanner.migrations.utils; - -import static org.junit.Assert.assertThrows; - -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public class DataflowWorkerMachineTypeValidatorTest { - - @Test - public void testValidMachineType() { - DataflowWorkerMachineTypeValidator.validateMachineSpecs("n1-standard-4", 4); - } - - @Test - public void testValidMachineTypeHighCpu() { - DataflowWorkerMachineTypeValidator.validateMachineSpecs("n1-standard-8", 4); - } - - @Test - public void testInvalidMachineTypeLowCpu() { - assertThrows( - IllegalArgumentException.class, - () -> { - DataflowWorkerMachineTypeValidator.validateMachineSpecs("n1-standard-2", 4); - }); - } - - @Test - public void testNullMachineType() { - assertThrows( - IllegalArgumentException.class, - () -> { - DataflowWorkerMachineTypeValidator.validateMachineSpecs(null, 4); - }); - } - - @Test - public void testEmptyMachineType() { - assertThrows( - IllegalArgumentException.class, - () -> { - DataflowWorkerMachineTypeValidator.validateMachineSpecs(" ", 4); - }); - } - - @Test - public void testValidCustomMachineType() { - DataflowWorkerMachineTypeValidator.validateMachineSpecs("custom-8-12345", 4); - } - - @Test - public void testValidCustomMachineTypeMinCpu() { - DataflowWorkerMachineTypeValidator.validateMachineSpecs("custom-4-12345", 4); - } - - @Test - public void testInvalidCustomMachineTypeLowCpu() { - assertThrows( - IllegalArgumentException.class, - () -> { - DataflowWorkerMachineTypeValidator.validateMachineSpecs("custom-2-12345", 4); - }); - } - - @Test - public void testInvalidCustomMachineTypeFormat() { - assertThrows( - IllegalArgumentException.class, - () -> { - DataflowWorkerMachineTypeValidator.validateMachineSpecs("custom-2", 4); - }); - } - - @Test - public void testInvalidCustomMachineTypeNonNumericCpu() { - assertThrows( - IllegalArgumentException.class, - () -> { - DataflowWorkerMachineTypeValidator.validateMachineSpecs("custom-abc-12345", 4); - }); - } - - @Test - public void testUnknownMachineType() { - assertThrows( - IllegalArgumentException.class, - () -> { - DataflowWorkerMachineTypeValidator.validateMachineSpecs("unknown-machine-type", 4); - }); - } -} diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java index 5c5f12b45a..4313f6f384 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java @@ -37,7 +37,7 @@ import com.google.cloud.teleport.v2.spanner.migrations.transformation.CustomTransformation; import com.google.cloud.teleport.v2.spanner.migrations.utils.CassandraConfigFileReader; import com.google.cloud.teleport.v2.spanner.migrations.utils.CassandraDriverConfigLoader; -import com.google.cloud.teleport.v2.spanner.migrations.utils.DataflowWorkerMachineTypeValidator; +import com.google.cloud.teleport.v2.spanner.migrations.utils.DataflowWorkerMachineTypeUtils; import com.google.cloud.teleport.v2.spanner.migrations.utils.SecretManagerAccessorImpl; import com.google.cloud.teleport.v2.spanner.migrations.utils.ShardFileReader; import com.google.cloud.teleport.v2.spanner.sourceddl.CassandraInformationSchemaScanner; @@ -589,7 +589,7 @@ public static PipelineResult run(Options options) { String workerMachineType = pipeline.getOptions().as(DataflowPipelineWorkerPoolOptions.class).getWorkerMachineType(); - DataflowWorkerMachineTypeValidator.validateMachineSpecs(workerMachineType, 4); + DataflowWorkerMachineTypeUtils.validateMachineSpecs(workerMachineType, 4); // Prepare Spanner config SpannerConfig spannerConfig =