Skip to content

Commit 2acb278

Browse files
committed
Get schema without import query
1 parent 2dc3fe5 commit 2acb278

File tree

12 files changed

+231
-35
lines changed

12 files changed

+231
-35
lines changed

docs/Snowflake-batchsource.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ log in to Snowflake, minus the "snowflakecomputing.com"). E.g. "myaccount.us-cen
2525

2626
**Role:** Role to use (e.g. `ACCOUNTADMIN`).
2727

28+
**Table Name:** The name of the table to retrieve the schema.
29+
2830
**Import Query:** Query for data import.
2931

3032
### Credentials

src/main/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessor.java

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import java.io.IOException;
3838
import java.lang.reflect.Field;
3939
import java.sql.Connection;
40+
import java.sql.DatabaseMetaData;
4041
import java.sql.PreparedStatement;
4142
import java.sql.ResultSet;
4243
import java.sql.ResultSetMetaData;
@@ -67,13 +68,46 @@ public void runSQL(String query) {
6768
populateStmt.execute();
6869
} catch (SQLException e) {
6970
String errorMessage = String.format("Statement '%s' failed with SQL state %s and error code %s due to '%s'",
70-
query, e.getSQLState(), e.getErrorCode(), e.getMessage());
71+
query, e.getSQLState(), e.getErrorCode(), e.getMessage());
7172
String errorReason = String.format("Statement '%s' failed with SQL state %s and error code %s. For more " +
72-
"details see %s.", query, e.getSQLState(), e.getErrorCode(), DocumentUrlUtil.getSupportedDocumentUrl());
73+
"details see %s.", query, e.getSQLState(), e.getErrorCode(), DocumentUrlUtil.getSupportedDocumentUrl());
7374
throw SnowflakeErrorType.fetchProgramFailureException(e, errorReason, errorMessage);
7475
}
7576
}
7677

78+
/**
79+
* Returns field descriptors for specified table name
80+
*
81+
* @param schemaName The name of schema containing the table
82+
* @param tableName The name of table whose metadata needs to be retrieved
83+
* @return list of field descriptors
84+
* @throws SQLException If an error occurs while retrieving metadata from the database
85+
*/
86+
public List<SnowflakeFieldDescriptor> describeTable(String schemaName, String tableName) throws SQLException {
87+
List<SnowflakeFieldDescriptor> fieldDescriptors = new ArrayList<>();
88+
try (Connection connection = dataSource.getConnection()) {
89+
DatabaseMetaData dbMetaData = connection.getMetaData();
90+
try (ResultSet columns = dbMetaData.getColumns(null, schemaName, tableName, null)) {
91+
while (columns.next()) {
92+
String columnName = columns.getString("COLUMN_NAME");
93+
int columnType = columns.getInt("DATA_TYPE");
94+
boolean nullable = columns.getInt("NULLABLE") == DatabaseMetaData.columnNullable;
95+
fieldDescriptors.add(new SnowflakeFieldDescriptor(columnName, columnType, nullable));
96+
}
97+
}
98+
} catch (SQLException e) {
99+
String errorMessage = String.format(
100+
"Failed to retrieve table metadata with SQL State %s and error code %s with message: %s.",
101+
e.getSQLState(), e.getErrorCode(), e.getMessage()
102+
);
103+
String errorReason = String.format("Failed to retrieve table metadata with SQL State %s and error " +
104+
"code %s. For more details %s", e.getSQLState(), e.getErrorCode(),
105+
DocumentUrlUtil.getSupportedDocumentUrl());
106+
throw SnowflakeErrorType.fetchProgramFailureException(e, errorReason, errorMessage);
107+
}
108+
return fieldDescriptors;
109+
}
110+
77111
/**
78112
* Returns field descriptors for specified import query.
79113
*
@@ -97,9 +131,9 @@ public List<SnowflakeFieldDescriptor> describeQuery(String query) throws IOExcep
97131
}
98132
} catch (SQLException e) {
99133
String errorMessage = String.format("Failed to execute query to fetch descriptors with SQL State %s and error " +
100-
"code %s with message: %s.", e.getSQLState(), e.getErrorCode(), e.getMessage());
134+
"code %s with message: %s.", e.getSQLState(), e.getErrorCode(), e.getMessage());
101135
String errorReason = String.format("Failed to execute query to fetch descriptors with SQL State %s and error " +
102-
"code %s. For more details %s", e.getSQLState(), e.getErrorCode(), DocumentUrlUtil.getSupportedDocumentUrl());
136+
"code %s. For more details %s", e.getSQLState(), e.getErrorCode(), DocumentUrlUtil.getSupportedDocumentUrl());
103137
throw SnowflakeErrorType.fetchProgramFailureException(e, errorReason, errorMessage);
104138
}
105139
return fieldDescriptors;
@@ -157,12 +191,13 @@ public void checkConnection() {
157191
throw new ConnectionTimeoutException("Cannot create Snowflake connection.", e);
158192
} catch (NullPointerException e) {
159193
String errorMessage = String.format("Failed to create Snowflake connection due to missing Username or password " +
160-
"with message: %s.", e.getMessage());
194+
"with message: %s.", e.getMessage());
161195
String errorReason = "Cannot create Snowflake connection. Username or password is missing.";
162196
throw ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN),
163-
errorReason, errorMessage, ErrorType.USER, true, e);
197+
errorReason, errorMessage, ErrorType.USER, true, e);
164198
}
165199
}
200+
166201
// SnowflakeBasicDataSource doesn't provide access for additional properties.
167202
private void addConnectionArguments(SnowflakeBasicDataSource dataSource, String connectionArguments) {
168203
try {
@@ -193,4 +228,13 @@ private static String writeTextToTmpFile(String text) {
193228
throw new RuntimeException("Cannot write key to temporary file", e);
194229
}
195230
}
231+
232+
/**
233+
* Retrieves schema name from the configuration
234+
*
235+
* @return The schema name
236+
*/
237+
public String getSchema() {
238+
return config.getSchemaName();
239+
}
196240
}

src/main/java/io/cdap/plugin/snowflake/common/util/QueryUtil.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package io.cdap.plugin.snowflake.common.util;
1818

19+
import com.google.common.base.Strings;
20+
1921
/**
2022
* Transforms import query.
2123
*/
@@ -29,6 +31,9 @@ private QueryUtil() {
2931
}
3032

3133
public static String removeSemicolon(String importQuery) {
34+
if (Strings.isNullOrEmpty(importQuery)) {
35+
return null;
36+
}
3237
if (importQuery.endsWith(";")) {
3338
importQuery = importQuery.substring(0, importQuery.length() - 1);
3439
}

src/main/java/io/cdap/plugin/snowflake/common/util/SchemaHelper.java

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
import io.cdap.plugin.snowflake.source.batch.SnowflakeInputFormatProvider;
2828
import io.cdap.plugin.snowflake.source.batch.SnowflakeSourceAccessor;
2929
import java.io.IOException;
30+
import java.sql.SQLException;
3031
import java.sql.Types;
3132
import java.util.List;
3233
import java.util.Map;
3334
import java.util.Objects;
3435
import java.util.stream.Collectors;
36+
import javax.annotation.Nullable;
3537

3638
/**
3739
* Resolves schema.
@@ -58,24 +60,47 @@ public class SchemaHelper {
5860
private SchemaHelper() {
5961
}
6062

63+
/**
64+
* Retrieves schema for the Snowflake batch source based on the given configuration.
65+
*
66+
* @param config The configuration for Snowflake batch source
67+
* @param collector The failure collector to capture any schema retrieval errors.
68+
* @return The resolved schema for Snowflake source
69+
*/
6170
public static Schema getSchema(SnowflakeBatchSourceConfig config, FailureCollector collector) {
6271
if (!config.canConnect()) {
6372
return getParsedSchema(config.getSchema());
6473
}
6574

6675
SnowflakeSourceAccessor snowflakeSourceAccessor =
6776
new SnowflakeSourceAccessor(config, SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR);
68-
return getSchema(snowflakeSourceAccessor, config.getSchema(), collector, config.getImportQuery());
77+
return getSchema(
78+
snowflakeSourceAccessor,
79+
config.getSchema(),
80+
collector,
81+
config.getTableName(),
82+
config.getImportQuery()
83+
);
6984
}
7085

86+
/**
87+
* Retrieves schema for a Snowflake source based on the provided parameters.
88+
*
89+
* @param snowflakeAccessor The {@link SnowflakeSourceAccessor} used to connect to Snowflake.
90+
* @param schema A JSON-format schema string
91+
* @param collector The {@link FailureCollector} to collect errors if schema retrieval fails.
92+
* @param tableName The name of the table in Snowflake.
93+
* @param importQuery The query to fetch data from Snowflake, used when `tableName` is not provided.
94+
* @return The parsed {@link Schema} if successful, or {@code null} if an error occurs.
95+
*/
7196
public static Schema getSchema(SnowflakeSourceAccessor snowflakeAccessor, String schema,
72-
FailureCollector collector, String importQuery) {
97+
FailureCollector collector, String tableName, String importQuery) {
7398
try {
7499
if (!Strings.isNullOrEmpty(schema)) {
75100
return getParsedSchema(schema);
76101
}
77-
return Strings.isNullOrEmpty(importQuery) ? null : getSchema(snowflakeAccessor, importQuery);
78-
} catch (SchemaParseException e) {
102+
return getSchema(snowflakeAccessor, tableName, importQuery);
103+
} catch (SchemaParseException | IllegalArgumentException e) {
79104
collector.addFailure(String.format("Unable to retrieve output schema. Reason: '%s'", e.getMessage()),
80105
null)
81106
.withStacktrace(e.getStackTrace())
@@ -95,6 +120,25 @@ private static Schema getParsedSchema(String schema) {
95120
}
96121
}
97122

123+
private static Schema getSchema(SnowflakeAccessor snowflakeAccessor, @Nullable String tableName, @Nullable String importQuery) {
124+
try {
125+
List<SnowflakeFieldDescriptor> result;
126+
if (!Strings.isNullOrEmpty(tableName)) {
127+
result = snowflakeAccessor.describeTable(snowflakeAccessor.getSchema(), tableName);
128+
} else {
129+
result = snowflakeAccessor.describeQuery(importQuery);
130+
}
131+
List<Schema.Field> fields = result.stream()
132+
.map(fieldDescriptor -> Schema.Field.of(fieldDescriptor.getName(), getSchema(fieldDescriptor)))
133+
.collect(Collectors.toList());
134+
return Schema.recordOf("data", fields);
135+
} catch (SQLException e) {
136+
throw new SchemaParseException(e);
137+
} catch (IOException e) {
138+
throw new RuntimeException(e);
139+
}
140+
}
141+
98142
public static Schema getSchema(SnowflakeAccessor snowflakeAccessor, String importQuery) {
99143
try {
100144
List<SnowflakeFieldDescriptor> result = snowflakeAccessor.describeQuery(importQuery);

src/main/java/io/cdap/plugin/snowflake/sink/batch/SnowflakeSinkConfig.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public class SnowflakeSinkConfig extends BaseSnowflakeConfig {
4444
@Name(PROPERTY_TABLE_NAME)
4545
@Description("Name of the table to insert records into.")
4646
@Macro
47+
@Nullable
4748
private String tableName;
4849

4950
@Name(PROPERTY_MAX_FILE_SIZE)

src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfig.java

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
package io.cdap.plugin.snowflake.source.batch;
1818

19+
import com.google.common.base.Strings;
1920
import io.cdap.cdap.api.annotation.Description;
2021
import io.cdap.cdap.api.annotation.Macro;
2122
import io.cdap.cdap.api.annotation.Name;
2223
import io.cdap.cdap.etl.api.FailureCollector;
2324
import io.cdap.plugin.snowflake.common.BaseSnowflakeConfig;
24-
2525
import java.util.Objects;
2626
import javax.annotation.Nullable;
2727

@@ -34,6 +34,7 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
3434
public static final String PROPERTY_IMPORT_QUERY = "importQuery";
3535
public static final String PROPERTY_MAX_SPLIT_SIZE = "maxSplitSize";
3636
public static final String PROPERTY_SCHEMA = "schema";
37+
public static final String PROPERTY_TABLE_NAME = "tableName";
3738

3839
@Name(PROPERTY_REFERENCE_NAME)
3940
@Description("This will be used to uniquely identify this source/sink for lineage, annotating metadata, etc.")
@@ -42,6 +43,7 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
4243
@Name(PROPERTY_IMPORT_QUERY)
4344
@Description("Query for import data.")
4445
@Macro
46+
@Nullable
4547
private String importQuery;
4648

4749
@Name(PROPERTY_MAX_SPLIT_SIZE)
@@ -55,19 +57,29 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
5557
@Macro
5658
private String schema;
5759

60+
@Name(PROPERTY_TABLE_NAME)
61+
@Nullable
62+
@Description("The name of the table used to retrieve the schema.")
63+
private final String tableName;
64+
5865
public SnowflakeBatchSourceConfig(String referenceName, String accountName, String database,
59-
String schemaName, String importQuery, String username, String password,
66+
String schemaName, @Nullable String importQuery, String username, String password,
6067
@Nullable Boolean keyPairEnabled, @Nullable String path,
6168
@Nullable String passphrase, @Nullable Boolean oauth2Enabled,
6269
@Nullable String clientId, @Nullable String clientSecret,
6370
@Nullable String refreshToken, Long maxSplitSize,
64-
@Nullable String connectionArguments, @Nullable String schema) {
65-
super(accountName, database, schemaName, username, password,
66-
keyPairEnabled, path, passphrase, oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments);
71+
@Nullable String connectionArguments,
72+
@Nullable String schema,
73+
@Nullable String tableName) {
74+
super(
75+
accountName, database, schemaName, username, password, keyPairEnabled, path, passphrase,
76+
oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments
77+
);
6778
this.referenceName = referenceName;
6879
this.importQuery = importQuery;
6980
this.maxSplitSize = maxSplitSize;
7081
this.schema = schema;
82+
this.tableName = tableName;
7183
}
7284

7385
public String getImportQuery() {
@@ -87,6 +99,11 @@ public String getSchema() {
8799
return schema;
88100
}
89101

102+
@Nullable
103+
public String getTableName() {
104+
return tableName;
105+
}
106+
90107
public void validate(FailureCollector collector) {
91108
super.validate(collector);
92109

@@ -95,5 +112,11 @@ public void validate(FailureCollector collector) {
95112
collector.addFailure("Maximum Slit Size cannot be a negative number.", null)
96113
.withConfigProperty(PROPERTY_MAX_SPLIT_SIZE);
97114
}
115+
116+
if (Strings.isNullOrEmpty(importQuery) && Strings.isNullOrEmpty(tableName)) {
117+
collector.addFailure("Either 'Schema' or 'Table Name' must be provided.", null)
118+
.withConfigProperty(PROPERTY_IMPORT_QUERY)
119+
.withConfigProperty(PROPERTY_TABLE_NAME);
120+
}
98121
}
99-
}
122+
}

src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeSourceAccessor.java

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package io.cdap.plugin.snowflake.source.batch;
1818

1919
import au.com.bytecode.opencsv.CSVReader;
20+
import com.google.common.base.Strings;
2021
import io.cdap.plugin.snowflake.common.SnowflakeErrorType;
2122
import io.cdap.plugin.snowflake.common.client.SnowflakeAccessor;
2223
import io.cdap.plugin.snowflake.common.util.DocumentUrlUtil;
@@ -77,7 +78,12 @@ public SnowflakeSourceAccessor(SnowflakeBatchSourceConfig config, String escapeC
7778
*/
7879
public List<String> prepareStageSplits() {
7980
LOG.info("Loading data into stage: '{}'", STAGE_PATH);
80-
String copy = String.format(COMAND_COPY_INTO, QueryUtil.removeSemicolon(config.getImportQuery()));
81+
String importQuery = config.getImportQuery();
82+
if (Strings.isNullOrEmpty(importQuery)) {
83+
String tableName = config.getTableName();
84+
importQuery = String.format("SELECT * FROM %s", tableName);
85+
}
86+
String copy = String.format(COMAND_COPY_INTO, QueryUtil.removeSemicolon(importQuery));
8187
if (config.getMaxSplitSize() > 0) {
8288
copy = copy + String.format(COMMAND_MAX_FILE_SIZE, config.getMaxSplitSize());
8389
}
@@ -94,10 +100,13 @@ public List<String> prepareStageSplits() {
94100
}
95101
} catch (SQLException e) {
96102
String errorReason = String.format("Failed to load data into stage '%s' with sqlState %s and errorCode %s. " +
97-
"For more details, see %s.", STAGE_PATH, e.getErrorCode(), e.getSQLState(),
98-
DocumentUrlUtil.getSupportedDocumentUrl());
99-
String errorMessage = String.format("Failed to load data into stage '%s' with sqlState %s and errorCode %s. " +
100-
"Failed to execute query with message: %s.", STAGE_PATH, e.getSQLState(), e.getErrorCode(), e.getMessage());
103+
"For more details, see %s.", STAGE_PATH, e.getErrorCode(), e.getSQLState(),
104+
DocumentUrlUtil.getSupportedDocumentUrl());
105+
String errorMessage = String.format(
106+
"Failed to load data into stage '%s' with sqlState %s and errorCode %s. "
107+
+ "Failed to execute query with message: %s.",
108+
STAGE_PATH, e.getSQLState(), e.getErrorCode(), e.getMessage()
109+
);
101110
throw SnowflakeErrorType.fetchProgramFailureException(e, errorReason, errorMessage);
102111
}
103112
return stageSplits;
@@ -126,11 +135,11 @@ public CSVReader buildCsvReader(String stageSplit) {
126135
return new CSVReader(inputStreamReader, ',', '"', escapeChar);
127136
} catch (SQLException e) {
128137
String errorReason = String.format("Failed to execute the query with sqlState: '%s' & errorCode: '%s'. " +
129-
"For more details, see %s.", e.getSQLState(), e.getErrorCode(), DocumentUrlUtil.getSupportedDocumentUrl());
138+
"For more details, see %s.", e.getSQLState(), e.getErrorCode(), DocumentUrlUtil.getSupportedDocumentUrl());
130139
String errorMessage = String.format("Failed to execute the query with sqlState: '%s' & errorCode: '%s' " +
131-
"with message: %s, stage split at %s.", e.getSQLState(), e.getErrorCode(),
132-
e.getMessage(), stageSplit);
140+
"with message: %s, stage split at %s.", e.getSQLState(), e.getErrorCode(),
141+
e.getMessage(), stageSplit);
133142
throw SnowflakeErrorType.fetchProgramFailureException(e, errorReason, errorMessage);
134143
}
135144
}
136-
}
145+
}

0 commit comments

Comments
 (0)