Skip to content

Commit 92fdb3c

Browse files
authored
Snowflake export update (#3124)
1 parent efb989a commit 92fdb3c

File tree

17 files changed

+2358
-3367
lines changed

17 files changed

+2358
-3367
lines changed

athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/BlockAllocatorImpl.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ public synchronized Block createBlock(Schema schema)
128128
List<FieldVector> vectors = new ArrayList();
129129
try {
130130
for (Field next : schema.getFields()) {
131-
vectors.add(next.createVector(rootAllocator));
131+
FieldVector vector = next.createVector(rootAllocator);
132+
vector.allocateNew();
133+
vectors.add(vector);
132134
}
133135
vectorSchemaRoot = new VectorSchemaRoot(schema, vectors, 0);
134136
block = new Block(id, schema, vectorSchemaRoot);

athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,6 @@ public void writeRows(RowWriter rowWriter)
204204
throw (ex instanceof RuntimeException) ? (RuntimeException) ex : new RuntimeException(ex);
205205
}
206206

207-
if (rows > maxRowsPerCall) {
208-
throw new AthenaConnectorException("Call generated more than " + maxRowsPerCall + "rows. Generating " +
209-
"too many rows per call to writeRows(...) can result in blocks that exceed the max size.", ErrorDetails.builder().errorCode(FederationSourceErrorCode.INVALID_INPUT_EXCEPTION.toString()).build());
210-
}
211207
if (rows > 0) {
212208
block.setRowCount(rowCount + rows);
213209
}

athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/exceptions/AthenaConnectorException.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,17 @@ public AthenaConnectorException(@Nonnull final Object response,
8181
requireNonNull(e);
8282
}
8383

84+
public AthenaConnectorException(@Nonnull final String message,
85+
@Nonnull final Exception e,
86+
@Nonnull final ErrorDetails errorDetails)
87+
{
88+
super(message, e);
89+
this.errorDetails = requireNonNull(errorDetails);
90+
this.response = null;
91+
requireNonNull(message);
92+
requireNonNull(e);
93+
}
94+
8495
public Object getResponse()
8596
{
8697
return response;

athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandlerTest.java

Lines changed: 101 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import com.google.cloud.bigquery.Table;
5151
import com.google.cloud.bigquery.TableDefinition;
5252
import com.google.cloud.bigquery.TableResult;
53+
import com.google.cloud.bigquery.storage.v1.ArrowRecordBatch;
5354
import com.google.cloud.bigquery.storage.v1.ArrowSchema;
5455
import com.google.cloud.bigquery.storage.v1.BigQueryReadClient;
5556
import com.google.cloud.bigquery.storage.v1.CreateReadSessionRequest;
@@ -69,10 +70,16 @@
6970
import org.apache.arrow.vector.VectorSchemaRoot;
7071
import org.apache.arrow.vector.complex.ListVector;
7172
import org.apache.arrow.vector.complex.impl.UnionListWriter;
73+
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
7274
import org.apache.arrow.vector.ipc.ReadChannel;
73-
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
75+
import org.apache.arrow.vector.ipc.WriteChannel;
7476
import org.apache.arrow.vector.ipc.message.MessageSerializer;
77+
import org.apache.arrow.vector.types.FloatingPointPrecision;
78+
import org.apache.arrow.vector.types.pojo.ArrowType;
79+
import org.apache.arrow.vector.types.pojo.Field;
80+
import org.apache.arrow.vector.types.pojo.FieldType;
7581
import org.apache.arrow.vector.types.pojo.Schema;
82+
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
7683
import org.junit.After;
7784
import org.junit.Before;
7885
import org.junit.Test;
@@ -90,6 +97,7 @@
9097
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest;
9198
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse;
9299

100+
import java.io.ByteArrayOutputStream;
93101
import java.nio.charset.StandardCharsets;
94102
import java.util.Arrays;
95103
import java.util.ArrayList;
@@ -150,63 +158,9 @@ public class BigQueryRecordHandlerTest
150158
.build();
151159
private FederatedIdentity federatedIdentity;
152160
private MockedStatic<BigQueryUtils> mockedStatic;
153-
private MockedStatic<MessageSerializer> messageSer;
154-
MockedConstruction<VectorSchemaRoot> mockedDefaultVectorSchemaRoot;
155-
MockedConstruction<VectorLoader> mockedDefaultVectorLoader;
156161
@Mock
157162
private Job queryJob;
158163

159-
public List<FieldVector> getFieldVectors()
160-
{
161-
List<FieldVector> fieldVectors = new ArrayList<>();
162-
IntVector intVector = new IntVector("int1", rootAllocator);
163-
intVector.allocateNew(1024);
164-
intVector.setSafe(0, 42); // Example: Set the value at index 0 to 42
165-
intVector.setSafe(1, 3);
166-
intVector.setValueCount(2);
167-
fieldVectors.add(intVector);
168-
VarCharVector varcharVector = new VarCharVector("string1", rootAllocator);
169-
varcharVector.allocateNew(1024);
170-
varcharVector.setSafe(0, "test".getBytes(StandardCharsets.UTF_8)); // Example: Set the value at index 0 to 42
171-
varcharVector.setSafe(1, "test1".getBytes(StandardCharsets.UTF_8));
172-
varcharVector.setValueCount(2);
173-
fieldVectors.add(varcharVector);
174-
BitVector bitVector = new BitVector("bool1", rootAllocator);
175-
bitVector.allocateNew(1024);
176-
bitVector.setSafe(0, 1); // Example: Set the value at index 0 to 42
177-
bitVector.setSafe(1, 0);
178-
bitVector.setValueCount(2);
179-
fieldVectors.add(bitVector);
180-
Float8Vector float8Vector = new Float8Vector("float1", rootAllocator);
181-
float8Vector.allocateNew(1024);
182-
float8Vector.setSafe(0, 1.00f); // Example: Set the value at index 0 to 42
183-
float8Vector.setSafe(1, 0.0f);
184-
float8Vector.setValueCount(2);
185-
fieldVectors.add(float8Vector);
186-
IntVector innerVector = new IntVector("innerVector", rootAllocator);
187-
innerVector.allocateNew(1024);
188-
innerVector.setSafe(0, 10);
189-
innerVector.setSafe(1, 20);
190-
innerVector.setSafe(2, 30);
191-
innerVector.setValueCount(3);
192-
193-
// Create a ListVector and add the inner vector to it
194-
ListVector listVector = ListVector.empty("listVector", rootAllocator);
195-
UnionListWriter writer = listVector.getWriter();
196-
for (int i = 0; i < 2; i++) {
197-
writer.startList();
198-
writer.setPosition(i);
199-
for (int j = 0; j < 5; j++) {
200-
writer.writeInt(j * i);
201-
}
202-
writer.setValueCount(5);
203-
writer.endList();
204-
}
205-
listVector.setValueCount(2);
206-
fieldVectors.add(listVector);
207-
return fieldVectors;
208-
}
209-
210164
@Before
211165
public void init()
212166
{
@@ -229,10 +183,9 @@ public void init()
229183
//Create Spill config
230184
spillConfig = SpillConfig.newBuilder()
231185
.withEncryptionKey(encryptionKey)
232-
//This will be enough for a single block
233-
.withMaxBlockBytes(100000)
234186
//This will force the writer to spill.
235-
.withMaxInlineBlockBytes(100)
187+
.withMaxBlockBytes(20)
188+
.withMaxInlineBlockBytes(1)
236189
//Async Writing.
237190
.withNumSpillThreads(0)
238191
.withRequestId(UUID.randomUUID().toString())
@@ -278,47 +231,40 @@ public void testReadWithConstraint()
278231
try (ReadRecordsRequest request = getReadRecordsRequest(Collections.emptyMap())) {
279232
// Mocking necessary dependencies
280233
ReadSession readSession = mock(ReadSession.class);
281-
ReadRowsResponse readRowsResponse = mock(ReadRowsResponse.class);
282234
ServerStreamingCallable ssCallable = mock(ServerStreamingCallable.class);
283235

284236
// Mocking method calls
285237
mockStatic(BigQueryReadClient.class);
286238
when(BigQueryReadClient.create()).thenReturn(bigQueryReadClient);
287-
messageSer = mockStatic(MessageSerializer.class);
288-
when(MessageSerializer.deserializeSchema((ReadChannel) any())).thenReturn(BigQueryTestUtils.getBlockTestSchema());
289-
mockedDefaultVectorLoader = Mockito.mockConstruction(VectorLoader.class,
290-
(mock, context) -> {
291-
Mockito.doNothing().when(mock).load(any());
292-
});
293-
mockedDefaultVectorSchemaRoot = Mockito.mockConstruction(VectorSchemaRoot.class,
294-
(mock, context) -> {
295-
when(mock.getRowCount()).thenReturn(2);
296-
when(mock.getFieldVectors()).thenReturn(getFieldVectors());
297-
});
298239
when(bigQueryReadClient.createReadSession(any(CreateReadSessionRequest.class))).thenReturn(readSession);
299240
when(readSession.getArrowSchema()).thenReturn(arrowSchema);
300241
when(readSession.getStreamsCount()).thenReturn(1);
301242
ReadStream readStream = mock(ReadStream.class);
302243
when(readSession.getStreams(anyInt())).thenReturn(readStream);
303244
when(readStream.getName()).thenReturn("testStream");
304-
byte[] byteArray1 = {(byte) 0xFF};
305-
ByteString byteString1 = ByteString.copyFrom(byteArray1);
245+
246+
// Create proper schema serialization
247+
Schema schema = new Schema(Arrays.asList(
248+
new Field("int1", FieldType.nullable(new ArrowType.Int(32, true)), null),
249+
new Field("string1", FieldType.nullable(new ArrowType.Utf8()), null),
250+
new Field("bool1", FieldType.nullable(new ArrowType.Bool()), null),
251+
new Field("float1", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null)
252+
));
253+
254+
ByteArrayOutputStream schemaOut = new ByteArrayOutputStream();
255+
MessageSerializer.serialize(new WriteChannel(java.nio.channels.Channels.newChannel(schemaOut)), schema);
306256

307257
ByteString bs = mock(ByteString.class);
308258
when(arrowSchema.getSerializedSchema()).thenReturn(bs);
309-
when(bs.toByteArray()).thenReturn(byteArray1);
259+
when(bs.toByteArray()).thenReturn(schemaOut.toByteArray());
310260
when(bigQueryReadClient.readRowsCallable()).thenReturn(ssCallable);
261+
311262
when(ssCallable.call(any(ReadRowsRequest.class))).thenReturn(serverStream);
312-
when(serverStream.iterator()).thenReturn(ImmutableList.of(readRowsResponse).iterator());
313-
when(readRowsResponse.hasArrowRecordBatch()).thenReturn(true);
314-
com.google.cloud.bigquery.storage.v1.ArrowRecordBatch arrowRecordBatch = mock(com.google.cloud.bigquery.storage.v1.ArrowRecordBatch.class);
315-
when(readRowsResponse.getArrowRecordBatch()).thenReturn(arrowRecordBatch);
316-
byte[] byteArray = {(byte) 0xFF};
317-
ByteString byteString = ByteString.copyFrom(byteArray);
318-
when(arrowRecordBatch.getSerializedRecordBatch()).thenReturn(byteString);
319-
ArrowRecordBatch apacheArrowRecordBatch = mock(ArrowRecordBatch.class);
320-
when(MessageSerializer.deserializeRecordBatch(any(ReadChannel.class), any())).thenReturn(apacheArrowRecordBatch);
321-
Mockito.doNothing().when(apacheArrowRecordBatch).close();
263+
264+
// Create real ReadRowsResponse instead of mocking
265+
ReadRowsResponse realReadRowsResponse = createReadRowsResponseExample();
266+
267+
when(serverStream.iterator()).thenReturn(ImmutableList.of(realReadRowsResponse).iterator());
322268

323269
QueryStatusChecker queryStatusChecker = mock(QueryStatusChecker.class);
324270

@@ -327,9 +273,6 @@ public void testReadWithConstraint()
327273

328274
//Ensure that there was a spill so that we can read the spilled block.
329275
assertTrue(spillWriter.spilled());
330-
mockedDefaultVectorLoader.close();
331-
mockedDefaultVectorSchemaRoot.close();
332-
messageSer.close();
333276
}
334277
}
335278

@@ -429,4 +372,76 @@ private TableResult setupMockTableResult() {
429372

430373
return result;
431374
}
375+
376+
public static com.google.cloud.bigquery.storage.v1.ReadRowsResponse createReadRowsResponseExample() throws Exception {
377+
com.google.cloud.bigquery.storage.v1.ArrowRecordBatch arrowRecordBatch = createExample();
378+
379+
ReadRowsResponse build = ReadRowsResponse.newBuilder()
380+
.setArrowRecordBatch(arrowRecordBatch)
381+
.setRowCount(2)
382+
.build();
383+
return build;
384+
}
385+
386+
public static com.google.cloud.bigquery.storage.v1.ArrowRecordBatch createExample() throws Exception {
387+
try(RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {
388+
// Create schema
389+
Schema schema = new Schema(Arrays.asList(
390+
new Field("int1", FieldType.nullable(new ArrowType.Int(32, true)), null),
391+
new Field("string1", FieldType.nullable(new ArrowType.Utf8()), null),
392+
new Field("bool1", FieldType.nullable(new ArrowType.Bool()), null),
393+
new Field("float1", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null)
394+
));
395+
396+
// Create vectors with data
397+
VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
398+
399+
IntVector intVector = (IntVector) root.getVector("int1");
400+
intVector.allocateNew(2);
401+
intVector.set(0, 42);
402+
intVector.set(1, 3);
403+
intVector.setValueCount(2);
404+
405+
VarCharVector stringVector = (VarCharVector) root.getVector("string1");
406+
stringVector.allocateNew(2);
407+
stringVector.set(0, "test".getBytes(StandardCharsets.UTF_8));
408+
stringVector.set(1, "test1".getBytes(StandardCharsets.UTF_8));
409+
stringVector.setValueCount(2);
410+
411+
BitVector boolVector = (BitVector) root.getVector("bool1");
412+
boolVector.allocateNew(2);
413+
boolVector.set(0, 1); // true
414+
boolVector.set(1, 0); // false
415+
boolVector.setValueCount(2);
416+
417+
Float8Vector floatVector = (Float8Vector) root.getVector("float1");
418+
floatVector.allocateNew(2);
419+
floatVector.set(0, 1.0);
420+
floatVector.set(1, 0.0);
421+
floatVector.setValueCount(2);
422+
423+
root.setRowCount(2);
424+
425+
// Use VectorUnloader to create proper ArrowRecordBatch
426+
org.apache.arrow.vector.VectorUnloader unloader = new org.apache.arrow.vector.VectorUnloader(root);
427+
org.apache.arrow.vector.ipc.message.ArrowRecordBatch batch = unloader.getRecordBatch();
428+
429+
// Serialize using MessageSerializer
430+
ByteArrayOutputStream out = new ByteArrayOutputStream();
431+
MessageSerializer.serialize(new WriteChannel(java.nio.channels.Channels.newChannel(out)), batch);
432+
433+
// Create BigQuery ArrowRecordBatch
434+
com.google.cloud.bigquery.storage.v1.ArrowRecordBatch recordBatch =
435+
com.google.cloud.bigquery.storage.v1.ArrowRecordBatch.newBuilder()
436+
.setSerializedRecordBatch(ByteString.copyFrom(out.toByteArray()))
437+
.setRowCount(2)
438+
.build();
439+
440+
batch.close();
441+
root.close();
442+
allocator.close();
443+
444+
return recordBatch;
445+
}
446+
}
432447
}

athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcSplitQueryBuilder.java

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -130,22 +130,16 @@ public PreparedStatement buildSql(
130130
return prepareStatementWithSql(jdbcConnection, catalog, schema, table, tableSchema, constraints, split, columnNames);
131131
}
132132

133-
protected PreparedStatement prepareStatementWithSql(
134-
final Connection jdbcConnection,
133+
protected String buildSQLStringLiteral(
135134
final String catalog,
136135
final String schema,
137136
final String table,
138137
final Schema tableSchema,
139138
final Constraints constraints,
140139
final Split split,
141-
final String columnNames)
142-
throws SQLException
140+
final String columnNames,
141+
List<TypeAndValue> accumulator)
143142
{
144-
if (constraints.getQueryPlan() != null) {
145-
SqlDialect sqlDialect = getSqlDialect();
146-
return prepareStatementWithCalciteSql(jdbcConnection, constraints, sqlDialect, split);
147-
}
148-
149143
StringBuilder sql = new StringBuilder();
150144
sql.append("SELECT ");
151145
sql.append(columnNames);
@@ -155,8 +149,6 @@ protected PreparedStatement prepareStatementWithSql(
155149
}
156150
sql.append(getFromClauseWithSplit(catalog, schema, table, split));
157151

158-
List<TypeAndValue> accumulator = new ArrayList<>();
159-
160152
List<String> clauses = toConjuncts(tableSchema.getFields(), constraints, accumulator, split.getProperties());
161153
clauses.addAll(getPartitionWhereClauses(split));
162154
if (!clauses.isEmpty()) {
@@ -177,7 +169,27 @@ protected PreparedStatement prepareStatementWithSql(
177169
sql.append(appendLimitOffset(split)); // legacy method to preserve functionality of existing connector impls
178170
}
179171
LOGGER.info("Generated SQL : {}", sql.toString());
180-
PreparedStatement statement = jdbcConnection.prepareStatement(sql.toString());
172+
return sql.toString();
173+
}
174+
175+
protected PreparedStatement prepareStatementWithSql(
176+
final Connection jdbcConnection,
177+
final String catalog,
178+
final String schema,
179+
final String table,
180+
final Schema tableSchema,
181+
final Constraints constraints,
182+
final Split split,
183+
final String columnNames)
184+
throws SQLException
185+
{
186+
if (constraints.getQueryPlan() != null) {
187+
SqlDialect sqlDialect = getSqlDialect();
188+
return prepareStatementWithCalciteSql(jdbcConnection, constraints, sqlDialect, split);
189+
}
190+
List<TypeAndValue> accumulator = new ArrayList<>();
191+
PreparedStatement statement = jdbcConnection.prepareStatement(
192+
this.buildSQLStringLiteral(catalog, schema, table, tableSchema, constraints, split, columnNames, accumulator));
181193
// TODO all types, converts Arrow values to JDBC.
182194
for (int i = 0; i < accumulator.size(); i++) {
183195
TypeAndValue typeAndValue = accumulator.get(i);

athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeCompositeHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public class SnowflakeCompositeHandler
4242
{
4343
public SnowflakeCompositeHandler() throws CertificateEncodingException, IOException, NoSuchAlgorithmException, KeyStoreException
4444
{
45-
super(new SnowflakeMetadataHandler(new SnowflakeEnvironmentProperties(System.getenv()).createEnvironment()), new SnowflakeRecordHandler(new SnowflakeEnvironmentProperties(System.getenv()).createEnvironment()));
45+
super(new SnowflakeMetadataHandler(new SnowflakeEnvironmentProperties().createEnvironment()), new SnowflakeRecordHandler(new SnowflakeEnvironmentProperties().createEnvironment()));
4646
installCaCertificate();
4747
setupNativeEnvironmentVariables();
4848
}

0 commit comments

Comments
 (0)