5050import com .google .cloud .bigquery .Table ;
5151import com .google .cloud .bigquery .TableDefinition ;
5252import com .google .cloud .bigquery .TableResult ;
53+ import com .google .cloud .bigquery .storage .v1 .ArrowRecordBatch ;
5354import com .google .cloud .bigquery .storage .v1 .ArrowSchema ;
5455import com .google .cloud .bigquery .storage .v1 .BigQueryReadClient ;
5556import com .google .cloud .bigquery .storage .v1 .CreateReadSessionRequest ;
6970import org .apache .arrow .vector .VectorSchemaRoot ;
7071import org .apache .arrow .vector .complex .ListVector ;
7172import org .apache .arrow .vector .complex .impl .UnionListWriter ;
73+ import org .apache .arrow .vector .ipc .ArrowStreamWriter ;
7274import org .apache .arrow .vector .ipc .ReadChannel ;
73- import org .apache .arrow .vector .ipc .message . ArrowRecordBatch ;
75+ import org .apache .arrow .vector .ipc .WriteChannel ;
7476import org .apache .arrow .vector .ipc .message .MessageSerializer ;
77+ import org .apache .arrow .vector .types .FloatingPointPrecision ;
78+ import org .apache .arrow .vector .types .pojo .ArrowType ;
79+ import org .apache .arrow .vector .types .pojo .Field ;
80+ import org .apache .arrow .vector .types .pojo .FieldType ;
7581import org .apache .arrow .vector .types .pojo .Schema ;
82+ import org .apache .arrow .vector .util .ByteArrayReadableSeekableByteChannel ;
7683import org .junit .After ;
7784import org .junit .Before ;
7885import org .junit .Test ;
9097import software .amazon .awssdk .services .secretsmanager .model .GetSecretValueRequest ;
9198import software .amazon .awssdk .services .secretsmanager .model .GetSecretValueResponse ;
9299
100+ import java .io .ByteArrayOutputStream ;
93101import java .nio .charset .StandardCharsets ;
94102import java .util .Arrays ;
95103import java .util .ArrayList ;
@@ -150,63 +158,9 @@ public class BigQueryRecordHandlerTest
150158 .build ();
151159 private FederatedIdentity federatedIdentity ;
152160 private MockedStatic <BigQueryUtils > mockedStatic ;
153- private MockedStatic <MessageSerializer > messageSer ;
154- MockedConstruction <VectorSchemaRoot > mockedDefaultVectorSchemaRoot ;
155- MockedConstruction <VectorLoader > mockedDefaultVectorLoader ;
156161 @ Mock
157162 private Job queryJob ;
158163
159- public List <FieldVector > getFieldVectors ()
160- {
161- List <FieldVector > fieldVectors = new ArrayList <>();
162- IntVector intVector = new IntVector ("int1" , rootAllocator );
163- intVector .allocateNew (1024 );
164- intVector .setSafe (0 , 42 ); // Example: Set the value at index 0 to 42
165- intVector .setSafe (1 , 3 );
166- intVector .setValueCount (2 );
167- fieldVectors .add (intVector );
168- VarCharVector varcharVector = new VarCharVector ("string1" , rootAllocator );
169- varcharVector .allocateNew (1024 );
170- varcharVector .setSafe (0 , "test" .getBytes (StandardCharsets .UTF_8 )); // Example: Set the value at index 0 to 42
171- varcharVector .setSafe (1 , "test1" .getBytes (StandardCharsets .UTF_8 ));
172- varcharVector .setValueCount (2 );
173- fieldVectors .add (varcharVector );
174- BitVector bitVector = new BitVector ("bool1" , rootAllocator );
175- bitVector .allocateNew (1024 );
176- bitVector .setSafe (0 , 1 ); // Example: Set the value at index 0 to 42
177- bitVector .setSafe (1 , 0 );
178- bitVector .setValueCount (2 );
179- fieldVectors .add (bitVector );
180- Float8Vector float8Vector = new Float8Vector ("float1" , rootAllocator );
181- float8Vector .allocateNew (1024 );
182- float8Vector .setSafe (0 , 1.00f ); // Example: Set the value at index 0 to 42
183- float8Vector .setSafe (1 , 0.0f );
184- float8Vector .setValueCount (2 );
185- fieldVectors .add (float8Vector );
186- IntVector innerVector = new IntVector ("innerVector" , rootAllocator );
187- innerVector .allocateNew (1024 );
188- innerVector .setSafe (0 , 10 );
189- innerVector .setSafe (1 , 20 );
190- innerVector .setSafe (2 , 30 );
191- innerVector .setValueCount (3 );
192-
193- // Create a ListVector and add the inner vector to it
194- ListVector listVector = ListVector .empty ("listVector" , rootAllocator );
195- UnionListWriter writer = listVector .getWriter ();
196- for (int i = 0 ; i < 2 ; i ++) {
197- writer .startList ();
198- writer .setPosition (i );
199- for (int j = 0 ; j < 5 ; j ++) {
200- writer .writeInt (j * i );
201- }
202- writer .setValueCount (5 );
203- writer .endList ();
204- }
205- listVector .setValueCount (2 );
206- fieldVectors .add (listVector );
207- return fieldVectors ;
208- }
209-
210164 @ Before
211165 public void init ()
212166 {
@@ -229,10 +183,9 @@ public void init()
229183 //Create Spill config
230184 spillConfig = SpillConfig .newBuilder ()
231185 .withEncryptionKey (encryptionKey )
232- //This will be enough for a single block
233- .withMaxBlockBytes (100000 )
234186 //This will force the writer to spill.
235- .withMaxInlineBlockBytes (100 )
187+ .withMaxBlockBytes (20 )
188+ .withMaxInlineBlockBytes (1 )
236189 //Async Writing.
237190 .withNumSpillThreads (0 )
238191 .withRequestId (UUID .randomUUID ().toString ())
@@ -278,47 +231,40 @@ public void testReadWithConstraint()
278231 try (ReadRecordsRequest request = getReadRecordsRequest (Collections .emptyMap ())) {
279232 // Mocking necessary dependencies
280233 ReadSession readSession = mock (ReadSession .class );
281- ReadRowsResponse readRowsResponse = mock (ReadRowsResponse .class );
282234 ServerStreamingCallable ssCallable = mock (ServerStreamingCallable .class );
283235
284236 // Mocking method calls
285237 mockStatic (BigQueryReadClient .class );
286238 when (BigQueryReadClient .create ()).thenReturn (bigQueryReadClient );
287- messageSer = mockStatic (MessageSerializer .class );
288- when (MessageSerializer .deserializeSchema ((ReadChannel ) any ())).thenReturn (BigQueryTestUtils .getBlockTestSchema ());
289- mockedDefaultVectorLoader = Mockito .mockConstruction (VectorLoader .class ,
290- (mock , context ) -> {
291- Mockito .doNothing ().when (mock ).load (any ());
292- });
293- mockedDefaultVectorSchemaRoot = Mockito .mockConstruction (VectorSchemaRoot .class ,
294- (mock , context ) -> {
295- when (mock .getRowCount ()).thenReturn (2 );
296- when (mock .getFieldVectors ()).thenReturn (getFieldVectors ());
297- });
298239 when (bigQueryReadClient .createReadSession (any (CreateReadSessionRequest .class ))).thenReturn (readSession );
299240 when (readSession .getArrowSchema ()).thenReturn (arrowSchema );
300241 when (readSession .getStreamsCount ()).thenReturn (1 );
301242 ReadStream readStream = mock (ReadStream .class );
302243 when (readSession .getStreams (anyInt ())).thenReturn (readStream );
303244 when (readStream .getName ()).thenReturn ("testStream" );
304- byte [] byteArray1 = {(byte ) 0xFF };
305- ByteString byteString1 = ByteString .copyFrom (byteArray1 );
245+
246+ // Create proper schema serialization
247+ Schema schema = new Schema (Arrays .asList (
248+ new Field ("int1" , FieldType .nullable (new ArrowType .Int (32 , true )), null ),
249+ new Field ("string1" , FieldType .nullable (new ArrowType .Utf8 ()), null ),
250+ new Field ("bool1" , FieldType .nullable (new ArrowType .Bool ()), null ),
251+ new Field ("float1" , FieldType .nullable (new ArrowType .FloatingPoint (FloatingPointPrecision .DOUBLE )), null )
252+ ));
253+
254+ ByteArrayOutputStream schemaOut = new ByteArrayOutputStream ();
255+ MessageSerializer .serialize (new WriteChannel (java .nio .channels .Channels .newChannel (schemaOut )), schema );
306256
307257 ByteString bs = mock (ByteString .class );
308258 when (arrowSchema .getSerializedSchema ()).thenReturn (bs );
309- when (bs .toByteArray ()).thenReturn (byteArray1 );
259+ when (bs .toByteArray ()).thenReturn (schemaOut . toByteArray () );
310260 when (bigQueryReadClient .readRowsCallable ()).thenReturn (ssCallable );
261+
311262 when (ssCallable .call (any (ReadRowsRequest .class ))).thenReturn (serverStream );
312- when (serverStream .iterator ()).thenReturn (ImmutableList .of (readRowsResponse ).iterator ());
313- when (readRowsResponse .hasArrowRecordBatch ()).thenReturn (true );
314- com .google .cloud .bigquery .storage .v1 .ArrowRecordBatch arrowRecordBatch = mock (com .google .cloud .bigquery .storage .v1 .ArrowRecordBatch .class );
315- when (readRowsResponse .getArrowRecordBatch ()).thenReturn (arrowRecordBatch );
316- byte [] byteArray = {(byte ) 0xFF };
317- ByteString byteString = ByteString .copyFrom (byteArray );
318- when (arrowRecordBatch .getSerializedRecordBatch ()).thenReturn (byteString );
319- ArrowRecordBatch apacheArrowRecordBatch = mock (ArrowRecordBatch .class );
320- when (MessageSerializer .deserializeRecordBatch (any (ReadChannel .class ), any ())).thenReturn (apacheArrowRecordBatch );
321- Mockito .doNothing ().when (apacheArrowRecordBatch ).close ();
263+
264+ // Create real ReadRowsResponse instead of mocking
265+ ReadRowsResponse realReadRowsResponse = createReadRowsResponseExample ();
266+
267+ when (serverStream .iterator ()).thenReturn (ImmutableList .of (realReadRowsResponse ).iterator ());
322268
323269 QueryStatusChecker queryStatusChecker = mock (QueryStatusChecker .class );
324270
@@ -327,9 +273,6 @@ public void testReadWithConstraint()
327273
328274 //Ensure that there was a spill so that we can read the spilled block.
329275 assertTrue (spillWriter .spilled ());
330- mockedDefaultVectorLoader .close ();
331- mockedDefaultVectorSchemaRoot .close ();
332- messageSer .close ();
333276 }
334277 }
335278
@@ -429,4 +372,76 @@ private TableResult setupMockTableResult() {
429372
430373 return result ;
431374 }
375+
376+ public static com .google .cloud .bigquery .storage .v1 .ReadRowsResponse createReadRowsResponseExample () throws Exception {
377+ com .google .cloud .bigquery .storage .v1 .ArrowRecordBatch arrowRecordBatch = createExample ();
378+
379+ ReadRowsResponse build = ReadRowsResponse .newBuilder ()
380+ .setArrowRecordBatch (arrowRecordBatch )
381+ .setRowCount (2 )
382+ .build ();
383+ return build ;
384+ }
385+
386+ public static com .google .cloud .bigquery .storage .v1 .ArrowRecordBatch createExample () throws Exception {
387+ try (RootAllocator allocator = new RootAllocator (Long .MAX_VALUE )) {
388+ // Create schema
389+ Schema schema = new Schema (Arrays .asList (
390+ new Field ("int1" , FieldType .nullable (new ArrowType .Int (32 , true )), null ),
391+ new Field ("string1" , FieldType .nullable (new ArrowType .Utf8 ()), null ),
392+ new Field ("bool1" , FieldType .nullable (new ArrowType .Bool ()), null ),
393+ new Field ("float1" , FieldType .nullable (new ArrowType .FloatingPoint (FloatingPointPrecision .DOUBLE )), null )
394+ ));
395+
396+ // Create vectors with data
397+ VectorSchemaRoot root = VectorSchemaRoot .create (schema , allocator );
398+
399+ IntVector intVector = (IntVector ) root .getVector ("int1" );
400+ intVector .allocateNew (2 );
401+ intVector .set (0 , 42 );
402+ intVector .set (1 , 3 );
403+ intVector .setValueCount (2 );
404+
405+ VarCharVector stringVector = (VarCharVector ) root .getVector ("string1" );
406+ stringVector .allocateNew (2 );
407+ stringVector .set (0 , "test" .getBytes (StandardCharsets .UTF_8 ));
408+ stringVector .set (1 , "test1" .getBytes (StandardCharsets .UTF_8 ));
409+ stringVector .setValueCount (2 );
410+
411+ BitVector boolVector = (BitVector ) root .getVector ("bool1" );
412+ boolVector .allocateNew (2 );
413+ boolVector .set (0 , 1 ); // true
414+ boolVector .set (1 , 0 ); // false
415+ boolVector .setValueCount (2 );
416+
417+ Float8Vector floatVector = (Float8Vector ) root .getVector ("float1" );
418+ floatVector .allocateNew (2 );
419+ floatVector .set (0 , 1.0 );
420+ floatVector .set (1 , 0.0 );
421+ floatVector .setValueCount (2 );
422+
423+ root .setRowCount (2 );
424+
425+ // Use VectorUnloader to create proper ArrowRecordBatch
426+ org .apache .arrow .vector .VectorUnloader unloader = new org .apache .arrow .vector .VectorUnloader (root );
427+ org .apache .arrow .vector .ipc .message .ArrowRecordBatch batch = unloader .getRecordBatch ();
428+
429+ // Serialize using MessageSerializer
430+ ByteArrayOutputStream out = new ByteArrayOutputStream ();
431+ MessageSerializer .serialize (new WriteChannel (java .nio .channels .Channels .newChannel (out )), batch );
432+
433+ // Create BigQuery ArrowRecordBatch
434+ com .google .cloud .bigquery .storage .v1 .ArrowRecordBatch recordBatch =
435+ com .google .cloud .bigquery .storage .v1 .ArrowRecordBatch .newBuilder ()
436+ .setSerializedRecordBatch (ByteString .copyFrom (out .toByteArray ()))
437+ .setRowCount (2 )
438+ .build ();
439+
440+ batch .close ();
441+ root .close ();
442+ allocator .close ();
443+
444+ return recordBatch ;
445+ }
446+ }
432447}
0 commit comments