Skip to content

Commit cae9b61

Browse files
committed
fix: add readonly flag for ops that don't change things
1 parent a777eb0 commit cae9b61

File tree

2 files changed

+70
-70
lines changed

2 files changed

+70
-70
lines changed

src/query_farm_airport_test_server/database_impl.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,12 +473,13 @@ def write_to_file(self, token: str) -> None:
473473

474474
with tempfile.NamedTemporaryFile("wb", dir=dir_name, delete=False) as tmp_file:
475475
pickle.dump(data, tmp_file)
476-
os.replace(tmp_file.name, file_path)
476+
os.replace(tmp_file.name, file_path)
477477

478478

479479
class DatabaseLibraryContext:
480-
def __init__(self, token: str):
480+
def __init__(self, token: str, readonly: bool = False) -> None:
481481
self.token = token
482+
self.readonly = readonly
482483

483484
def __enter__(self) -> DatabaseLibrary:
484485
self.db = DatabaseLibrary.read_from_file(self.token)
@@ -489,7 +490,8 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
489490
print(f"An error occurred: {exc_val}")
490491
# Optionally return True to suppress the exception
491492
return
492-
self.db.write_to_file(self.token)
493+
if not self.readonly:
494+
self.db.write_to_file(self.token)
493495

494496

495497
def add_handler(table: pa.Table) -> pa.Array:

src/query_farm_airport_test_server/server.py

Lines changed: 65 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,6 @@ def __init__(
144144
self._auth_manager = auth_manager
145145
super().__init__(location=location, **kwargs)
146146

147-
# token, database name, schema, table_name
148-
# self.contents: dict[str, DatabaseLibrary] = {}
149-
150147
self.ROWID_FIELD_NAME = "rowid"
151148
self.rowid_field = pa.field(self.ROWID_FIELD_NAME, pa.int64(), metadata={"is_rowid": "1"})
152149

@@ -159,7 +156,7 @@ def action_endpoints(
159156
assert context.caller is not None
160157

161158
descriptor_parts = descriptor_unpack_(parameters.descriptor)
162-
with DatabaseLibraryContext(context.caller.token.token) as library:
159+
with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
163160
database = library.by_name(descriptor_parts.catalog_name)
164161
schema = database.by_name(descriptor_parts.schema_name)
165162

@@ -292,7 +289,7 @@ def action_catalog_version(
292289
) -> base_server.GetCatalogVersionResult:
293290
assert context.caller is not None
294291

295-
with DatabaseLibraryContext(context.caller.token.token) as library:
292+
with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
296293
database = library.by_name(parameters.catalog_name)
297294

298295
context.logger.debug(
@@ -402,11 +399,11 @@ def action_create_table(
402399

403400
database.version += 1
404401

405-
return table_info.flight_info(
406-
name=parameters.table_name,
407-
catalog_name=parameters.catalog_name,
408-
schema_name=parameters.schema_name,
409-
)[0]
402+
return table_info.flight_info(
403+
name=parameters.table_name,
404+
catalog_name=parameters.catalog_name,
405+
schema_name=parameters.schema_name,
406+
)[0]
410407

411408
def impl_do_action(
412409
self,
@@ -554,7 +551,7 @@ def exchange_update(
554551

555552
table_info.update_table(existing_table)
556553

557-
return change_count
554+
return change_count
558555

559556
def exchange_delete(
560557
self,
@@ -613,7 +610,7 @@ def exchange_delete(
613610

614611
table_info.update_table(existing_table)
615612

616-
return change_count
613+
return change_count
617614

618615
def exchange_insert(
619616
self,
@@ -630,6 +627,7 @@ def exchange_insert(
630627

631628
if descriptor_parts.type != "table":
632629
raise flight.FlightServerError(f"Unsupported descriptor type: {descriptor_parts.type}")
630+
633631
with DatabaseLibraryContext(context.caller.token.token) as library:
634632
database = library.by_name(descriptor_parts.catalog_name)
635633
schema = database.by_name(descriptor_parts.schema_name)
@@ -686,11 +684,11 @@ def exchange_insert(
686684
# )
687685
new_rows = conform_nullable(existing_table.schema, new_rows)
688686

689-
table_info.update_table(
690-
pa.concat_tables([existing_table, new_rows.select(existing_table.schema.names)])
691-
)
687+
existing_table = pa.concat_tables([existing_table, new_rows.select(existing_table.schema.names)])
688+
689+
table_info.update_table(existing_table)
692690

693-
return change_count
691+
return change_count
694692

695693
def exchange_table_function_in_out(
696694
self,
@@ -713,7 +711,7 @@ def exchange_table_function_in_out(
713711

714712
output_schema = table_info.output_schema(parameters=parameters.parameters, input_schema=input_schema)
715713
gen = table_info.handler(parameters, output_schema)
716-
return (output_schema, gen)
714+
return (output_schema, gen)
717715

718716
def action_add_column(
719717
self,
@@ -748,11 +746,11 @@ def action_add_column(
748746
)
749747
database.version += 1
750748

751-
return table_info.flight_info(
752-
name=parameters.name,
753-
catalog_name=parameters.catalog,
754-
schema_name=parameters.schema_name,
755-
)[0]
749+
return table_info.flight_info(
750+
name=parameters.name,
751+
catalog_name=parameters.catalog,
752+
schema_name=parameters.schema_name,
753+
)[0]
756754

757755
def action_remove_column(
758756
self,
@@ -771,11 +769,11 @@ def action_remove_column(
771769
table_info.update_table(table_info.version().drop(parameters.removed_column))
772770
database.version += 1
773771

774-
return table_info.flight_info(
775-
name=parameters.name,
776-
catalog_name=parameters.catalog,
777-
schema_name=parameters.schema_name,
778-
)[0]
772+
return table_info.flight_info(
773+
name=parameters.name,
774+
catalog_name=parameters.catalog,
775+
schema_name=parameters.schema_name,
776+
)[0]
779777

780778
def action_rename_column(
781779
self,
@@ -794,11 +792,11 @@ def action_rename_column(
794792
table_info.update_table(table_info.version().rename_columns({parameters.old_name: parameters.new_name}))
795793
database.version += 1
796794

797-
return table_info.flight_info(
798-
name=parameters.name,
799-
catalog_name=parameters.catalog,
800-
schema_name=parameters.schema_name,
801-
)[0]
795+
return table_info.flight_info(
796+
name=parameters.name,
797+
catalog_name=parameters.catalog,
798+
schema_name=parameters.schema_name,
799+
)[0]
802800

803801
def action_rename_table(
804802
self,
@@ -818,11 +816,11 @@ def action_rename_table(
818816

819817
database.version += 1
820818

821-
return table_info.flight_info(
822-
name=parameters.new_table_name,
823-
catalog_name=parameters.catalog,
824-
schema_name=parameters.schema_name,
825-
)[0]
819+
return table_info.flight_info(
820+
name=parameters.new_table_name,
821+
catalog_name=parameters.catalog,
822+
schema_name=parameters.schema_name,
823+
)[0]
826824

827825
def action_set_default(
828826
self,
@@ -858,11 +856,11 @@ def action_set_default(
858856

859857
database.version += 1
860858

861-
return table_info.flight_info(
862-
name=parameters.name,
863-
catalog_name=parameters.catalog,
864-
schema_name=parameters.schema_name,
865-
)[0]
859+
return table_info.flight_info(
860+
name=parameters.name,
861+
catalog_name=parameters.catalog,
862+
schema_name=parameters.schema_name,
863+
)[0]
866864

867865
def action_set_not_null(
868866
self,
@@ -895,11 +893,11 @@ def action_set_not_null(
895893

896894
database.version += 1
897895

898-
return table_info.flight_info(
899-
name=parameters.name,
900-
catalog_name=parameters.catalog,
901-
schema_name=parameters.schema_name,
902-
)[0]
896+
return table_info.flight_info(
897+
name=parameters.name,
898+
catalog_name=parameters.catalog,
899+
schema_name=parameters.schema_name,
900+
)[0]
903901

904902
def action_drop_not_null(
905903
self,
@@ -929,11 +927,11 @@ def action_drop_not_null(
929927

930928
database.version += 1
931929

932-
return table_info.flight_info(
933-
name=parameters.name,
934-
catalog_name=parameters.catalog,
935-
schema_name=parameters.schema_name,
936-
)[0]
930+
return table_info.flight_info(
931+
name=parameters.name,
932+
catalog_name=parameters.catalog,
933+
schema_name=parameters.schema_name,
934+
)[0]
937935

938936
def action_change_column_type(
939937
self,
@@ -970,11 +968,11 @@ def action_change_column_type(
970968

971969
database.version += 1
972970

973-
return table_info.flight_info(
974-
name=parameters.name,
975-
catalog_name=parameters.catalog,
976-
schema_name=parameters.schema_name,
977-
)[0]
971+
return table_info.flight_info(
972+
name=parameters.name,
973+
catalog_name=parameters.catalog,
974+
schema_name=parameters.schema_name,
975+
)[0]
978976

979977
def action_column_statistics(
980978
self,
@@ -985,7 +983,7 @@ def action_column_statistics(
985983
assert context.caller is not None
986984

987985
descriptor_parts = descriptor_unpack_(parameters.flight_descriptor)
988-
with DatabaseLibraryContext(context.caller.token.token) as library:
986+
with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
989987
database = library.by_name(descriptor_parts.catalog_name)
990988
schema = database.by_name(descriptor_parts.schema_name)
991989

@@ -1026,7 +1024,7 @@ def action_column_statistics(
10261024
]
10271025
),
10281026
)
1029-
return result_table
1027+
return result_table
10301028

10311029
def impl_do_get(
10321030
self,
@@ -1039,7 +1037,7 @@ def impl_do_get(
10391037
ticket_data = flight_handling.decode_ticket_model(ticket, FlightTicketData)
10401038

10411039
descriptor_parts = descriptor_unpack_(ticket_data.descriptor)
1042-
with DatabaseLibraryContext(context.caller.token.token) as library:
1040+
with DatabaseLibraryContext(context.caller.token.token, readonly=True) as library:
10431041
database = library.by_name(descriptor_parts.catalog_name)
10441042
schema = database.by_name(descriptor_parts.schema_name)
10451043

@@ -1111,13 +1109,13 @@ def action_table_function_flight_info(
11111109
# Now to get the table function, its a bit harder since they are named by action.
11121110
table_function = schema.by_name("table_function", descriptor_parts.name)
11131111

1114-
return table_function.flight_info(
1115-
name=descriptor_parts.name,
1116-
catalog_name=descriptor_parts.catalog_name,
1117-
schema_name=descriptor_parts.schema_name,
1118-
# Pass the real parameters here.
1119-
parameters=parameters,
1120-
)[0]
1112+
return table_function.flight_info(
1113+
name=descriptor_parts.name,
1114+
catalog_name=descriptor_parts.catalog_name,
1115+
schema_name=descriptor_parts.schema_name,
1116+
# Pass the real parameters here.
1117+
parameters=parameters,
1118+
)[0]
11211119

11221120
def action_flight_info(
11231121
self,

0 commit comments

Comments
 (0)