Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,13 @@ public abstract class LoadUnloadConfig extends BaseSnowflakeConfig {


public LoadUnloadConfig(String accountName, String database,
String schemaName, String username, String password,
String schemaName, String tableName, String username, String password,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tableName is not required here!

@Nullable Boolean keyPairEnabled, @Nullable String path,
@Nullable String passphrase, @Nullable Boolean oauth2Enabled, @Nullable String clientId,
@Nullable String clientSecret, @Nullable String refreshToken,
@Nullable String connectionArguments) {
super(accountName, database, schemaName, username, password, keyPairEnabled, path, passphrase, oauth2Enabled,
super(accountName, database, schemaName, tableName, username, password, keyPairEnabled, path, passphrase,
oauth2Enabled,
clientId, clientSecret, refreshToken, connectionArguments);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ public class LoadActionConfig extends LoadUnloadConfig {
@Nullable
private String pattern;

public LoadActionConfig(String accountName, String database, String schemaName, String username, String password,
public LoadActionConfig(String accountName, String database, String schemaName, String tableName,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tableName not required here as well

String username, String password,
@Nullable Boolean keyPairEnabled, @Nullable String path, @Nullable String passphrase,
@Nullable Boolean oauth2Enabled, @Nullable String clientId, @Nullable String clientSecret,
@Nullable String refreshToken, @Nullable String connectionArguments) {
super(accountName, database, schemaName, username, password, keyPairEnabled, path, passphrase, oauth2Enabled,
super(accountName, database, tableName, schemaName, username, password, keyPairEnabled, path, passphrase,
oauth2Enabled,
clientId, clientSecret, refreshToken, connectionArguments);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ public class UnloadActionConfig extends LoadUnloadConfig {
private Boolean includeHeader;


public UnloadActionConfig(String accountName, String database, String schemaName, String username, String password,
public UnloadActionConfig(String accountName, String database, String schemaName, String tableName, String username,
String password,
@Nullable Boolean keyPairEnabled, @Nullable String path, @Nullable String passphrase,
@Nullable Boolean oauth2Enabled, @Nullable String clientId, @Nullable String clientSecret,
@Nullable String refreshToken, @Nullable String connectionArguments) {
super(accountName, database, schemaName, username, password, keyPairEnabled, path, passphrase, oauth2Enabled,
super(accountName, database, schemaName, tableName, username, password, keyPairEnabled, path,
passphrase, oauth2Enabled,
clientId, clientSecret, refreshToken, connectionArguments);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ public class RunSQLConfig extends BaseSnowflakeConfig {
@Macro
private String query;

public RunSQLConfig(String accountName, String database, String schemaName, String username, String password,
public RunSQLConfig(String accountName, String database, String schemaName, String tableName, String username,
String password,
@Nullable Boolean keyPairEnabled, @Nullable String path, @Nullable String passphrase,
@Nullable Boolean oauth2Enabled, @Nullable String clientId, @Nullable String clientSecret,
@Nullable String refreshToken, @Nullable String connectionArguments) {
super(accountName, database, schemaName, username, password, keyPairEnabled, path, passphrase, oauth2Enabled,
super(accountName, database, schemaName, tableName, username, password, keyPairEnabled, path, passphrase,
oauth2Enabled,
clientId, clientSecret, refreshToken, connectionArguments);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class BaseSnowflakeConfig extends PluginConfig {
public static final String PROPERTY_ACCOUNT_NAME = "accountName";
public static final String PROPERTY_DATABASE = "database";
public static final String PROPERTY_SCHEMA_NAME = "schemaName";
public static final String PROPERTY_TABLE_NAME = "TableName";
public static final String PROPERTY_WAREHOUSE = "warehouse";
public static final String PROPERTY_ROLE = "role";
public static final String PROPERTY_USERNAME = "username";
Expand Down Expand Up @@ -63,6 +64,13 @@ public class BaseSnowflakeConfig extends PluginConfig {
@Macro
private String schemaName;

@Name(PROPERTY_TABLE_NAME)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Table Name should be moved under Basic section using widgets.json

@Description("Name of the table to import data from. If specified, importQuery will be ignored.")
@Macro
@Nullable
private String tableName;


@Nullable
@Name(PROPERTY_WAREHOUSE)
@Description("Warehouse to connect to. If not specified default warehouse is used.")
Expand All @@ -87,6 +95,7 @@ public class BaseSnowflakeConfig extends PluginConfig {
@Nullable
private String password;


@Name(PROPERTY_KEY_PAIR_ENABLED)
@Description("If true, plugin will perform Key Pair authentication.")
@Nullable
Expand Down Expand Up @@ -136,6 +145,7 @@ public class BaseSnowflakeConfig extends PluginConfig {
public BaseSnowflakeConfig(String accountName,
String database,
String schemaName,
String tableName,
String username,
String password,
@Nullable Boolean keyPairEnabled,
Expand All @@ -150,6 +160,7 @@ public BaseSnowflakeConfig(String accountName,
this.database = database;
this.schemaName = schemaName;
this.username = username;
this.tableName = tableName;
this.password = password;
this.keyPairEnabled = keyPairEnabled;
this.privateKey = privateKey;
Expand All @@ -161,6 +172,7 @@ public BaseSnowflakeConfig(String accountName,
this.connectionArguments = connectionArguments;
}


public String getAccountName() {
return accountName;
}
Expand All @@ -173,6 +185,11 @@ public String getSchemaName() {
return schemaName;
}

@Nullable
public String getTableName() {
return tableName;
}

@Nullable
public String getWarehouse() {
return warehouse;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ public static Schema getSchema(SnowflakeSourceAccessor snowflakeAccessor, String
if (!Strings.isNullOrEmpty(schema)) {
return getParsedSchema(schema);
}
return getSchema(snowflakeAccessor, snowflakeAccessor.getSchema(), tableName, importQuery);
} catch (SchemaParseException e) {
return getSchema(snowflakeAccessor, tableName, importQuery);
} catch (SchemaParseException | IllegalArgumentException e) {
collector.addFailure(String.format("Unable to retrieve output schema. Reason: '%s'", e.getMessage()),
null)
.withStacktrace(e.getStackTrace())
Expand All @@ -97,18 +97,17 @@ private static Schema getParsedSchema(String schema) {
}
}

public static Schema getSchema(SnowflakeAccessor snowflakeAccessor, String schemaName,
public static Schema getSchema(SnowflakeAccessor snowflakeAccessor,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The getSchema method used within the SchemaHelper class only should be made private

String tableName, String importQuery) {
try {
List<SnowflakeFieldDescriptor> result;
// If tableName is provided, describe the table
if (!Strings.isNullOrEmpty(tableName)) {
result = snowflakeAccessor.describeTable(schemaName, tableName);
} else if (!Strings.isNullOrEmpty(importQuery)) {
result = snowflakeAccessor.describeQuery(importQuery);
result = snowflakeAccessor.describeTable(snowflakeAccessor.getSchema(), tableName);
} else {
return null;
result = snowflakeAccessor.describeQuery(importQuery);
}

List<Schema.Field> fields = result.stream()
.map(fieldDescriptor -> Schema.Field.of(fieldDescriptor.getName(),
getSchema(fieldDescriptor)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ public class SnowflakeSinkConfig extends BaseSnowflakeConfig {
private String copyOptions;

public SnowflakeSinkConfig(String referenceName, String accountName, String database,
String schemaName, String username, String password,
String schemaName, String tableName, String username, String password,
@Nullable Boolean keyPairEnabled, @Nullable String path,
@Nullable String passphrase, @Nullable Boolean oauth2Enabled,
@Nullable String clientId, @Nullable String clientSecret,
@Nullable String refreshToken, @Nullable String connectionArguments) {
super(accountName, database, schemaName, username, password,
super(accountName, database, schemaName, tableName, username, password,
keyPairEnabled, path, passphrase, oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments);
this.referenceName = referenceName;
}
Expand Down Expand Up @@ -105,7 +105,7 @@ private void validateInputSchema(Schema schema, FailureCollector failureCollecto

SnowflakeAccessor snowflakeAccessor = new SnowflakeAccessor(this);
// Schema expectedSchema = SchemaHelper.getSchema(snowflakeAccessor, String.format(GET_FIELDS_QUERY, tableName));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove comments

Schema expectedSchema = SchemaHelper.getSchema(snowflakeAccessor, getSchemaName(), tableName, null);
Schema expectedSchema = SchemaHelper.getSchema(snowflakeAccessor, tableName, null);
try {
SchemaHelper.checkCompatibility(expectedSchema, schema);
} catch (IllegalArgumentException ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
public static final String PROPERTY_IMPORT_QUERY = "importQuery";
public static final String PROPERTY_MAX_SPLIT_SIZE = "maxSplitSize";
public static final String PROPERTY_SCHEMA = "schema";
public static final String PROPERTY_TABLE_NAME = "tableName";
// public static final String PROPERTY_TABLE_NAME = "tableName";

@Name(PROPERTY_REFERENCE_NAME)
@Description("This will be used to uniquely identify this source/sink for lineage, annotating metadata, etc.")
Expand All @@ -47,11 +47,6 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
@Nullable
private String importQuery;

@Name(PROPERTY_TABLE_NAME)
@Description("Name of the table to import data from. If specified, importQuery will be ignored.")
@Macro
@Nullable
private String tableName;

@Name(PROPERTY_MAX_SPLIT_SIZE)
@Description("Maximum split size specified in bytes.")
Expand All @@ -73,11 +68,11 @@ public SnowflakeBatchSourceConfig(String referenceName, String accountName, Stri
@Nullable String clientId, @Nullable String clientSecret,
@Nullable String refreshToken, Long maxSplitSize,
@Nullable String connectionArguments, @Nullable String schema) {
super(accountName, database, schemaName, tableName, password,
super(accountName, database, schemaName, tableName, username, password,
keyPairEnabled, path, passphrase, oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments);
this.referenceName = referenceName;
this.importQuery = importQuery;
this.tableName = tableName;
// this.tableName = tableName;
this.maxSplitSize = maxSplitSize;
this.schema = schema;
}
Expand All @@ -86,10 +81,10 @@ public String getImportQuery() {
return importQuery;
}

@Nullable
public String getTableName() {
return tableName;
}
// @Nullable
// public String getTableName() {
// return tableName;
// }

public Long getMaxSplitSize() {
return maxSplitSize;
Expand All @@ -106,14 +101,14 @@ public String getSchema() {

public void validate(FailureCollector collector) {
super.validate(collector);

if (!containsMacro(PROPERTY_IMPORT_QUERY) && !containsMacro(PROPERTY_TABLE_NAME)) {
if (Strings.isNullOrEmpty(tableName) && Strings.isNullOrEmpty(importQuery)) {
if (Strings.isNullOrEmpty(getTableName()) && Strings.isNullOrEmpty(importQuery)) {
collector.addFailure("Both importQuery and tableName cannot be NULL at the same time.",
"Provide either an importQuery or a tableName.")
.withConfigProperty(PROPERTY_IMPORT_QUERY)
.withConfigProperty(PROPERTY_TABLE_NAME);
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@
import org.mockito.Mockito;

import java.io.IOException;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static net.snowflake.client.loader.LoaderProperty.tableName;

/**
* Tests for {@link SchemaHelper}
*/
Expand All @@ -50,11 +53,11 @@ public void testGetSchema() {
);

MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE);
SnowflakeBatchSourceConfig mockConfig = Mockito.mock(SnowflakeBatchSourceConfig.class);
Mockito.when(mockConfig.canConnect()).thenReturn(false);
Mockito.when(mockConfig.getSchema()).thenReturn(expected.toString());

Schema actual = SchemaHelper.getSchema(mockConfig, collector);
// SnowflakeBatchSourceConfig mockConfig = Mockito.mock(SnowflakeBatchSourceConfig.class);
// Mockito.when(mockConfig.canConnect()).thenReturn(false);
// Mockito.when(mockConfig.getSchema()).thenReturn(expected.toString());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove comments

Schema actual = SchemaHelper.getSchema(null, expected.toString(), collector, null,
null);

Assert.assertTrue(collector.getValidationFailures().isEmpty());
Assert.assertEquals(expected, actual);
Expand All @@ -65,8 +68,8 @@ public void testGetSchemaInvalidJson() {
MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE);
SnowflakeBatchSourceConfig mockConfig = Mockito.mock(SnowflakeBatchSourceConfig.class);
Mockito.when(mockConfig.getSchema()).thenReturn("{}");
// SchemaHelper.getSchema(null, "{}", collector, null);
SchemaHelper.getSchema(mockConfig, collector);
SchemaHelper.getSchema(null, "{}", collector, null, null);
// SchemaHelper.getSchema(mockConfig, collector);
ValidationAssertions.assertValidationFailed(
collector, Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_SCHEMA));
}
Expand All @@ -81,15 +84,16 @@ public void testGetSchemaFromSnowflakeUnknownType() throws IOException {
sample.add(new SnowflakeFieldDescriptor("field1", -1000, false));

Mockito.when(snowflakeAccessor.describeQuery(importQuery)).thenReturn(sample);

SchemaHelper.getSchema(snowflakeAccessor, MOCK_SCHEMA, MOCK_TABLE, importQuery);
String tableName = "tableName";
SchemaHelper.getSchema(snowflakeAccessor, null, collector, tableName, importQuery);
ValidationAssertions.assertValidationFailed(
collector, Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_SCHEMA));
}

@Test
public void testGetSchemaFromSnowflake() throws IOException {
public void testGetSchemaFromSnowflake() throws IOException, SQLException {
String importQuery = "SELECT * FROM someTable";
String tableName = "tableName";
MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE);
SnowflakeSourceAccessor snowflakeAccessor = Mockito.mock(SnowflakeSourceAccessor.class);

Expand Down Expand Up @@ -148,12 +152,16 @@ public void testGetSchemaFromSnowflake() throws IOException {
Schema.Field.of("field131", Schema.nullableOf(Schema.decimalOf(38))),
Schema.Field.of("field132", Schema.decimalOf(38)),
Schema.Field.of("field133", Schema.of(Schema.LogicalType.TIMESTAMP_MICROS)),
Schema.Field.of("field134", Schema.nullableOf(Schema.of(Schema.LogicalType.TIMESTAMP_MICROS)))
Schema.Field.of("field134", Schema.nullableOf(Schema.of(Schema.LogicalType.TIMESTAMP_MICROS)))

);

Mockito.when(snowflakeAccessor.describeQuery(importQuery)).thenReturn(sample);
Mockito.when(snowflakeAccessor.describeTable(Mockito.any(), String.valueOf (Mockito.eq(tableName)))).
thenReturn(sample);


Schema actual = SchemaHelper.getSchema(snowflakeAccessor, MOCK_SCHEMA, MOCK_TABLE, importQuery);
Schema actual = SchemaHelper.getSchema(snowflakeAccessor, null, collector, tableName, importQuery);

Assert.assertTrue(collector.getValidationFailures().isEmpty());
Assert.assertEquals(expected, actual);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ public class SnowflakeBatchSourceConfigBuilder {
"schemaName",
"importQuery",
"tableName",
"bqdiuser",
"Datafusion@321",
"userName",
"Password",
false,
"",
"",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,35 +69,4 @@ public void validatePassword() {
collector, Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_PASSWORD));
}

/*
* Creating a config where both tableName and importQuery are null
*/

@Test
public void testValidateTableNameAndImportQueryNull() {
SnowflakeBatchSourceConfig config = new SnowflakeBatchSourceConfigBuilder()
.setReferenceName("testRef")
.setAccountName("testAccount")
.setDatabase("testDB")
.setSchemaName("testSchema")
.setUsername("testUser")
.setPassword("testPassword")
.setMaxSplitSize(1024L)
.setTableName(null)
.setImportQuery(null)
.build();

// Mock FailureCollector to capture validation errors
MockFailureCollector collector = new MockFailureCollector(MOCK_STAGE);
config.validate(collector);
Assert.assertFalse(collector.getValidationFailures().isEmpty());
ValidationAssertions.assertValidationFailed(
collector,
Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_IMPORT_QUERY)
);
ValidationAssertions.assertValidationFailed(
collector,
Collections.singletonList(SnowflakeBatchSourceConfig.PROPERTY_TABLE_NAME)
);
}
}