diff --git a/src/main/java/io/cdap/plugin/snowflake/actions/loadunload/LoadUnloadConfig.java b/src/main/java/io/cdap/plugin/snowflake/actions/loadunload/LoadUnloadConfig.java index 50ca000..4b2d0f7 100644 --- a/src/main/java/io/cdap/plugin/snowflake/actions/loadunload/LoadUnloadConfig.java +++ b/src/main/java/io/cdap/plugin/snowflake/actions/loadunload/LoadUnloadConfig.java @@ -136,12 +136,13 @@ public abstract class LoadUnloadConfig extends BaseSnowflakeConfig { public LoadUnloadConfig(String accountName, String database, - String schemaName, String username, String password, + String schemaName, String tableName, String username, String password, @Nullable Boolean keyPairEnabled, @Nullable String path, @Nullable String passphrase, @Nullable Boolean oauth2Enabled, @Nullable String clientId, @Nullable String clientSecret, @Nullable String refreshToken, @Nullable String connectionArguments) { - super(accountName, database, schemaName, username, password, keyPairEnabled, path, passphrase, oauth2Enabled, + super(accountName, database, schemaName, tableName, username, password, keyPairEnabled, path, passphrase, + oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments); } diff --git a/src/main/java/io/cdap/plugin/snowflake/actions/loadunload/load/LoadActionConfig.java b/src/main/java/io/cdap/plugin/snowflake/actions/loadunload/load/LoadActionConfig.java index bcd4724..d14bc70 100644 --- a/src/main/java/io/cdap/plugin/snowflake/actions/loadunload/load/LoadActionConfig.java +++ b/src/main/java/io/cdap/plugin/snowflake/actions/loadunload/load/LoadActionConfig.java @@ -71,11 +71,13 @@ public class LoadActionConfig extends LoadUnloadConfig { @Nullable private String pattern; - public LoadActionConfig(String accountName, String database, String schemaName, String username, String password, + public LoadActionConfig(String accountName, String database, String schemaName, String tableName, + String username, String password, @Nullable Boolean keyPairEnabled, @Nullable String path, @Nullable String passphrase, @Nullable Boolean oauth2Enabled, @Nullable String clientId, @Nullable String clientSecret, @Nullable String refreshToken, @Nullable String connectionArguments) { - super(accountName, database, schemaName, username, password, keyPairEnabled, path, passphrase, oauth2Enabled, + super(accountName, database, tableName, schemaName, username, password, keyPairEnabled, path, passphrase, + oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments); } diff --git a/src/main/java/io/cdap/plugin/snowflake/actions/loadunload/unload/UnloadActionConfig.java b/src/main/java/io/cdap/plugin/snowflake/actions/loadunload/unload/UnloadActionConfig.java index 5909861..90f6ed3 100644 --- a/src/main/java/io/cdap/plugin/snowflake/actions/loadunload/unload/UnloadActionConfig.java +++ b/src/main/java/io/cdap/plugin/snowflake/actions/loadunload/unload/UnloadActionConfig.java @@ -52,11 +52,13 @@ public class UnloadActionConfig extends LoadUnloadConfig { private Boolean includeHeader; - public UnloadActionConfig(String accountName, String database, String schemaName, String username, String password, + public UnloadActionConfig(String accountName, String database, String schemaName, String tableName, String username, + String password, @Nullable Boolean keyPairEnabled, @Nullable String path, @Nullable String passphrase, @Nullable Boolean oauth2Enabled, @Nullable String clientId, @Nullable String clientSecret, @Nullable String refreshToken, @Nullable String connectionArguments) { - super(accountName, database, schemaName, username, password, keyPairEnabled, path, passphrase, oauth2Enabled, + super(accountName, database, schemaName, tableName, username, password, keyPairEnabled, path, + passphrase, oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments); } diff --git a/src/main/java/io/cdap/plugin/snowflake/actions/sql/RunSQLConfig.java b/src/main/java/io/cdap/plugin/snowflake/actions/sql/RunSQLConfig.java index 2fce83c..3194b03 100644 --- a/src/main/java/io/cdap/plugin/snowflake/actions/sql/RunSQLConfig.java +++ b/src/main/java/io/cdap/plugin/snowflake/actions/sql/RunSQLConfig.java @@ -33,11 +33,13 @@ public class RunSQLConfig extends BaseSnowflakeConfig { @Macro private String query; - public RunSQLConfig(String accountName, String database, String schemaName, String username, String password, + public RunSQLConfig(String accountName, String database, String schemaName, String tableName, String username, + String password, @Nullable Boolean keyPairEnabled, @Nullable String path, @Nullable String passphrase, @Nullable Boolean oauth2Enabled, @Nullable String clientId, @Nullable String clientSecret, @Nullable String refreshToken, @Nullable String connectionArguments) { - super(accountName, database, schemaName, username, password, keyPairEnabled, path, passphrase, oauth2Enabled, + super(accountName, database, schemaName, tableName, username, password, keyPairEnabled, path, passphrase, + oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments); } diff --git a/src/main/java/io/cdap/plugin/snowflake/common/BaseSnowflakeConfig.java b/src/main/java/io/cdap/plugin/snowflake/common/BaseSnowflakeConfig.java index 7967cfd..b348c53 100644 --- a/src/main/java/io/cdap/plugin/snowflake/common/BaseSnowflakeConfig.java +++ b/src/main/java/io/cdap/plugin/snowflake/common/BaseSnowflakeConfig.java @@ -35,6 +35,7 @@ public class BaseSnowflakeConfig extends PluginConfig { public static final String PROPERTY_ACCOUNT_NAME = "accountName"; public static final String PROPERTY_DATABASE = "database"; public static final String PROPERTY_SCHEMA_NAME = "schemaName"; + public static final String PROPERTY_TABLE_NAME = "TableName"; public static final String PROPERTY_WAREHOUSE = "warehouse"; public static final String PROPERTY_ROLE = "role"; public static final String PROPERTY_USERNAME = "username"; @@ -63,6 +64,13 @@ public class BaseSnowflakeConfig extends PluginConfig { @Macro private String schemaName; + @Name(PROPERTY_TABLE_NAME) + @Description("Name of the table to import data from. If specified, importQuery will be ignored.") + @Macro + @Nullable + private String tableName; + + @Nullable @Name(PROPERTY_WAREHOUSE) @Description("Warehouse to connect to. If not specified default warehouse is used.") @@ -87,6 +95,7 @@ public class BaseSnowflakeConfig extends PluginConfig { @Nullable private String password; + @Name(PROPERTY_KEY_PAIR_ENABLED) @Description("If true, plugin will perform Key Pair authentication.") @Nullable @@ -136,6 +145,7 @@ public class BaseSnowflakeConfig extends PluginConfig { public BaseSnowflakeConfig(String accountName, String database, String schemaName, + String tableName, String username, String password, @Nullable Boolean keyPairEnabled, @@ -150,6 +160,7 @@ public BaseSnowflakeConfig(String accountName, this.database = database; this.schemaName = schemaName; this.username = username; + this.tableName = tableName; this.password = password; this.keyPairEnabled = keyPairEnabled; this.privateKey = privateKey; @@ -161,6 +172,7 @@ public BaseSnowflakeConfig(String accountName, this.connectionArguments = connectionArguments; } + public String getAccountName() { return accountName; } @@ -173,6 +185,11 @@ public String getSchemaName() { return schemaName; } + @Nullable + public String getTableName() { + return tableName; + } + @Nullable public String getWarehouse() { return warehouse; diff --git a/src/main/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessor.java b/src/main/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessor.java index 911f696..9ef092d 100644 --- a/src/main/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessor.java +++ b/src/main/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessor.java @@ -36,18 +36,24 @@ import java.io.FileWriter; import java.io.IOException; import java.lang.reflect.Field; +//import java.sql.*; import java.sql.Connection; +import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; + import java.util.ArrayList; import java.util.List; import java.util.Properties; /** - * A class which accesses Snowflake API. + * Establishes a connection to Snowflake using BaseSnowflakeConfig. + * Initializes and configures SnowflakeBasicDataSource. + * Sets application name (CDAP) and row limit (LIMIT_ROWS). */ + public class SnowflakeAccessor { private static final String APPLICATION_NAME = "CDAP"; private static final int LIMIT_ROWS = 1; @@ -61,6 +67,10 @@ public SnowflakeAccessor(BaseSnowflakeConfig config) { initDataSource(dataSource, config); } + /** + * A class which will help in connection + */ + public void runSQL(String query) { try (Connection connection = dataSource.getConnection(); PreparedStatement populateStmt = connection.prepareStatement(query);) { @@ -105,6 +115,32 @@ public List describeQuery(String query) throws IOExcep return fieldDescriptors; } + /** + * Returns field descriptors for specified tableName. + * + * @return List of field descriptors. + * @throws IOException thrown if there are any issue with the I/O operations. + */ + + public List describeTable(String schemaName, String tableName) throws SQLException { + List fieldDescriptors = new ArrayList<>(); + + try (Connection connection = dataSource.getConnection()) { + DatabaseMetaData dbMetaData = connection.getMetaData(); + + try (ResultSet columns = dbMetaData.getColumns(null, schemaName, tableName, null)) { + while (columns.next()) { + String columnName = columns.getString("COLUMN_NAME"); + int columnType = columns.getInt("DATA_TYPE"); + boolean nullable = columns.getInt("NULLABLE") == DatabaseMetaData.columnNullable; + + fieldDescriptors.add(new SnowflakeFieldDescriptor(columnName, columnType, nullable)); + } + } + } + return fieldDescriptors; + } + private void initDataSource(SnowflakeBasicDataSource dataSource, BaseSnowflakeConfig config) { dataSource.setDatabaseName(config.getDatabase()); dataSource.setSchema(config.getSchemaName()); @@ -193,4 +229,8 @@ private static String writeTextToTmpFile(String text) { throw new RuntimeException("Cannot write key to temporary file", e); } } + + public String getSchema() { + return config.getSchemaName(); + } } diff --git a/src/main/java/io/cdap/plugin/snowflake/common/util/SchemaHelper.java b/src/main/java/io/cdap/plugin/snowflake/common/util/SchemaHelper.java index 9302eae..dded0bf 100644 --- a/src/main/java/io/cdap/plugin/snowflake/common/util/SchemaHelper.java +++ b/src/main/java/io/cdap/plugin/snowflake/common/util/SchemaHelper.java @@ -27,6 +27,7 @@ import io.cdap.plugin.snowflake.source.batch.SnowflakeInputFormatProvider; import io.cdap.plugin.snowflake.source.batch.SnowflakeSourceAccessor; import java.io.IOException; +import java.sql.SQLException; import java.sql.Types; import java.util.List; import java.util.Map; @@ -64,18 +65,19 @@ public static Schema getSchema(SnowflakeBatchSourceConfig config, FailureCollect } SnowflakeSourceAccessor snowflakeSourceAccessor = - new SnowflakeSourceAccessor(config, SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR); - return getSchema(snowflakeSourceAccessor, config.getSchema(), collector, config.getImportQuery()); + new SnowflakeSourceAccessor(config, SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR); + return getSchema(snowflakeSourceAccessor, config.getSchema(), collector, config.getTableName(), + config.getImportQuery()); } public static Schema getSchema(SnowflakeSourceAccessor snowflakeAccessor, String schema, - FailureCollector collector, String importQuery) { + FailureCollector collector, String tableName, String importQuery) { try { if (!Strings.isNullOrEmpty(schema)) { return getParsedSchema(schema); } - return Strings.isNullOrEmpty(importQuery) ? null : getSchema(snowflakeAccessor, importQuery); - } catch (SchemaParseException e) { + return getSchema(snowflakeAccessor, tableName, importQuery); + } catch (SchemaParseException | IllegalArgumentException e) { collector.addFailure(String.format("Unable to retrieve output schema. Reason: '%s'", e.getMessage()), null) .withStacktrace(e.getStackTrace()) @@ -95,15 +97,26 @@ private static Schema getParsedSchema(String schema) { } } - public static Schema getSchema(SnowflakeAccessor snowflakeAccessor, String importQuery) { + public static Schema getSchema(SnowflakeAccessor snowflakeAccessor, + String tableName, String importQuery) { try { - List result = snowflakeAccessor.describeQuery(importQuery); + List result; + // If tableName is provided, describe the table + if (!Strings.isNullOrEmpty(tableName)) { + result = snowflakeAccessor.describeTable(snowflakeAccessor.getSchema(), tableName); + } else { + result = snowflakeAccessor.describeQuery(importQuery); + } + List fields = result.stream() - .map(fieldDescriptor -> Schema.Field.of(fieldDescriptor.getName(), getSchema(fieldDescriptor))) - .collect(Collectors.toList()); + .map(fieldDescriptor -> Schema.Field.of(fieldDescriptor.getName(), + getSchema(fieldDescriptor))) + .collect(Collectors.toList()); return Schema.recordOf("data", fields); - } catch (IOException e) { + } catch (SQLException e) { throw new SchemaParseException(e); + } catch (IOException e) { + throw new RuntimeException(e); } } diff --git a/src/main/java/io/cdap/plugin/snowflake/sink/batch/SnowflakeSinkConfig.java b/src/main/java/io/cdap/plugin/snowflake/sink/batch/SnowflakeSinkConfig.java index 4ae2291..75e7c52 100644 --- a/src/main/java/io/cdap/plugin/snowflake/sink/batch/SnowflakeSinkConfig.java +++ b/src/main/java/io/cdap/plugin/snowflake/sink/batch/SnowflakeSinkConfig.java @@ -58,12 +58,12 @@ public class SnowflakeSinkConfig extends BaseSnowflakeConfig { private String copyOptions; public SnowflakeSinkConfig(String referenceName, String accountName, String database, - String schemaName, String username, String password, + String schemaName, String tableName, String username, String password, @Nullable Boolean keyPairEnabled, @Nullable String path, @Nullable String passphrase, @Nullable Boolean oauth2Enabled, @Nullable String clientId, @Nullable String clientSecret, @Nullable String refreshToken, @Nullable String connectionArguments) { - super(accountName, database, schemaName, username, password, + super(accountName, database, schemaName, tableName, username, password, keyPairEnabled, path, passphrase, oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments); this.referenceName = referenceName; } @@ -104,8 +104,8 @@ private void validateInputSchema(Schema schema, FailureCollector failureCollecto } SnowflakeAccessor snowflakeAccessor = new SnowflakeAccessor(this); - Schema expectedSchema = SchemaHelper.getSchema(snowflakeAccessor, String.format(GET_FIELDS_QUERY, tableName)); - +// Schema expectedSchema = SchemaHelper.getSchema(snowflakeAccessor, String.format(GET_FIELDS_QUERY, tableName)); + Schema expectedSchema = SchemaHelper.getSchema(snowflakeAccessor, tableName, null); try { SchemaHelper.checkCompatibility(expectedSchema, schema); } catch (IllegalArgumentException ex) { diff --git a/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfig.java b/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfig.java index 561a10f..603d644 100644 --- a/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfig.java +++ b/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfig.java @@ -16,6 +16,7 @@ package io.cdap.plugin.snowflake.source.batch; +import com.google.common.base.Strings; import io.cdap.cdap.api.annotation.Description; import io.cdap.cdap.api.annotation.Macro; import io.cdap.cdap.api.annotation.Name; @@ -34,6 +35,7 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig { public static final String PROPERTY_IMPORT_QUERY = "importQuery"; public static final String PROPERTY_MAX_SPLIT_SIZE = "maxSplitSize"; public static final String PROPERTY_SCHEMA = "schema"; +// public static final String PROPERTY_TABLE_NAME = "tableName"; @Name(PROPERTY_REFERENCE_NAME) @Description("This will be used to uniquely identify this source/sink for lineage, annotating metadata, etc.") @@ -42,8 +44,10 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig { @Name(PROPERTY_IMPORT_QUERY) @Description("Query for import data.") @Macro + @Nullable private String importQuery; + @Name(PROPERTY_MAX_SPLIT_SIZE) @Description("Maximum split size specified in bytes.") @Macro @@ -55,17 +59,20 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig { @Macro private String schema; + public SnowflakeBatchSourceConfig(String referenceName, String accountName, String database, - String schemaName, String importQuery, String username, String password, + String schemaName, @Nullable String importQuery, @Nullable String tableName, + String username, String password, @Nullable Boolean keyPairEnabled, @Nullable String path, @Nullable String passphrase, @Nullable Boolean oauth2Enabled, @Nullable String clientId, @Nullable String clientSecret, @Nullable String refreshToken, Long maxSplitSize, @Nullable String connectionArguments, @Nullable String schema) { - super(accountName, database, schemaName, username, password, + super(accountName, database, schemaName, tableName, username, password, keyPairEnabled, path, passphrase, oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments); this.referenceName = referenceName; this.importQuery = importQuery; +// this.tableName = tableName; this.maxSplitSize = maxSplitSize; this.schema = schema; } @@ -74,6 +81,11 @@ public String getImportQuery() { return importQuery; } +// @Nullable +// public String getTableName() { +// return tableName; +// } + public Long getMaxSplitSize() { return maxSplitSize; } @@ -89,11 +101,14 @@ public String getSchema() { public void validate(FailureCollector collector) { super.validate(collector); - - if (!containsMacro(PROPERTY_MAX_SPLIT_SIZE) && Objects.nonNull(maxSplitSize) - && maxSplitSize < 0) { - collector.addFailure("Maximum Slit Size cannot be a negative number.", null) - .withConfigProperty(PROPERTY_MAX_SPLIT_SIZE); + if (!containsMacro(PROPERTY_IMPORT_QUERY) && !containsMacro(PROPERTY_TABLE_NAME)) { + if (Strings.isNullOrEmpty(getTableName()) && Strings.isNullOrEmpty(importQuery)) { + collector.addFailure("Both importQuery and tableName cannot be NULL at the same time.", + "Provide either an importQuery or a tableName.") + .withConfigProperty(PROPERTY_IMPORT_QUERY) + .withConfigProperty(PROPERTY_TABLE_NAME); + } } + } } diff --git a/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeSourceAccessor.java b/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeSourceAccessor.java index 202be53..058e8ba 100644 --- a/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeSourceAccessor.java +++ b/src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeSourceAccessor.java @@ -17,6 +17,7 @@ package io.cdap.plugin.snowflake.source.batch; import au.com.bytecode.opencsv.CSVReader; +import com.google.common.base.Strings; import io.cdap.plugin.snowflake.common.SnowflakeErrorType; import io.cdap.plugin.snowflake.common.client.SnowflakeAccessor; import io.cdap.plugin.snowflake.common.util.DocumentUrlUtil; @@ -77,7 +78,11 @@ public SnowflakeSourceAccessor(SnowflakeBatchSourceConfig config, String escapeC */ public List prepareStageSplits() { LOG.info("Loading data into stage: '{}'", STAGE_PATH); - String copy = String.format(COMAND_COPY_INTO, QueryUtil.removeSemicolon(config.getImportQuery())); + String importQuery = config.getImportQuery(); + if (Strings.isNullOrEmpty(importQuery)) { + importQuery = "SELECT * FROM " + config.getTableName(); + } + String copy = String.format(COMAND_COPY_INTO, QueryUtil.removeSemicolon(importQuery)); if (config.getMaxSplitSize() > 0) { copy = copy + String.format(COMMAND_MAX_FILE_SIZE, config.getMaxSplitSize()); } diff --git a/src/test/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessorTest.java b/src/test/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessorTest.java index 27b0154..e4832a4 100644 --- a/src/test/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessorTest.java +++ b/src/test/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessorTest.java @@ -90,6 +90,21 @@ public void testDescribeQuery() throws Exception { Assert.assertEquals(expected, actual); } + @Test + public void testDescribeTable() throws Exception { + String schemaName = "TEST_SCHEMA"; + String tableName = "TEST_TABLE"; + + List actual = snowflakeAccessor.describeTable(schemaName, tableName); + + Assert.assertNotNull(actual); + Assert.assertFalse(actual.isEmpty()); + // Optionally, verify a known column exists + boolean containsExpectedColumn = actual.stream() + .anyMatch(field -> "COLUMN_NAME".equalsIgnoreCase(field.getName())); + Assert.assertTrue("Expected column is not found in the table description", containsExpectedColumn); + } + @Test public void testPrepareStageSplits() throws Exception { Pattern expected = Pattern.compile("cdap_stage/result.*data__0_0_0\\.csv\\.gz"); diff --git a/src/test/java/io/cdap/plugin/snowflake/common/util/SchemaHelperTest.java b/src/test/java/io/cdap/plugin/snowflake/common/util/SchemaHelperTest.java index 63bf0af..b0689cc 100644 --- a/src/test/java/io/cdap/plugin/snowflake/common/util/SchemaHelperTest.java +++ b/src/test/java/io/cdap/plugin/snowflake/common/util/SchemaHelperTest.java @@ -27,18 +27,23 @@ import org.mockito.Mockito; import java.io.IOException; +import java.sql.SQLException; import java.sql.Types; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import static net.snowflake.client.loader.LoaderProperty.tableName; + /** * Tests for {@link SchemaHelper} */ public class SchemaHelperTest { private static final String MOCK_STAGE = "mockStage"; + private static final String MOCK_SCHEMA = "mockSchema"; + private static final String MOCK_TABLE = "mockTable"; @Test public void testGetSchema() { @@ -48,7 +53,11 @@ public void testGetSchema() { ); MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE); - Schema actual = SchemaHelper.getSchema(null, expected.toString(), collector, null); +// SnowflakeBatchSourceConfig mockConfig = Mockito.mock(SnowflakeBatchSourceConfig.class); +// Mockito.when(mockConfig.canConnect()).thenReturn(false); +// Mockito.when(mockConfig.getSchema()).thenReturn(expected.toString()); + Schema actual = SchemaHelper.getSchema(null, expected.toString(), collector, null, + null); Assert.assertTrue(collector.getValidationFailures().isEmpty()); Assert.assertEquals(expected, actual); @@ -57,8 +66,10 @@ public void testGetSchema() { @Test public void testGetSchemaInvalidJson() { MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE); - SchemaHelper.getSchema(null, "{}", collector, null); - + SnowflakeBatchSourceConfig mockConfig = Mockito.mock(SnowflakeBatchSourceConfig.class); + Mockito.when(mockConfig.getSchema()).thenReturn("{}"); + SchemaHelper.getSchema(null, "{}", collector, null, null); +// SchemaHelper.getSchema(mockConfig, collector); ValidationAssertions.assertValidationFailed( collector, Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_SCHEMA)); } @@ -73,16 +84,16 @@ public void testGetSchemaFromSnowflakeUnknownType() throws IOException { sample.add(new SnowflakeFieldDescriptor("field1", -1000, false)); Mockito.when(snowflakeAccessor.describeQuery(importQuery)).thenReturn(sample); - - SchemaHelper.getSchema(snowflakeAccessor, null, collector, importQuery); - + String tableName = "tableName"; + SchemaHelper.getSchema(snowflakeAccessor, null, collector, tableName, importQuery); ValidationAssertions.assertValidationFailed( collector, Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_SCHEMA)); } @Test - public void testGetSchemaFromSnowflake() throws IOException { + public void testGetSchemaFromSnowflake() throws IOException, SQLException { String importQuery = "SELECT * FROM someTable"; + String tableName = "tableName"; MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE); SnowflakeSourceAccessor snowflakeAccessor = Mockito.mock(SnowflakeSourceAccessor.class); @@ -141,12 +152,16 @@ public void testGetSchemaFromSnowflake() throws IOException { Schema.Field.of("field131", Schema.nullableOf(Schema.decimalOf(38))), Schema.Field.of("field132", Schema.decimalOf(38)), Schema.Field.of("field133", Schema.of(Schema.LogicalType.TIMESTAMP_MICROS)), - Schema.Field.of("field134", Schema.nullableOf(Schema.of(Schema.LogicalType.TIMESTAMP_MICROS))) + Schema.Field.of("field134", Schema.nullableOf(Schema.of(Schema.LogicalType.TIMESTAMP_MICROS))) + ); Mockito.when(snowflakeAccessor.describeQuery(importQuery)).thenReturn(sample); + Mockito.when(snowflakeAccessor.describeTable(Mockito.any(), String.valueOf (Mockito.eq(tableName)))). + thenReturn(sample); + - Schema actual = SchemaHelper.getSchema(snowflakeAccessor, null, collector, importQuery); + Schema actual = SchemaHelper.getSchema(snowflakeAccessor, null, collector, tableName, importQuery); Assert.assertTrue(collector.getValidationFailures().isEmpty()); Assert.assertEquals(expected, actual); diff --git a/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigBuilder.java b/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigBuilder.java index 7f2f035..518d792 100644 --- a/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigBuilder.java +++ b/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigBuilder.java @@ -27,8 +27,9 @@ public class SnowflakeBatchSourceConfigBuilder { "database", "schemaName", "importQuery", - "username", - "password", + "tableName", + "userName", + "Password", false, "", "", @@ -45,6 +46,7 @@ public class SnowflakeBatchSourceConfigBuilder { private String database; private String schemaName; private String importQuery; + private String tableName; private String username; private String password; private Boolean keyPairEnabled; @@ -67,6 +69,7 @@ public SnowflakeBatchSourceConfigBuilder(SnowflakeBatchSourceConfig config) { this.database = config.getDatabase(); this.schemaName = config.getSchemaName(); this.importQuery = config.getImportQuery(); + this.tableName = config.getTableName(); this.username = config.getUsername(); this.password = config.getPassword(); this.keyPairEnabled = config.getKeyPairEnabled(); @@ -106,6 +109,11 @@ public SnowflakeBatchSourceConfigBuilder setImportQuery(String importQuery) { return this; } + public SnowflakeBatchSourceConfigBuilder setTableName(String tableName) { + this.tableName = tableName; + return this; + } + public SnowflakeBatchSourceConfigBuilder setUsername(String username) { this.username = username; return this; @@ -172,6 +180,7 @@ public SnowflakeBatchSourceConfig build() { database, schemaName, importQuery, + tableName, username, password, keyPairEnabled, diff --git a/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigTest.java b/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigTest.java index a071bea..27e7e55 100644 --- a/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigTest.java +++ b/src/test/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfigTest.java @@ -68,4 +68,5 @@ public void validatePassword() { ValidationAssertions.assertValidationFailed( collector, Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_PASSWORD)); } + }