Skip to content

Commit 89bb5df

Browse files
committed
Get schema without import query
1 parent 2dc3fe5 commit 89bb5df

File tree

12 files changed

+216
-20
lines changed

12 files changed

+216
-20
lines changed

docs/Snowflake-batchsource.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ log in to Snowflake, minus the "snowflakecomputing.com"). E.g. "myaccount.us-cen
2525

2626
**Role:** Role to use (e.g. `ACCOUNTADMIN`).
2727

28+
**Table Name:** The name of the table to retrieve the schema.
29+
2830
**Import Query:** Query for data import.
2931

3032
### Credentials

src/main/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessor.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import java.io.IOException;
3838
import java.lang.reflect.Field;
3939
import java.sql.Connection;
40+
import java.sql.DatabaseMetaData;
4041
import java.sql.PreparedStatement;
4142
import java.sql.ResultSet;
4243
import java.sql.ResultSetMetaData;
@@ -74,6 +75,39 @@ public void runSQL(String query) {
7475
}
7576
}
7677

78+
/**
79+
* Returns field descriptors for specified table name
80+
*
81+
* @param schemaName The name of schema containing the table
82+
* @param tableName The name of table whose metadata needs to be retrieved
83+
* @return list of field descriptors
84+
* @throws SQLException If an error occurs while retrieving metadata from the database
85+
*/
86+
public List<SnowflakeFieldDescriptor> describeTable(String schemaName, String tableName) throws SQLException {
87+
List<SnowflakeFieldDescriptor> fieldDescriptors = new ArrayList<>();
88+
try (Connection connection = dataSource.getConnection()) {
89+
DatabaseMetaData dbMetaData = connection.getMetaData();
90+
try (ResultSet columns = dbMetaData.getColumns(null, schemaName, tableName, null)) {
91+
while (columns.next()) {
92+
String columnName = columns.getString("COLUMN_NAME");
93+
int columnType = columns.getInt("DATA_TYPE");
94+
boolean nullable = columns.getInt("NULLABLE") == DatabaseMetaData.columnNullable;
95+
fieldDescriptors.add(new SnowflakeFieldDescriptor(columnName, columnType, nullable));
96+
}
97+
}
98+
} catch (SQLException e) {
99+
String errorMessage = String.format(
100+
"Failed to retrieve table metadata with SQL State %s and error code %s with message: %s.",
101+
e.getSQLState(), e.getErrorCode(), e.getMessage()
102+
);
103+
String errorReason = String.format("Failed to retrieve table metadata with SQL State %s and error " +
104+
"code %s. For more details %s", e.getSQLState(), e.getErrorCode(),
105+
DocumentUrlUtil.getSupportedDocumentUrl());
106+
throw SnowflakeErrorType.fetchProgramFailureException(e, errorReason, errorMessage);
107+
}
108+
return fieldDescriptors;
109+
}
110+
77111
/**
78112
* Returns field descriptors for specified import query.
79113
*
@@ -163,6 +197,7 @@ public void checkConnection() {
163197
errorReason, errorMessage, ErrorType.USER, true, e);
164198
}
165199
}
200+
166201
// SnowflakeBasicDataSource doesn't provide access for additional properties.
167202
private void addConnectionArguments(SnowflakeBasicDataSource dataSource, String connectionArguments) {
168203
try {
@@ -193,4 +228,13 @@ private static String writeTextToTmpFile(String text) {
193228
throw new RuntimeException("Cannot write key to temporary file", e);
194229
}
195230
}
231+
232+
/**
233+
* Retrieves schema name from the configuration
234+
*
235+
* @return The schema name
236+
*/
237+
public String getSchema() {
238+
return config.getSchemaName();
239+
}
196240
}

src/main/java/io/cdap/plugin/snowflake/common/util/QueryUtil.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package io.cdap.plugin.snowflake.common.util;
1818

19+
import com.google.common.base.Strings;
20+
1921
/**
2022
* Transforms import query.
2123
*/
@@ -29,6 +31,9 @@ private QueryUtil() {
2931
}
3032

3133
public static String removeSemicolon(String importQuery) {
34+
if (Strings.isNullOrEmpty(importQuery)) {
35+
return null;
36+
}
3237
if (importQuery.endsWith(";")) {
3338
importQuery = importQuery.substring(0, importQuery.length() - 1);
3439
}

src/main/java/io/cdap/plugin/snowflake/common/util/SchemaHelper.java

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
import io.cdap.plugin.snowflake.source.batch.SnowflakeInputFormatProvider;
2828
import io.cdap.plugin.snowflake.source.batch.SnowflakeSourceAccessor;
2929
import java.io.IOException;
30+
import java.sql.SQLException;
3031
import java.sql.Types;
3132
import java.util.List;
3233
import java.util.Map;
3334
import java.util.Objects;
3435
import java.util.stream.Collectors;
36+
import javax.annotation.Nullable;
3537

3638
/**
3739
* Resolves schema.
@@ -58,24 +60,47 @@ public class SchemaHelper {
5860
private SchemaHelper() {
5961
}
6062

63+
/**
64+
* Retrieves schema for the Snowflake batch source based on the given configuration.
65+
*
66+
* @param config The configuration for Snowflake batch source
67+
* @param collector The failure collector to capture any schema retrieval errors.
68+
* @return The resolved schema for Snowflake source
69+
*/
6170
public static Schema getSchema(SnowflakeBatchSourceConfig config, FailureCollector collector) {
6271
if (!config.canConnect()) {
6372
return getParsedSchema(config.getSchema());
6473
}
6574

6675
SnowflakeSourceAccessor snowflakeSourceAccessor =
6776
new SnowflakeSourceAccessor(config, SnowflakeInputFormatProvider.PROPERTY_DEFAULT_ESCAPE_CHAR);
68-
return getSchema(snowflakeSourceAccessor, config.getSchema(), collector, config.getImportQuery());
77+
return getSchema(
78+
snowflakeSourceAccessor,
79+
config.getSchema(),
80+
collector,
81+
config.getTableName(),
82+
config.getImportQuery()
83+
);
6984
}
7085

86+
/**
87+
* Retrieves schema for a Snowflake source based on the provided parameters.
88+
*
89+
* @param snowflakeAccessor The {@link SnowflakeSourceAccessor} used to connect to Snowflake.
90+
* @param schema A JSON-format schema string
91+
* @param collector The {@link FailureCollector} to collect errors if schema retrieval fails.
92+
* @param tableName The name of the table in Snowflake.
93+
* @param importQuery The query to fetch data from Snowflake, used when `tableName` is not provided.
94+
* @return The parsed {@link Schema} if successful, or {@code null} if an error occurs.
95+
*/
7196
public static Schema getSchema(SnowflakeSourceAccessor snowflakeAccessor, String schema,
72-
FailureCollector collector, String importQuery) {
97+
FailureCollector collector, String tableName, String importQuery) {
7398
try {
7499
if (!Strings.isNullOrEmpty(schema)) {
75100
return getParsedSchema(schema);
76101
}
77-
return Strings.isNullOrEmpty(importQuery) ? null : getSchema(snowflakeAccessor, importQuery);
78-
} catch (SchemaParseException e) {
102+
return getSchema(snowflakeAccessor, tableName, importQuery);
103+
} catch (SchemaParseException | IllegalArgumentException e) {
79104
collector.addFailure(String.format("Unable to retrieve output schema. Reason: '%s'", e.getMessage()),
80105
null)
81106
.withStacktrace(e.getStackTrace())
@@ -95,6 +120,25 @@ private static Schema getParsedSchema(String schema) {
95120
}
96121
}
97122

123+
private static Schema getSchema(SnowflakeAccessor snowflakeAccessor, @Nullable String tableName, @Nullable String importQuery) {
124+
try {
125+
List<SnowflakeFieldDescriptor> result;
126+
if (!Strings.isNullOrEmpty(tableName)) {
127+
result = snowflakeAccessor.describeTable(snowflakeAccessor.getSchema(), tableName);
128+
} else {
129+
result = snowflakeAccessor.describeQuery(importQuery);
130+
}
131+
List<Schema.Field> fields = result.stream()
132+
.map(fieldDescriptor -> Schema.Field.of(fieldDescriptor.getName(), getSchema(fieldDescriptor)))
133+
.collect(Collectors.toList());
134+
return Schema.recordOf("data", fields);
135+
} catch (SQLException e) {
136+
throw new SchemaParseException(e);
137+
} catch (IOException e) {
138+
throw new RuntimeException(e);
139+
}
140+
}
141+
98142
public static Schema getSchema(SnowflakeAccessor snowflakeAccessor, String importQuery) {
99143
try {
100144
List<SnowflakeFieldDescriptor> result = snowflakeAccessor.describeQuery(importQuery);

src/main/java/io/cdap/plugin/snowflake/sink/batch/SnowflakeSinkConfig.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public class SnowflakeSinkConfig extends BaseSnowflakeConfig {
4444
@Name(PROPERTY_TABLE_NAME)
4545
@Description("Name of the table to insert records into.")
4646
@Macro
47+
@Nullable
4748
private String tableName;
4849

4950
@Name(PROPERTY_MAX_FILE_SIZE)

src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeBatchSourceConfig.java

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
package io.cdap.plugin.snowflake.source.batch;
1818

19+
import com.google.common.base.Strings;
1920
import io.cdap.cdap.api.annotation.Description;
2021
import io.cdap.cdap.api.annotation.Macro;
2122
import io.cdap.cdap.api.annotation.Name;
2223
import io.cdap.cdap.etl.api.FailureCollector;
2324
import io.cdap.plugin.snowflake.common.BaseSnowflakeConfig;
24-
2525
import java.util.Objects;
2626
import javax.annotation.Nullable;
2727

@@ -34,6 +34,7 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
3434
public static final String PROPERTY_IMPORT_QUERY = "importQuery";
3535
public static final String PROPERTY_MAX_SPLIT_SIZE = "maxSplitSize";
3636
public static final String PROPERTY_SCHEMA = "schema";
37+
public static final String PROPERTY_TABLE_NAME = "tableName";
3738

3839
@Name(PROPERTY_REFERENCE_NAME)
3940
@Description("This will be used to uniquely identify this source/sink for lineage, annotating metadata, etc.")
@@ -42,6 +43,7 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
4243
@Name(PROPERTY_IMPORT_QUERY)
4344
@Description("Query for import data.")
4445
@Macro
46+
@Nullable
4547
private String importQuery;
4648

4749
@Name(PROPERTY_MAX_SPLIT_SIZE)
@@ -55,19 +57,29 @@ public class SnowflakeBatchSourceConfig extends BaseSnowflakeConfig {
5557
@Macro
5658
private String schema;
5759

60+
@Name(PROPERTY_TABLE_NAME)
61+
@Nullable
62+
@Description("The name of the table used to retrieve the schema.")
63+
private final String tableName;
64+
5865
public SnowflakeBatchSourceConfig(String referenceName, String accountName, String database,
59-
String schemaName, String importQuery, String username, String password,
66+
String schemaName, @Nullable String importQuery, String username, String password,
6067
@Nullable Boolean keyPairEnabled, @Nullable String path,
6168
@Nullable String passphrase, @Nullable Boolean oauth2Enabled,
6269
@Nullable String clientId, @Nullable String clientSecret,
6370
@Nullable String refreshToken, Long maxSplitSize,
64-
@Nullable String connectionArguments, @Nullable String schema) {
65-
super(accountName, database, schemaName, username, password,
66-
keyPairEnabled, path, passphrase, oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments);
71+
@Nullable String connectionArguments,
72+
@Nullable String schema,
73+
@Nullable String tableName) {
74+
super(
75+
accountName, database, schemaName, username, password, keyPairEnabled, path, passphrase,
76+
oauth2Enabled, clientId, clientSecret, refreshToken, connectionArguments
77+
);
6778
this.referenceName = referenceName;
6879
this.importQuery = importQuery;
6980
this.maxSplitSize = maxSplitSize;
7081
this.schema = schema;
82+
this.tableName = tableName;
7183
}
7284

7385
public String getImportQuery() {
@@ -87,6 +99,11 @@ public String getSchema() {
8799
return schema;
88100
}
89101

102+
@Nullable
103+
public String getTableName() {
104+
return tableName;
105+
}
106+
90107
public void validate(FailureCollector collector) {
91108
super.validate(collector);
92109

@@ -95,5 +112,11 @@ public void validate(FailureCollector collector) {
95112
collector.addFailure("Maximum Slit Size cannot be a negative number.", null)
96113
.withConfigProperty(PROPERTY_MAX_SPLIT_SIZE);
97114
}
115+
116+
if (Strings.isNullOrEmpty(importQuery) && Strings.isNullOrEmpty(tableName)) {
117+
collector.addFailure("Either 'Schema' or 'Table Name' must be provided.", null)
118+
.withConfigProperty(PROPERTY_IMPORT_QUERY)
119+
.withConfigProperty(PROPERTY_TABLE_NAME);
120+
}
98121
}
99122
}

src/main/java/io/cdap/plugin/snowflake/source/batch/SnowflakeSourceAccessor.java

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package io.cdap.plugin.snowflake.source.batch;
1818

1919
import au.com.bytecode.opencsv.CSVReader;
20+
import com.google.common.base.Strings;
2021
import io.cdap.plugin.snowflake.common.SnowflakeErrorType;
2122
import io.cdap.plugin.snowflake.common.client.SnowflakeAccessor;
2223
import io.cdap.plugin.snowflake.common.util.DocumentUrlUtil;
@@ -77,7 +78,12 @@ public SnowflakeSourceAccessor(SnowflakeBatchSourceConfig config, String escapeC
7778
*/
7879
public List<String> prepareStageSplits() {
7980
LOG.info("Loading data into stage: '{}'", STAGE_PATH);
80-
String copy = String.format(COMAND_COPY_INTO, QueryUtil.removeSemicolon(config.getImportQuery()));
81+
String importQuery = config.getImportQuery();
82+
if (Strings.isNullOrEmpty(importQuery)) {
83+
String tableName = config.getTableName();
84+
importQuery = String.format("SELECT * FROM %s", tableName);
85+
}
86+
String copy = String.format(COMAND_COPY_INTO, QueryUtil.removeSemicolon(importQuery));
8187
if (config.getMaxSplitSize() > 0) {
8288
copy = copy + String.format(COMMAND_MAX_FILE_SIZE, config.getMaxSplitSize());
8389
}
@@ -94,10 +100,13 @@ public List<String> prepareStageSplits() {
94100
}
95101
} catch (SQLException e) {
96102
String errorReason = String.format("Failed to load data into stage '%s' with sqlState %s and errorCode %s. " +
97-
"For more details, see %s.", STAGE_PATH, e.getErrorCode(), e.getSQLState(),
98-
DocumentUrlUtil.getSupportedDocumentUrl());
99-
String errorMessage = String.format("Failed to load data into stage '%s' with sqlState %s and errorCode %s. " +
100-
"Failed to execute query with message: %s.", STAGE_PATH, e.getSQLState(), e.getErrorCode(), e.getMessage());
103+
"For more details, see %s.", STAGE_PATH, e.getErrorCode(), e.getSQLState(),
104+
DocumentUrlUtil.getSupportedDocumentUrl());
105+
String errorMessage = String.format(
106+
"Failed to load data into stage '%s' with sqlState %s and errorCode %s. "
107+
+ "Failed to execute query with message: %s.",
108+
STAGE_PATH, e.getSQLState(), e.getErrorCode(), e.getMessage()
109+
);
101110
throw SnowflakeErrorType.fetchProgramFailureException(e, errorReason, errorMessage);
102111
}
103112
return stageSplits;

src/test/java/io/cdap/plugin/snowflake/common/client/SnowflakeAccessorTest.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,41 @@ public void testDescribeQuery() throws Exception {
9090
Assert.assertEquals(expected, actual);
9191
}
9292

93+
@Test
94+
public void testDescribeTable() throws Exception {
95+
String schemaName = "TEST_SCHEMA";
96+
String tableName = "TEST_TABLE";
97+
List<SnowflakeFieldDescriptor> expected = Arrays.asList(
98+
new SnowflakeFieldDescriptor("COLUMN_NUMBER", -5, true),
99+
new SnowflakeFieldDescriptor("COLUMN_DECIMAL", -5, true),
100+
new SnowflakeFieldDescriptor("COLUMN_NUMERIC", -5, true),
101+
new SnowflakeFieldDescriptor("COLUMN_INT", -5, true),
102+
new SnowflakeFieldDescriptor("COLUMN_INTEGER", -5, true),
103+
new SnowflakeFieldDescriptor("COLUMN_BIGINT", -5, true),
104+
new SnowflakeFieldDescriptor("COLUMN_SMALLINT", -5, true),
105+
new SnowflakeFieldDescriptor("COLUMN_FLOAT", 8, true),
106+
new SnowflakeFieldDescriptor("COLUMN_DOUBLE", 8, true),
107+
new SnowflakeFieldDescriptor("COLUMN_REAL", 8, true),
108+
new SnowflakeFieldDescriptor("COLUMN_VARCHAR", 12, true),
109+
new SnowflakeFieldDescriptor("COLUMN_CHAR", 12, true),
110+
new SnowflakeFieldDescriptor("COLUMN_TEXT", 12, true),
111+
new SnowflakeFieldDescriptor("COLUMN_BINARY", -2, true),
112+
new SnowflakeFieldDescriptor("COLUMN_BOOLEAN", 16, true),
113+
new SnowflakeFieldDescriptor("COLUMN_DATE", 91, true),
114+
new SnowflakeFieldDescriptor("COLUMN_TIMESTAMP", 93, true),
115+
new SnowflakeFieldDescriptor("COLUMN_VARIANT", 12, true),
116+
new SnowflakeFieldDescriptor("COLUMN_OBJECT", 12, true),
117+
new SnowflakeFieldDescriptor("COLUMN_ARRAY", 12, true)
118+
);
119+
120+
List<SnowflakeFieldDescriptor> actual = snowflakeAccessor.describeTable(String.valueOf(Constants.TEST_TABLE_SCHEMA),
121+
Constants.TEST_TABLE);
122+
123+
Assert.assertNotNull(actual);
124+
Assert.assertFalse(actual.isEmpty());
125+
Assert.assertEquals(expected, actual);
126+
}
127+
93128
@Test
94129
public void testPrepareStageSplits() throws Exception {
95130
Pattern expected = Pattern.compile("cdap_stage/result.*data__0_0_0\\.csv\\.gz");

0 commit comments

Comments
 (0)