Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@
import java.io.FileWriter;
import java.io.IOException;
import java.lang.reflect.Field;
//import java.sql.*;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove comments

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;

import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
Expand Down Expand Up @@ -105,6 +108,25 @@ public List<SnowflakeFieldDescriptor> describeQuery(String query) throws IOExcep
return fieldDescriptors;
}

public List<SnowflakeFieldDescriptor> describeTable(String schemaName, String tableName) throws SQLException {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Javadoc for public methods

List<SnowflakeFieldDescriptor> fieldDescriptors = new ArrayList<>();

try (Connection connection = dataSource.getConnection()) {
DatabaseMetaData dbMetaData = connection.getMetaData();

try (ResultSet columns = dbMetaData.getColumns(null, schemaName, tableName, null)) {
while (columns.next()) {
String columnName = columns.getString("COLUMN_NAME");
int columnType = columns.getInt("DATA_TYPE");
boolean nullable = columns.getInt("NULLABLE") == DatabaseMetaData.columnNullable;

fieldDescriptors.add(new SnowflakeFieldDescriptor(columnName, columnType, nullable));
}
}
}
return fieldDescriptors;
}

private void initDataSource(SnowflakeBasicDataSource dataSource, BaseSnowflakeConfig config) {
dataSource.setDatabaseName(config.getDatabase());
dataSource.setSchema(config.getSchemaName());
Expand Down Expand Up @@ -193,4 +215,8 @@ private static String writeTextToTmpFile(String text) {
throw new RuntimeException("Cannot write key to temporary file", e);
}
}

public String getSchema() {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add javadoc here as well

return config.getSchemaName();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.cdap.plugin.snowflake.source.batch.SnowflakeInputFormatProvider;
import io.cdap.plugin.snowflake.source.batch.SnowflakeSourceAccessor;
import java.io.IOException;
import java.sql.SQLException;
import java.sql.Types;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -64,17 +65,18 @@ public static Schema getSchema(SnowflakeBatchSourceConfig config, FailureCollect
}

SnowflakeSourceAccessor snowflakeSourceAccessor =
new SnowflakeSourceAccessor(config, SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR);
return getSchema(snowflakeSourceAccessor, config.getSchema(), collector, config.getImportQuery());
new SnowflakeSourceAccessor(config, SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR);
return getSchema(snowflakeSourceAccessor, config.getSchema(), collector, config.getTableName(),
config.getImportQuery());
}

public static Schema getSchema(SnowflakeSourceAccessor snowflakeAccessor, String schema,
FailureCollector collector, String importQuery) {
FailureCollector collector, String tableName, String importQuery) {
try {
if (!Strings.isNullOrEmpty(schema)) {
return getParsedSchema(schema);
}
return Strings.isNullOrEmpty(importQuery) ? null : getSchema(snowflakeAccessor, importQuery);
return getSchema(snowflakeAccessor, snowflakeAccessor.getSchema(), tableName, importQuery);
} catch (SchemaParseException e) {
collector.addFailure(String.format("Unable to retrieve output schema. Reason: '%s'", e.getMessage()),
null)
Expand All @@ -95,15 +97,27 @@ private static Schema getParsedSchema(String schema) {
}
}

public static Schema getSchema(SnowflakeAccessor snowflakeAccessor, String importQuery) {
public static Schema getSchema(SnowflakeAccessor snowflakeAccessor, String schemaName,
String tableName, String importQuery) {
try {
List<SnowflakeFieldDescriptor> result = snowflakeAccessor.describeQuery(importQuery);
List<SnowflakeFieldDescriptor> result;
// If tableName is provided, describe the table
if (!Strings.isNullOrEmpty(tableName)) {
result = snowflakeAccessor.describeTable(schemaName, tableName);
} else if (!Strings.isNullOrEmpty(importQuery)) {
result = snowflakeAccessor.describeQuery(importQuery);
} else {
return null;
}
List<Schema.Field> fields = result.stream()
.map(fieldDescriptor -> Schema.Field.of(fieldDescriptor.getName(), getSchema(fieldDescriptor)))
.collect(Collectors.toList());
.map(fieldDescriptor -> Schema.Field.of(fieldDescriptor.getName(),
getSchema(fieldDescriptor)))
.collect(Collectors.toList());
return Schema.recordOf("data", fields);
} catch (IOException e) {
} catch (SQLException e) {
throw new SchemaParseException(e);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ private void validateInputSchema(Schema schema, FailureCollector failureCollecto
}

SnowflakeAccessor snowflakeAccessor = new SnowflakeAccessor(this);
Schema expectedSchema = SchemaHelper.getSchema(snowflakeAccessor, String.format(GET_FIELDS_QUERY, tableName));

// Schema expectedSchema = SchemaHelper.getSchema(snowflakeAccessor, String.format(GET_FIELDS_QUERY, tableName));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove comments

Schema expectedSchema = SchemaHelper.getSchema(snowflakeAccessor, getSchemaName(), tableName, null);
try {
SchemaHelper.checkCompatibility(expectedSchema, schema);
} catch (IllegalArgumentException ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
public static final String PROPERTY_IMPORT_QUERY = "importQuery";
public static final String PROPERTY_MAX_SPLIT_SIZE = "maxSplitSize";
public static final String PROPERTY_SCHEMA = "schema";
public static final String PROPERTY_TABLE_NAME = "tableName";

@Name(PROPERTY_REFERENCE_NAME)
@Description("This will be used to uniquely identify this source/sink for lineage, annotating metadata, etc.")
Expand All @@ -42,8 +43,15 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
@Name(PROPERTY_IMPORT_QUERY)
@Description("Query for import data.")
@Macro
@Nullable
private String importQuery;

@Name(PROPERTY_TABLE_NAME)
@Description("Name of the table to import data from. If specified, importQuery will be ignored.")
@Macro
@Nullable
private String tableName;

@Name(PROPERTY_MAX_SPLIT_SIZE)
@Description("Maximum split size specified in bytes.")
@Macro
Expand All @@ -55,17 +63,20 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
@Macro
private String schema;


public SnowflakeBatchSourceConfig(String referenceName, String accountName, String database,
String schemaName, String importQuery, String username, String password,
String schemaName, @Nullable String importQuery, @Nullable String tableName,
String username, String password,
@Nullable Boolean keyPairEnabled, @Nullable String path,
@Nullable String passphrase, @Nullable Boolean oauth2Enabled,
@Nullable String clientId, @Nullable String clientSecret,
@Nullable String refreshToken, Long maxSplitSize,
@Nullable String connectionArguments, @Nullable String schema) {
super(accountName, database, schemaName, username, password,
super(accountName, database, schemaName, tableName, password,
keyPairEnabled, path, passphrase, oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments);
this.referenceName = referenceName;
this.importQuery = importQuery;
this.tableName = tableName;
this.maxSplitSize = maxSplitSize;
this.schema = schema;
}
Expand All @@ -74,6 +85,11 @@ public String getImportQuery() {
return importQuery;
}

@Nullable
public String getTableName() {
return tableName;
}

public Long getMaxSplitSize() {
return maxSplitSize;
}
Expand All @@ -90,10 +106,11 @@ public String getSchema() {
public void validate(FailureCollector collector) {
super.validate(collector);

if (!containsMacro(PROPERTY_MAX_SPLIT_SIZE) && Objects.nonNull(maxSplitSize)
&& maxSplitSize < 0) {
collector.addFailure("Maximum Slit Size cannot be a negative number.", null)
.withConfigProperty(PROPERTY_MAX_SPLIT_SIZE);
if (tableName != null && importQuery != null) {
collector.addFailure("Both importQuery and tableName cannot be specified at the same time.",
"Provide either an importQuery or a tableName.")
.withConfigProperty(PROPERTY_IMPORT_QUERY)
.withConfigProperty(PROPERTY_TABLE_NAME);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.cdap.plugin.snowflake.source.batch;

import au.com.bytecode.opencsv.CSVReader;
import com.google.common.base.Strings;
import io.cdap.plugin.snowflake.common.SnowflakeErrorType;
import io.cdap.plugin.snowflake.common.client.SnowflakeAccessor;
import io.cdap.plugin.snowflake.common.util.DocumentUrlUtil;
Expand Down Expand Up @@ -77,7 +78,11 @@ public SnowflakeSourceAccessor(SnowflakeBatchSourceConfig config, String escapeC
*/
public List<String> prepareStageSplits() {
LOG.info("Loading data into stage: '{}'", STAGE_PATH);
String copy = String.format(COMAND_COPY_INTO, QueryUtil.removeSemicolon(config.getImportQuery()));
String importQuery = config.getImportQuery();
if (Strings.isNullOrEmpty(importQuery)) {
importQuery = "SELECT * FROM " + config.getTableName();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to use String.format method

}
String copy = String.format(COMAND_COPY_INTO, QueryUtil.removeSemicolon(importQuery));
if (config.getMaxSplitSize() > 0) {
copy = copy + String.format(COMMAND_MAX_FILE_SIZE, config.getMaxSplitSize());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
public class SchemaHelperTest {

private static final String MOCK_STAGE = "mockStage";
private static final String MOCK_SCHEMA = "mockSchema";
private static final String MOCK_TABLE = "mockTable";

@Test
public void testGetSchema() {
Expand All @@ -48,7 +50,11 @@ public void testGetSchema() {
);

MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE);
Schema actual = SchemaHelper.getSchema(null, expected.toString(), collector, null);
SnowflakeBatchSourceConfig mockConfig = Mockito.mock(SnowflakeBatchSourceConfig.class);
Mockito.when(mockConfig.canConnect()).thenReturn(false);
Mockito.when(mockConfig.getSchema()).thenReturn(expected.toString());

Schema actual = SchemaHelper.getSchema(mockConfig, collector);

Assert.assertTrue(collector.getValidationFailures().isEmpty());
Assert.assertEquals(expected, actual);
Expand All @@ -57,8 +63,10 @@ public void testGetSchema() {
@Test
public void testGetSchemaInvalidJson() {
MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE);
SchemaHelper.getSchema(null, "{}", collector, null);

SnowflakeBatchSourceConfig mockConfig = Mockito.mock(SnowflakeBatchSourceConfig.class);
Mockito.when(mockConfig.getSchema()).thenReturn("{}");
// SchemaHelper.getSchema(null, "{}", collector, null);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove comments

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think its better to not remove the comments it will help other developer to easily see what changes i had made to the particular line of code.

SchemaHelper.getSchema(mockConfig, collector);
ValidationAssertions.assertValidationFailed(
collector, Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_SCHEMA));
}
Expand All @@ -74,8 +82,7 @@ public void testGetSchemaFromSnowflakeUnknownType() throws IOException {

Mockito.when(snowflakeAccessor.describeQuery(importQuery)).thenReturn(sample);

SchemaHelper.getSchema(snowflakeAccessor, null, collector, importQuery);

SchemaHelper.getSchema(snowflakeAccessor, MOCK_SCHEMA, MOCK_TABLE, importQuery);
ValidationAssertions.assertValidationFailed(
collector, Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_SCHEMA));
}
Expand Down Expand Up @@ -146,7 +153,7 @@ public void testGetSchemaFromSnowflake() throws IOException {

Mockito.when(snowflakeAccessor.describeQuery(importQuery)).thenReturn(sample);

Schema actual = SchemaHelper.getSchema(snowflakeAccessor, null, collector, importQuery);
Schema actual = SchemaHelper.getSchema(snowflakeAccessor, MOCK_SCHEMA, MOCK_TABLE, importQuery);

Assert.assertTrue(collector.getValidationFailures().isEmpty());
Assert.assertEquals(expected, actual);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class SnowflakeBatchSourceConfigBuilder {
"database",
"schemaName",
"importQuery",
"tableName",
"username",
"password",
false,
Expand All @@ -45,6 +46,7 @@ public class SnowflakeBatchSourceConfigBuilder {
private String database;
private String schemaName;
private String importQuery;
private String tableName;
private String username;
private String password;
private Boolean keyPairEnabled;
Expand All @@ -67,6 +69,7 @@ public SnowflakeBatchSourceConfigBuilder(SnowflakeBatchSourceConfig config) {
this.database = config.getDatabase();
this.schemaName = config.getSchemaName();
this.importQuery = config.getImportQuery();
this.tableName = config.getTableName();
this.username = config.getUsername();
this.password = config.getPassword();
this.keyPairEnabled = config.getKeyPairEnabled();
Expand Down Expand Up @@ -172,6 +175,7 @@ public SnowflakeBatchSourceConfig build() {
database,
schemaName,
importQuery,
tableName,
username,
password,
keyPairEnabled,
Expand Down