1111import pyarrow .compute as pc
1212import pyarrow .flight as flight
1313import query_farm_flight_server .flight_inventory as flight_inventory
14+ import query_farm_flight_server .flight_handling as flight_handling
1415import query_farm_flight_server .parameter_types as parameter_types
1516
1617from .utils import CaseInsensitiveDict
@@ -166,6 +167,9 @@ class TableInfo:
166167 # the next row id to assign.
167168 row_id_counter : int = 0
168169
170+ # This cannot be serailized but it convenient for testing.
171+ endpoint_generator : Callable [[Any ], list [flight .FlightEndpoint ]] | None = None
172+
169173 def update_table (self , table : pa .Table ) -> None :
170174 assert table is not None
171175 assert isinstance (table , pa .Table )
@@ -226,6 +230,7 @@ def deserialize(self, data: dict[str, Any]) -> "TableInfo":
226230 """
227231 self .table_versions = [deserialize_table_data (table ) for table in data ["table_versions" ]]
228232 self .row_id_counter = data ["row_id_counter" ]
233+ self .endpoint_generator = None
229234 return self
230235
231236
@@ -360,6 +365,7 @@ class DatabaseContents:
360365 version : int = 1
361366
362367 def __post_init__ (self ) -> None :
368+ self .schemas_by_name ["remote_data" ] = remote_data_schema
363369 self .schemas_by_name ["static_data" ] = static_data_schema
364370 self .schemas_by_name ["utils" ] = util_schema
365371 return
@@ -383,6 +389,7 @@ def deserialize(self, data: dict[str, Any]) -> "DatabaseContents":
383389 {name : SchemaCollection ().deserialize (schema ) for name , schema in data ["schemas" ].items ()}
384390 )
385391 self .schemas_by_name ["static_data" ] = static_data_schema
392+ self .schemas_by_name ["remote_data" ] = remote_data_schema
386393 self .schemas_by_name ["utils" ] = util_schema
387394
388395 self .version = data ["version" ]
@@ -650,6 +657,83 @@ def in_out_handler(
650657 return pa .RecordBatch .from_arrays ([["last" ], ["row" ]], schema = output_schema )
651658
652659
660+ def yellow_taxi_endpoint_generator (ticket_data : Any ) -> list [flight .FlightEndpoint ]:
661+ """
662+ Generate a list of FlightEndpoint objects for the NYC Yellow Taxi dataset.
663+ """
664+ files = [
665+ "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-01.parquet" ,
666+ "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-02.parquet" ,
667+ # "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-03.parquet",
668+ # "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-04.parquet",
669+ # "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-05.parquet",
670+ ]
671+ return [
672+ flight_handling .endpoint (
673+ ticket_data = ticket_data ,
674+ locations = [
675+ flight_handling .dict_to_msgpack_duckdb_call_data_uri (
676+ {
677+ "function_name" : "read_parquet" ,
678+ # So arguments could be a record batch.
679+ "data" : flight_handling .serialize_arrow_ipc_table (
680+ pa .Table .from_pylist (
681+ [
682+ {
683+ "arg_0" : files ,
684+ "hive_partitioning" : False ,
685+ "union_by_name" : True ,
686+ }
687+ ],
688+ )
689+ ),
690+ }
691+ )
692+ ],
693+ )
694+ ]
695+
696+
697+ remote_data_schema = SchemaCollection (
698+ scalar_functions_by_name = CaseInsensitiveDict (),
699+ table_functions_by_name = CaseInsensitiveDict (),
700+ tables_by_name = CaseInsensitiveDict (
701+ {
702+ "nyc_yellow_taxi" : TableInfo (
703+ table_versions = [
704+ pa .schema (
705+ [
706+ pa .field ("VendorID" , pa .int32 ()),
707+ pa .field ("tpep_pickup_datetime" , pa .timestamp ("ms" )),
708+ pa .field ("tpep_dropoff_datetime" , pa .timestamp ("ms" )),
709+ pa .field ("passenger_count" , pa .int64 ()),
710+ pa .field ("trip_distance" , pa .float64 ()),
711+ pa .field ("RatecodeID" , pa .int64 ()),
712+ pa .field ("store_and_fwd_flag" , pa .string ()),
713+ pa .field ("PULocationID" , pa .int64 ()),
714+ pa .field ("DOLocationID" , pa .int64 ()),
715+ pa .field ("payment_type" , pa .int64 ()),
716+ pa .field ("fare_amount" , pa .float64 ()),
717+ pa .field ("extra" , pa .float64 ()),
718+ pa .field ("mta_tax" , pa .float64 ()),
719+ pa .field ("tip_amount" , pa .float64 ()),
720+ pa .field ("tolls_amount" , pa .float64 ()),
721+ pa .field ("improvement_surcharge" , pa .float64 ()),
722+ pa .field ("total_amount" , pa .float64 ()),
723+ pa .field ("congestion_surcharge" , pa .float64 ()),
724+ pa .field ("Airport_fee" , pa .float64 ()),
725+ pa .field ("cbd_congestion_fee" , pa .float64 ()),
726+ ]
727+ ).empty_table ()
728+ ],
729+ row_id_counter = 0 ,
730+ endpoint_generator = yellow_taxi_endpoint_generator ,
731+ ),
732+ }
733+ ),
734+ )
735+
736+
653737static_data_schema = SchemaCollection (
654738 scalar_functions_by_name = CaseInsensitiveDict (),
655739 table_functions_by_name = CaseInsensitiveDict (),
0 commit comments