Skip to content

Commit e3f3d76

Browse files
committed
fix: add remote yellow taxi data example
1 parent cae9b61 commit e3f3d76

File tree

2 files changed

+87
-1
lines changed

2 files changed

+87
-1
lines changed

src/query_farm_airport_test_server/database_impl.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pyarrow.compute as pc
1212
import pyarrow.flight as flight
1313
import query_farm_flight_server.flight_inventory as flight_inventory
14+
import query_farm_flight_server.flight_handling as flight_handling
1415
import query_farm_flight_server.parameter_types as parameter_types
1516

1617
from .utils import CaseInsensitiveDict
@@ -166,6 +167,9 @@ class TableInfo:
166167
# the next row id to assign.
167168
row_id_counter: int = 0
168169

170+
# This cannot be serailized but it convenient for testing.
171+
endpoint_generator: Callable[[Any], list[flight.FlightEndpoint]] | None = None
172+
169173
def update_table(self, table: pa.Table) -> None:
170174
assert table is not None
171175
assert isinstance(table, pa.Table)
@@ -226,6 +230,7 @@ def deserialize(self, data: dict[str, Any]) -> "TableInfo":
226230
"""
227231
self.table_versions = [deserialize_table_data(table) for table in data["table_versions"]]
228232
self.row_id_counter = data["row_id_counter"]
233+
self.endpoint_generator = None
229234
return self
230235

231236

@@ -360,6 +365,7 @@ class DatabaseContents:
360365
version: int = 1
361366

362367
def __post_init__(self) -> None:
368+
self.schemas_by_name["remote_data"] = remote_data_schema
363369
self.schemas_by_name["static_data"] = static_data_schema
364370
self.schemas_by_name["utils"] = util_schema
365371
return
@@ -383,6 +389,7 @@ def deserialize(self, data: dict[str, Any]) -> "DatabaseContents":
383389
{name: SchemaCollection().deserialize(schema) for name, schema in data["schemas"].items()}
384390
)
385391
self.schemas_by_name["static_data"] = static_data_schema
392+
self.schemas_by_name["remote_data"] = remote_data_schema
386393
self.schemas_by_name["utils"] = util_schema
387394

388395
self.version = data["version"]
@@ -650,6 +657,83 @@ def in_out_handler(
650657
return pa.RecordBatch.from_arrays([["last"], ["row"]], schema=output_schema)
651658

652659

660+
def yellow_taxi_endpoint_generator(ticket_data: Any) -> list[flight.FlightEndpoint]:
661+
"""
662+
Generate a list of FlightEndpoint objects for the NYC Yellow Taxi dataset.
663+
"""
664+
files = [
665+
"https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-01.parquet",
666+
"https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-02.parquet",
667+
# "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-03.parquet",
668+
# "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-04.parquet",
669+
# "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-05.parquet",
670+
]
671+
return [
672+
flight_handling.endpoint(
673+
ticket_data=ticket_data,
674+
locations=[
675+
flight_handling.dict_to_msgpack_duckdb_call_data_uri(
676+
{
677+
"function_name": "read_parquet",
678+
# So arguments could be a record batch.
679+
"data": flight_handling.serialize_arrow_ipc_table(
680+
pa.Table.from_pylist(
681+
[
682+
{
683+
"arg_0": files,
684+
"hive_partitioning": False,
685+
"union_by_name": True,
686+
}
687+
],
688+
)
689+
),
690+
}
691+
)
692+
],
693+
)
694+
]
695+
696+
697+
remote_data_schema = SchemaCollection(
698+
scalar_functions_by_name=CaseInsensitiveDict(),
699+
table_functions_by_name=CaseInsensitiveDict(),
700+
tables_by_name=CaseInsensitiveDict(
701+
{
702+
"nyc_yellow_taxi": TableInfo(
703+
table_versions=[
704+
pa.schema(
705+
[
706+
pa.field("VendorID", pa.int32()),
707+
pa.field("tpep_pickup_datetime", pa.timestamp("ms")),
708+
pa.field("tpep_dropoff_datetime", pa.timestamp("ms")),
709+
pa.field("passenger_count", pa.int64()),
710+
pa.field("trip_distance", pa.float64()),
711+
pa.field("RatecodeID", pa.int64()),
712+
pa.field("store_and_fwd_flag", pa.string()),
713+
pa.field("PULocationID", pa.int64()),
714+
pa.field("DOLocationID", pa.int64()),
715+
pa.field("payment_type", pa.int64()),
716+
pa.field("fare_amount", pa.float64()),
717+
pa.field("extra", pa.float64()),
718+
pa.field("mta_tax", pa.float64()),
719+
pa.field("tip_amount", pa.float64()),
720+
pa.field("tolls_amount", pa.float64()),
721+
pa.field("improvement_surcharge", pa.float64()),
722+
pa.field("total_amount", pa.float64()),
723+
pa.field("congestion_surcharge", pa.float64()),
724+
pa.field("Airport_fee", pa.float64()),
725+
pa.field("cbd_congestion_fee", pa.float64()),
726+
]
727+
).empty_table()
728+
],
729+
row_id_counter=0,
730+
endpoint_generator=yellow_taxi_endpoint_generator,
731+
),
732+
}
733+
),
734+
)
735+
736+
653737
static_data_schema = SchemaCollection(
654738
scalar_functions_by_name=CaseInsensitiveDict(),
655739
table_functions_by_name=CaseInsensitiveDict(),

src/query_farm_airport_test_server/server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def action_endpoints(
173173
filter_sql_where_clause = None
174174

175175
if descriptor_parts.type == "table":
176-
schema.by_name("table", descriptor_parts.name)
176+
table_info = schema.by_name("table", descriptor_parts.name)
177177

178178
ticket_data = FlightTicketData(
179179
descriptor=parameters.descriptor,
@@ -182,6 +182,8 @@ def action_endpoints(
182182
at_value=parameters.parameters.at_value,
183183
)
184184

185+
if table_info.endpoint_generator is not None:
186+
return table_info.endpoint_generator(ticket_data)
185187
return [flight_handling.endpoint(ticket_data=ticket_data, locations=None)]
186188
elif descriptor_parts.type == "table_function":
187189
# So the table function may not exist, because its a dynamic descriptor.

0 commit comments

Comments
 (0)