Skip to content

Commit 756fb99

Browse files
Error management for Snowflake source and sink, Added new validation for maximum split size and NPE issue handled
1 parent 1d9cacd commit 756fb99

16 files changed

+495
-80
lines changed

src/main/java/io/cdap/plugin/snowflake/common/BaseSnowflakeConfig.java

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -230,33 +230,10 @@ public String getConnectionArguments() {
230230
}
231231

232232
public void validate(FailureCollector collector) {
233-
if (getOauth2Enabled()) {
234-
if (!containsMacro(PROPERTY_CLIENT_ID)
235-
&& Strings.isNullOrEmpty(getClientId())) {
236-
collector.addFailure("Client ID is not set.", null)
237-
.withConfigProperty(PROPERTY_CLIENT_ID);
238-
}
239-
if (!containsMacro(PROPERTY_CLIENT_SECRET)
240-
&& Strings.isNullOrEmpty(getClientSecret())) {
241-
collector.addFailure("Client Secret is not set.", null)
242-
.withConfigProperty(PROPERTY_CLIENT_SECRET);
243-
}
244-
if (!containsMacro(PROPERTY_REFRESH_TOKEN)
245-
&& Strings.isNullOrEmpty(getRefreshToken())) {
246-
collector.addFailure("Refresh Token is not set.", null)
247-
.withConfigProperty(PROPERTY_REFRESH_TOKEN);
248-
}
249-
} else if (getKeyPairEnabled()) {
250-
if (!containsMacro(PROPERTY_USERNAME)
251-
&& Strings.isNullOrEmpty(getUsername())) {
252-
collector.addFailure("Username is not set.", null)
253-
.withConfigProperty(PROPERTY_USERNAME);
254-
}
255-
if (!containsMacro(PROPERTY_PRIVATE_KEY)
256-
&& Strings.isNullOrEmpty(getPrivateKey())) {
257-
collector.addFailure("Private Key is not set.", null)
258-
.withConfigProperty(PROPERTY_PRIVATE_KEY);
259-
}
233+
if (Boolean.TRUE.equals(getOauth2Enabled())) {
234+
validateWhenOath2Enabled(collector);
235+
} else if (Boolean.TRUE.equals(getKeyPairEnabled())) {
236+
validateWhenKeyPairEnabled(collector);
260237
} else {
261238
if (!containsMacro(PROPERTY_USERNAME)
262239
&& Strings.isNullOrEmpty(getUsername())) {
@@ -272,6 +249,37 @@ public void validate(FailureCollector collector) {
272249
validateConnection(collector);
273250
}
274251

252+
private void validateWhenKeyPairEnabled(FailureCollector collector) {
253+
if (!containsMacro(PROPERTY_USERNAME)
254+
&& Strings.isNullOrEmpty(getUsername())) {
255+
collector.addFailure("Username is not set.", null)
256+
.withConfigProperty(PROPERTY_USERNAME);
257+
}
258+
if (!containsMacro(PROPERTY_PRIVATE_KEY)
259+
&& Strings.isNullOrEmpty(getPrivateKey())) {
260+
collector.addFailure("Private Key is not set.", null)
261+
.withConfigProperty(PROPERTY_PRIVATE_KEY);
262+
}
263+
}
264+
265+
private void validateWhenOath2Enabled(FailureCollector collector) {
266+
if (!containsMacro(PROPERTY_CLIENT_ID)
267+
&& Strings.isNullOrEmpty(getClientId())) {
268+
collector.addFailure("Client ID is not set.", null)
269+
.withConfigProperty(PROPERTY_CLIENT_ID);
270+
}
271+
if (!containsMacro(PROPERTY_CLIENT_SECRET)
272+
&& Strings.isNullOrEmpty(getClientSecret())) {
273+
collector.addFailure("Client Secret is not set.", null)
274+
.withConfigProperty(PROPERTY_CLIENT_SECRET);
275+
}
276+
if (!containsMacro(PROPERTY_REFRESH_TOKEN)
277+
&& Strings.isNullOrEmpty(getRefreshToken())) {
278+
collector.addFailure("Refresh Token is not set.", null)
279+
.withConfigProperty(PROPERTY_REFRESH_TOKEN);
280+
}
281+
}
282+
275283
public boolean canConnect() {
276284
return (!containsMacro(PROPERTY_DATABASE) && !containsMacro(PROPERTY_SCHEMA_NAME)
277285
&& !containsMacro(PROPERTY_ACCOUNT_NAME) && !containsMacro(PROPERTY_USERNAME)
@@ -299,7 +307,7 @@ protected void validateConnection(FailureCollector collector) {
299307
.withConfigProperty(PROPERTY_USERNAME);
300308

301309
// TODO: for oauth2
302-
if (keyPairEnabled) {
310+
if (Boolean.TRUE.equals(keyPairEnabled)) {
303311
failure.withConfigProperty(PROPERTY_PRIVATE_KEY);
304312
} else {
305313
failure.withConfigProperty(PROPERTY_PASSWORD);

src/main/java/io/cdap/plugin/snowflake/common/OAuthUtil.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,27 @@
1919
import com.google.gson.JsonElement;
2020
import com.google.gson.JsonParser;
2121
import com.google.gson.JsonSyntaxException;
22+
import io.cdap.cdap.api.exception.ErrorCategory;
23+
import io.cdap.cdap.api.exception.ErrorType;
24+
import io.cdap.cdap.api.exception.ErrorUtils;
25+
import io.cdap.cdap.api.exception.ProgramFailureException;
26+
import io.cdap.cdap.etl.api.exception.ErrorPhase;
2227
import io.cdap.plugin.snowflake.common.exception.ConnectionTimeoutException;
28+
import io.cdap.plugin.snowflake.common.exception.SchemaParseException;
2329
import org.apache.http.client.methods.CloseableHttpResponse;
2430
import org.apache.http.client.methods.HttpPost;
2531
import org.apache.http.client.utils.URIBuilder;
2632
import org.apache.http.entity.StringEntity;
2733
import org.apache.http.impl.client.CloseableHttpClient;
2834
import org.apache.http.util.EntityUtils;
35+
import scala.xml.Null;
36+
2937
import java.io.IOException;
3038
import java.net.URI;
3139
import java.net.URISyntaxException;
3240
import java.net.URLEncoder;
3341
import java.util.Base64;
42+
import java.util.Objects;
3443

3544
/**
3645
* A class which contains utilities to make OAuth2 specific calls.
@@ -50,9 +59,15 @@ public static String getAccessTokenByRefreshToken(CloseableHttpClient httpclient
5059
httppost.setHeader("Content-type", "application/x-www-form-urlencoded");
5160

5261
// set grant type and refresh_token. It should be in body not url!
53-
StringEntity entity = new StringEntity(String.format("refresh_token=%s&grant_type=refresh_token",
54-
URLEncoder.encode(config.getRefreshToken(), "UTF-8")));
55-
httppost.setEntity(entity);
62+
try {
63+
StringEntity entity = new StringEntity(String.format("refresh_token=%s&grant_type=refresh_token",
64+
URLEncoder.encode(Objects.requireNonNull(Objects.requireNonNull(config).getRefreshToken()), "UTF-8")));
65+
httppost.setEntity(entity);
66+
} catch (NullPointerException e) {
67+
String errorMessage = "Error encoding URL due to missing Refresh Token.";
68+
throw ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN),
69+
errorMessage, String.format("Error message: %s", errorMessage), ErrorType.SYSTEM, true, e);
70+
}
5671

5772
// set 'Authorization' header
5873
String stringToEncode = config.getClientId() + ":" + config.getClientSecret();
@@ -72,7 +87,7 @@ public static String getAccessTokenByRefreshToken(CloseableHttpClient httpclient
7287

7388
// if exception happened during parsing OR if json does not contain 'access_token' key.
7489
if (jsonElement == null) {
75-
throw new RuntimeException(String.format("Unexpected response '%s' from '%s'", responseString, uri.toString()));
90+
throw new RuntimeException(String.format("Unexpected response '%s' from '%s'", responseString, uri));
7691
}
7792

7893
return jsonElement.getAsString();
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*
2+
* Copyright © 2024 Cask Data, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
5+
* use this file except in compliance with the License. You may obtain a copy of
6+
* the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13+
* License for the specific language governing permissions and limitations under
14+
* the License.
15+
*/
16+
17+
package io.cdap.plugin.snowflake.common;
18+
19+
import com.google.common.base.Throwables;
20+
import io.cdap.cdap.api.data.format.UnexpectedFormatException;
21+
import io.cdap.cdap.api.exception.ErrorCategory;
22+
import io.cdap.cdap.api.exception.ErrorCodeType;
23+
import io.cdap.cdap.api.exception.ErrorType;
24+
import io.cdap.cdap.api.exception.ErrorUtils;
25+
import io.cdap.cdap.api.exception.ProgramFailureException;
26+
import io.cdap.cdap.etl.api.exception.ErrorContext;
27+
import io.cdap.cdap.etl.api.exception.ErrorDetailsProvider;
28+
import io.cdap.plugin.snowflake.common.exception.ConnectionTimeoutException;
29+
import io.cdap.plugin.snowflake.common.exception.SchemaParseException;
30+
31+
import java.net.URISyntaxException;
32+
import java.util.List;
33+
34+
35+
/**
36+
* Error details provided for the Snowflake
37+
**/
38+
public class SnowflakeErrorDetailsProvider implements ErrorDetailsProvider {
39+
40+
@Override
41+
public ProgramFailureException getExceptionDetails(Exception e, ErrorContext errorContext) {
42+
List<Throwable> causalChain = Throwables.getCausalChain(e);
43+
for (Throwable t : causalChain) {
44+
if (t instanceof ProgramFailureException) {
45+
// if causal chain already has program failure exception, return null to avoid double wrap.
46+
return null;
47+
}
48+
if (t instanceof IllegalArgumentException) {
49+
return getProgramFailureException((IllegalArgumentException) t, errorContext);
50+
}
51+
if (t instanceof IllegalStateException) {
52+
return getProgramFailureException((IllegalStateException) t, errorContext);
53+
}
54+
if (t instanceof URISyntaxException) {
55+
return getProgramFailureException((URISyntaxException) t, errorContext);
56+
}
57+
if (t instanceof SchemaParseException) {
58+
return getProgramFailureException((SchemaParseException) t, errorContext);
59+
}
60+
if (t instanceof UnexpectedFormatException) {
61+
return getProgramFailureException((UnexpectedFormatException) t, errorContext);
62+
}
63+
if (t instanceof ConnectionTimeoutException) {
64+
return getProgramFailureException((ConnectionTimeoutException) t, errorContext);
65+
}
66+
}
67+
return null;
68+
}
69+
70+
/**
71+
* Get a ProgramFailureException with the given error
72+
* information from {@link IllegalArgumentException}.
73+
*
74+
* @param e The IllegalArgumentException to get the error information from.
75+
* @return A ProgramFailureException with the given error information.
76+
*/
77+
private ProgramFailureException getProgramFailureException(IllegalArgumentException e, ErrorContext errorContext) {
78+
String errorMessage = e.getMessage();
79+
String errorMessageFormat = "Error occurred in the phase: '%s'. Error message: %s";
80+
81+
return ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN),
82+
errorMessage,
83+
String.format(errorMessageFormat, errorContext.getPhase(), errorMessage), ErrorType.USER, false, e);
84+
}
85+
86+
/**
87+
* Get a ProgramFailureException with the given error
88+
* information from {@link IllegalStateException}.
89+
*
90+
* @param e The IllegalStateException to get the error information from.
91+
* @return A ProgramFailureException with the given error information.
92+
*/
93+
private ProgramFailureException getProgramFailureException(IllegalStateException e, ErrorContext errorContext) {
94+
String errorMessage = e.getMessage();
95+
String errorMessageFormat = "Error occurred in the phase: '%s'. Error message: %s";
96+
return ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN),
97+
errorMessage,
98+
String.format(errorMessageFormat, errorContext.getPhase(), errorMessage), ErrorType.SYSTEM, false, e);
99+
}
100+
101+
/**
102+
* Get a ProgramFailureException with the given error
103+
* information from {@link URISyntaxException}.
104+
*
105+
* @param e The URISyntaxException to get the error information from.
106+
* @return A ProgramFailureException with the given error information.
107+
*/
108+
private ProgramFailureException getProgramFailureException(URISyntaxException e,
109+
ErrorContext errorContext) {
110+
String errorMessage = e.getMessage();
111+
String errorMessageFormat = "Error occurred in the phase: '%s'. Error message: %s";
112+
return ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN),
113+
errorMessage,
114+
String.format(errorMessageFormat, errorContext.getPhase(), errorMessage), ErrorType.USER, false, e);
115+
}
116+
117+
/**
118+
* Get a ProgramFailureException with the given error
119+
* information from {@link SchemaParseException}.
120+
*
121+
* @param e The SchemaParseException to get the error information from.
122+
* @return A ProgramFailureException with the given error information.
123+
*/
124+
private ProgramFailureException getProgramFailureException(SchemaParseException e, ErrorContext errorContext) {
125+
String errorMessage = e.getMessage();
126+
String errorMessageFormat = "Error occurred in the phase: '%s'. Error message: %s";
127+
return ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN),
128+
errorMessage,
129+
String.format(errorMessageFormat, errorContext.getPhase(), errorMessage), ErrorType.USER, false, e);
130+
}
131+
132+
/**
133+
* Get a ProgramFailureException with the given error
134+
* information from {@link UnexpectedFormatException}.
135+
*
136+
* @param e The UnexpectedFormatException to get the error information from.
137+
* @return A ProgramFailureException with the given error information.
138+
*/
139+
private ProgramFailureException getProgramFailureException(UnexpectedFormatException e, ErrorContext errorContext) {
140+
String errorMessage = e.getMessage();
141+
String errorMessageFormat = "Error occurred in the phase: '%s'. Error message: %s";
142+
return ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN),
143+
errorMessage,
144+
String.format(errorMessageFormat, errorContext.getPhase(), errorMessage), ErrorType.USER, false, e);
145+
}
146+
147+
/**
148+
* Get a ProgramFailureException with the given error
149+
* information from {@link ConnectionTimeoutException}.
150+
*
151+
* @param e The ConnectionTimeoutException to get the error information from.
152+
* @return A ProgramFailureException with the given error information.
153+
*/
154+
private ProgramFailureException getProgramFailureException(ConnectionTimeoutException e, ErrorContext errorContext) {
155+
String errorMessage = e.getMessage();
156+
String errorMessageFormat = "Error occurred in the phase: '%s'. Error message: %s";
157+
return ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN),
158+
errorMessage,
159+
String.format(errorMessageFormat, errorContext.getPhase(), errorMessage), ErrorType.SYSTEM, false, e);
160+
}
161+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright © 2024 Cask Data, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
5+
* use this file except in compliance with the License. You may obtain a copy of
6+
* the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13+
* License for the specific language governing permissions and limitations under
14+
* the License.
15+
*/
16+
17+
package io.cdap.plugin.snowflake.common;
18+
19+
import io.cdap.cdap.api.exception.ErrorType;
20+
21+
import java.util.Arrays;
22+
import java.util.HashSet;
23+
import java.util.Set;
24+
25+
/**
26+
* Error Type provided based on the Snowflake error message code
27+
*
28+
**/
29+
public class SnowflakeErrorType {
30+
31+
//https://github.com/snowflakedb/snowflake-jdbc/blob/master/src/main/java/net/snowflake/client/jdbc/ErrorCode.java
32+
private static final Set<Integer> USER_ERRORS = new HashSet<>(Arrays.asList(
33+
200004, 200006, 200007, 200008, 200009, 200010, 200011, 200012, 200014,
34+
200017, 200018, 200019, 200021, 200023, 200024, 200025, 200026, 200028,
35+
200029, 200030, 200031, 200032, 200033, 200034, 200035, 200036, 200037,
36+
200038, 200045, 200046, 200047, 200056
37+
));
38+
39+
private static final Set<Integer> SYSTEM_ERRORS = new HashSet<>(Arrays.asList(
40+
200001, 200002, 200003, 200013, 200015, 200016, 200020, 200022, 200039,
41+
200040, 200044, 200061
42+
));
43+
44+
/**
45+
* Method to get the error type based on the error code.
46+
*
47+
* @param errorCode the error code to classify
48+
* @return the corresponding ErrorType (USER, SYSTEM, UNKNOWN)
49+
*/
50+
public static ErrorType getErrorType(int errorCode) {
51+
if (USER_ERRORS.contains(errorCode)) {
52+
return ErrorType.USER;
53+
} else if (SYSTEM_ERRORS.contains(errorCode)) {
54+
return ErrorType.SYSTEM;
55+
} else {
56+
return ErrorType.UNKNOWN;
57+
}
58+
}
59+
}

0 commit comments

Comments
 (0)