Skip to content

Commit a11deb3

Browse files
committed
fix: initial work on primary keys
1 parent 37f1da8 commit a11deb3

File tree

3 files changed

+59
-3
lines changed

3 files changed

+59
-3
lines changed

requirements-dev.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ python-levenshtein==0.27.1
8686
# via query-farm-flight-server
8787
query-farm-duckdb-json-serialization==0.1.2
8888
# via query-farm-airport-test-server
89-
query-farm-flight-server==0.1.8
89+
query-farm-flight-server==0.1.10
9090
# via query-farm-airport-test-server
9191
rapidfuzz==3.13.0
9292
# via levenshtein

src/query_farm_airport_test_server/database_impl.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ class TableInfo:
161161
# To enable version history keep track of tables.
162162
table_versions: list[pa.Table] = field(default_factory=list)
163163

164+
primary_key_columns: list[str] = field(default_factory=list)
165+
unique_columns: list[str] = field(default_factory=list)
166+
multi_key_primary_keys: list[str] = field(default_factory=list)
167+
extra_constraints: list[str] = field(default_factory=list)
168+
164169
# the next row id to assign.
165170
row_id_counter: int = 0
166171

@@ -219,6 +224,10 @@ def serialize(self) -> dict[str, Any]:
219224
return {
220225
"table_versions": [serialize_table_data(table) for table in self.table_versions],
221226
"row_id_counter": self.row_id_counter,
227+
"primary_key_columns": self.primary_key_columns,
228+
"unique_columns": self.unique_columns,
229+
"multi_key_primary_keys": self.multi_key_primary_keys,
230+
"extra_constraints": self.extra_constraints,
222231
}
223232

224233
def deserialize(self, data: dict[str, Any]) -> "TableInfo":
@@ -228,6 +237,10 @@ def deserialize(self, data: dict[str, Any]) -> "TableInfo":
228237
self.table_versions = [deserialize_table_data(table) for table in data["table_versions"]]
229238
self.row_id_counter = data["row_id_counter"]
230239
self.endpoint_generator = None
240+
self.primary_key_columns = data.get("primary_key_columns", [])
241+
self.unique_columns = data.get("unique_columns", [])
242+
self.multi_key_primary_keys = data.get("multi_key_primary_keys", [])
243+
self.extra_constraints = data.get("extra_constraints", [])
231244
return self
232245

233246

@@ -374,7 +387,11 @@ def by_name(self, name: str) -> SchemaCollection:
374387

375388
def serialize(self) -> dict[str, Any]:
376389
return {
377-
"schemas": {name: schema.serialize() for name, schema in self.schemas_by_name.items()},
390+
"schemas": {
391+
name: schema.serialize()
392+
for name, schema in self.schemas_by_name.items()
393+
if name not in ("static_data", "remote_data", "utils")
394+
},
378395
"version": self.version,
379396
}
380397

src/query_farm_airport_test_server/server.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,14 @@ def action_create_table(
395395

396396
schema_with_row_id = parameters.arrow_schema.append(self.rowid_field)
397397

398-
table_info = TableInfo([schema_with_row_id.empty_table()], 0)
398+
table_info = TableInfo(
399+
table_versions=[schema_with_row_id.empty_table()],
400+
row_id_counter=0,
401+
primary_key_columns=parameters.primary_key_columns,
402+
unique_columns=parameters.unique_columns,
403+
multi_key_primary_keys=parameters.multi_key_primary_keys,
404+
extra_constraints=parameters.extra_constraints,
405+
)
399406

400407
schema.tables_by_name[parameters.table_name] = table_info
401408

@@ -653,6 +660,38 @@ def exchange_insert(
653660
for chunk in reader:
654661
if chunk.data is not None:
655662
new_rows = pa.Table.from_batches([chunk.data])
663+
664+
for column_name in table_info.primary_key_columns:
665+
insert_values = new_rows.column(column_name)
666+
667+
# Make sure that all elements are unique
668+
value_counts = pc.value_counts(insert_values)
669+
duplicates_struct = pc.filter(value_counts, pc.greater(value_counts.field("counts"), 1))
670+
duplicates = duplicates_struct.field("values")
671+
if len(duplicates) > 0:
672+
raise flight.FlightServerError(
673+
f"Cannot insert rows with duplicate values in primary key column {column_name}: {duplicates}",
674+
extra_info=msgpack.packb(
675+
{
676+
"exception_type": "ConstraintException",
677+
"message": f"""Duplicate key "{column_name}: {duplicates[0]}" violates primary key constraint.""",
678+
}
679+
),
680+
)
681+
682+
already_exists = pc.is_in(insert_values, value_set=existing_table.column(column_name))
683+
existing_rows = pc.filter(insert_values, already_exists)
684+
if len(existing_rows) > 0:
685+
raise flight.FlightServerError(
686+
f"Cannot insert rows with duplicate values in primary key column {column_name}: {duplicates}",
687+
extra_info=msgpack.packb(
688+
{
689+
"exception_type": "ConstraintException",
690+
"message": f"""Duplicate key "{column_name}: {existing_rows[0]}" violates primary key constraint.""",
691+
}
692+
),
693+
)
694+
656695
assert new_rows.num_rows > 0
657696

658697
# append the row id column to the new rows.

0 commit comments

Comments
 (0)