@@ -37,10 +37,7 @@ class TableFunction:
3737 output_schema_source : pa .Schema | TableFunctionDynamicOutput
3838
3939 # The function to call to process a chunk of rows.
40- handler : Callable [
41- [parameter_types .TableFunctionParameters , pa .Schema ],
42- Generator [pa .RecordBatch , pa .RecordBatch , pa .RecordBatch ],
43- ]
40+ handler : Callable [[parameter_types .TableFunctionParameters , pa .Schema ], parameter_types .TableFunctionInOutGenerator ]
4441
4542 estimated_rows : int | Callable [[parameter_types .TableFunctionFlightInfo ], int ] = - 1
4643
@@ -578,6 +575,11 @@ def dynamic_schema_handler_output_schema(
578575 return parameters .schema
579576
580577
578+ def in_out_long_schema_handler (parameters : pa .RecordBatch , input_schema : pa .Schema | None = None ) -> pa .Schema :
579+ assert input_schema is not None
580+ return pa .schema ([input_schema .field (0 )])
581+
582+
581583def in_out_schema_handler (parameters : pa .RecordBatch , input_schema : pa .Schema | None = None ) -> pa .Schema :
582584 assert input_schema is not None
583585 return pa .schema ([parameters .schema .field (0 ), input_schema .field (0 )])
@@ -613,35 +615,43 @@ def in_out_echo_handler(
613615def in_out_wide_handler (
614616 parameters : parameter_types .TableFunctionParameters ,
615617 output_schema : pa .Schema ,
616- ) -> Generator [ pa . RecordBatch , pa . RecordBatch , None ] :
618+ ) -> parameter_types . TableFunctionInOutGenerator :
617619 result = output_schema .empty_table ()
618620
619621 while True :
620- input_chunk = yield result
622+ input_chunk = yield ( result , True )
621623
622624 if input_chunk is None :
623625 break
624626
627+ if isinstance (input_chunk , bool ):
628+ raise NotImplementedError ("Not expecting continuing output for input chunk." )
629+
630+ chunk_length = len (input_chunk )
631+
625632 result = pa .RecordBatch .from_arrays (
626- [[i ] * len ( input_chunk ) for i in range (20 )],
633+ [[i ] * chunk_length for i in range (20 )],
627634 schema = output_schema ,
628635 )
629636
630- return
637+ return None
631638
632639
633640def in_out_handler (
634641 parameters : parameter_types .TableFunctionParameters ,
635642 output_schema : pa .Schema ,
636- ) -> Generator [ pa . RecordBatch , pa . RecordBatch , None ] :
643+ ) -> parameter_types . TableFunctionInOutGenerator :
637644 result = output_schema .empty_table ()
638645
639646 while True :
640- input_chunk = yield result
647+ input_chunk = yield ( result , True )
641648
642649 if input_chunk is None :
643650 break
644651
652+ if isinstance (input_chunk , bool ):
653+ raise NotImplementedError ("Not expecting continuing output for input chunk." )
654+
645655 assert parameters .parameters is not None
646656 parameter_value = parameters .parameters .column (0 ).to_pylist ()[0 ]
647657
@@ -654,7 +664,75 @@ def in_out_handler(
654664 schema = output_schema ,
655665 )
656666
657- return pa .RecordBatch .from_arrays ([["last" ], ["row" ]], schema = output_schema )
667+ return [pa .RecordBatch .from_arrays ([["last" ], ["row" ]], schema = output_schema )]
668+
669+
670+ def in_out_long_handler (
671+ parameters : parameter_types .TableFunctionParameters ,
672+ output_schema : pa .Schema ,
673+ ) -> parameter_types .TableFunctionInOutGenerator :
674+ result = output_schema .empty_table ()
675+
676+ while True :
677+ input_chunk = yield (result , True )
678+
679+ if input_chunk is None :
680+ break
681+
682+ if isinstance (input_chunk , bool ):
683+ raise NotImplementedError ("Not expecting continuing output for input chunk." )
684+
685+ # Return the input chunk ten times.
686+ multiplier = 10
687+ copied_results = [
688+ pa .RecordBatch .from_arrays (
689+ [
690+ input_chunk .column (0 ),
691+ ],
692+ schema = output_schema ,
693+ )
694+ for index in range (multiplier )
695+ ]
696+
697+ for item in copied_results [0 :- 1 ]:
698+ yield (item , False )
699+ result = copied_results [- 1 ]
700+
701+ return None
702+
703+
704+ def in_out_huge_chunk_handler (
705+ parameters : parameter_types .TableFunctionParameters ,
706+ output_schema : pa .Schema ,
707+ ) -> parameter_types .TableFunctionInOutGenerator :
708+ result = output_schema .empty_table ()
709+ multiplier = 10
710+ chunk_length = 5000
711+
712+ while True :
713+ input_chunk = yield (result , True )
714+
715+ if input_chunk is None :
716+ break
717+
718+ if isinstance (input_chunk , bool ):
719+ raise NotImplementedError ("Not expecting continuing output for input chunk." )
720+
721+ for index , _i in enumerate (range (multiplier )):
722+ output = pa .RecordBatch .from_arrays (
723+ [list (range (chunk_length )), list ([index ] * chunk_length )],
724+ schema = output_schema ,
725+ )
726+ if index < multiplier - 1 :
727+ yield (output , False )
728+ else :
729+ result = output
730+
731+ # test big chunks returned as the last results.
732+ return [
733+ pa .RecordBatch .from_arrays ([list (range (chunk_length )), list ([footer_id ] * chunk_length )], schema = output_schema )
734+ for footer_id in (- 1 , - 2 , - 3 )
735+ ]
658736
659737
660738def yellow_taxi_endpoint_generator (ticket_data : Any ) -> list [flight .FlightEndpoint ]:
@@ -739,6 +817,17 @@ def yellow_taxi_endpoint_generator(ticket_data: Any) -> list[flight.FlightEndpoi
739817 table_functions_by_name = CaseInsensitiveDict (),
740818 tables_by_name = CaseInsensitiveDict (
741819 {
820+ "big_chunk" : TableInfo (
821+ table_versions = [
822+ pa .Table .from_arrays (
823+ [
824+ list (range (100000 )),
825+ ],
826+ schema = pa .schema ([pa .field ("id" , pa .int64 ())]),
827+ )
828+ ],
829+ row_id_counter = 0 ,
830+ ),
742831 "employees" : TableInfo (
743832 table_versions = [
744833 pa .Table .from_arrays (
@@ -775,7 +864,7 @@ def yellow_taxi_endpoint_generator(ticket_data: Any) -> list[flight.FlightEndpoi
775864 )
776865 ],
777866 row_id_counter = 2 ,
778- )
867+ ),
779868 }
780869 ),
781870)
@@ -948,6 +1037,41 @@ def collatz_steps(n: int) -> list[int]:
9481037 ),
9491038 handler = in_out_handler ,
9501039 ),
1040+ "test_table_in_out_long" : TableFunction (
1041+ input_schema = pa .schema (
1042+ [
1043+ pa .field (
1044+ "table_input" ,
1045+ pa .string (),
1046+ metadata = {"is_table_type" : "1" },
1047+ ),
1048+ ]
1049+ ),
1050+ output_schema_source = TableFunctionDynamicOutput (
1051+ schema_creator = in_out_long_schema_handler ,
1052+ default_values = (
1053+ pa .RecordBatch .from_arrays (
1054+ [pa .array ([1 ], type = pa .int32 ())],
1055+ schema = pa .schema ([pa .field ("input" , pa .int32 ())]),
1056+ ),
1057+ pa .schema ([pa .field ("input" , pa .int32 ())]),
1058+ ),
1059+ ),
1060+ handler = in_out_long_handler ,
1061+ ),
1062+ "test_table_in_out_huge" : TableFunction (
1063+ input_schema = pa .schema (
1064+ [
1065+ pa .field (
1066+ "table_input" ,
1067+ pa .string (),
1068+ metadata = {"is_table_type" : "1" },
1069+ ),
1070+ ]
1071+ ),
1072+ output_schema_source = pa .schema ([("multiplier" , pa .int64 ()), ("value" , pa .int64 ())]),
1073+ handler = in_out_huge_chunk_handler ,
1074+ ),
9511075 "test_table_in_out_wide" : TableFunction (
9521076 input_schema = pa .schema (
9531077 [
0 commit comments