Skip to content

Commit 237333d

Browse files
authored
Update-schema: Add support for initial-default (#1770)
<!-- Thanks for opening a pull request! --> <!-- In the case this PR will resolve an issue, please replace ${GITHUB_ISSUE_ID} below with the actual Github issue id. --> <!-- Closes #${GITHUB_ISSUE_ID} --> # Rationale for this change This allows for V3 initial defaults. This PR took a bit longer than anticipated, mostly because the Pydantic json deserialization. There is a certain way we need to serialize python types to [JSON single value encoding](https://iceberg.apache.org/spec/#json-single-value-serialization). # Are these changes tested? Added new tests # Are there any user-facing changes? After this PRs initial defaults can be set through the API. This enables users to add required fields. <!-- In the case of user-facing changes, please add the changelog label. -->
1 parent b85127e commit 237333d

File tree

10 files changed

+311
-47
lines changed

10 files changed

+311
-47
lines changed

pyiceberg/avro/resolver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def struct(self, file_schema: StructType, record_struct: Optional[IcebergType],
290290
# There is a default value
291291
if file_field.write_default is not None:
292292
# The field is not in the record, but there is a write default value
293-
results.append((None, DefaultWriter(writer=writer, value=file_field.write_default))) # type: ignore
293+
results.append((None, DefaultWriter(writer=writer, value=file_field.write_default)))
294294
elif file_field.required:
295295
raise ValueError(f"Field is required, and there is no write default: {file_field}")
296296
else:

pyiceberg/conversions.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -503,27 +503,47 @@ def _(_: Union[IntegerType, LongType], val: int) -> int:
503503

504504

505505
@from_json.register(DateType)
506-
def _(_: DateType, val: str) -> date:
506+
def _(_: DateType, val: Union[str, int, date]) -> date:
507507
"""JSON date is string encoded."""
508-
return days_to_date(date_str_to_days(val))
508+
if isinstance(val, str):
509+
val = date_str_to_days(val)
510+
if isinstance(val, int):
511+
return days_to_date(val)
512+
else:
513+
return val
509514

510515

511516
@from_json.register(TimeType)
512-
def _(_: TimeType, val: str) -> time:
517+
def _(_: TimeType, val: Union[str, int, time]) -> time:
513518
"""JSON ISO8601 string into Python time."""
514-
return micros_to_time(time_str_to_micros(val))
519+
if isinstance(val, str):
520+
val = time_str_to_micros(val)
521+
if isinstance(val, int):
522+
return micros_to_time(val)
523+
else:
524+
return val
515525

516526

517527
@from_json.register(TimestampType)
518-
def _(_: PrimitiveType, val: str) -> datetime:
528+
def _(_: PrimitiveType, val: Union[str, int, datetime]) -> datetime:
519529
"""JSON ISO8601 string into Python datetime."""
520-
return micros_to_timestamp(timestamp_to_micros(val))
530+
if isinstance(val, str):
531+
val = timestamp_to_micros(val)
532+
if isinstance(val, int):
533+
return micros_to_timestamp(val)
534+
else:
535+
return val
521536

522537

523538
@from_json.register(TimestamptzType)
524-
def _(_: TimestamptzType, val: str) -> datetime:
539+
def _(_: TimestamptzType, val: Union[str, int, datetime]) -> datetime:
525540
"""JSON ISO8601 string into Python datetime."""
526-
return micros_to_timestamptz(timestamptz_to_micros(val))
541+
if isinstance(val, str):
542+
val = timestamptz_to_micros(val)
543+
if isinstance(val, int):
544+
return micros_to_timestamptz(val)
545+
else:
546+
return val
527547

528548

529549
@from_json.register(FloatType)
@@ -540,20 +560,24 @@ def _(_: StringType, val: str) -> str:
540560

541561

542562
@from_json.register(FixedType)
543-
def _(t: FixedType, val: str) -> bytes:
563+
def _(t: FixedType, val: Union[str, bytes]) -> bytes:
544564
"""JSON hexadecimal encoded string into bytes."""
545-
b = codecs.decode(val.encode(UTF8), "hex")
565+
if isinstance(val, str):
566+
val = codecs.decode(val.encode(UTF8), "hex")
546567

547-
if len(t) != len(b):
548-
raise ValueError(f"FixedType has length {len(t)}, which is different from the value: {len(b)}")
568+
if len(t) != len(val):
569+
raise ValueError(f"FixedType has length {len(t)}, which is different from the value: {len(val)}")
549570

550-
return b
571+
return val
551572

552573

553574
@from_json.register(BinaryType)
554-
def _(_: BinaryType, val: str) -> bytes:
575+
def _(_: BinaryType, val: Union[bytes, str]) -> bytes:
555576
"""JSON hexadecimal encoded string into bytes."""
556-
return codecs.decode(val.encode(UTF8), "hex")
577+
if isinstance(val, str):
578+
return codecs.decode(val.encode(UTF8), "hex")
579+
else:
580+
return val
557581

558582

559583
@from_json.register(DecimalType)
@@ -563,6 +587,11 @@ def _(_: DecimalType, val: str) -> Decimal:
563587

564588

565589
@from_json.register(UUIDType)
566-
def _(_: UUIDType, val: str) -> uuid.UUID:
590+
def _(_: UUIDType, val: Union[str, bytes, uuid.UUID]) -> uuid.UUID:
567591
"""Convert JSON string into Python UUID."""
568-
return uuid.UUID(val)
592+
if isinstance(val, str):
593+
return uuid.UUID(val)
594+
elif isinstance(val, bytes):
595+
return uuid.UUID(bytes=val)
596+
else:
597+
return val

pyiceberg/expressions/literals.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import struct
2525
from abc import ABC, abstractmethod
26-
from datetime import date, datetime
26+
from datetime import date, datetime, time
2727
from decimal import ROUND_HALF_UP, Decimal
2828
from functools import singledispatchmethod
2929
from math import isnan
@@ -54,6 +54,7 @@
5454
datetime_to_micros,
5555
micros_to_days,
5656
time_str_to_micros,
57+
time_to_micros,
5758
timestamp_to_micros,
5859
timestamptz_to_micros,
5960
)
@@ -152,6 +153,8 @@ def literal(value: L) -> Literal[L]:
152153
return TimestampLiteral(datetime_to_micros(value)) # type: ignore
153154
elif isinstance(value, date):
154155
return DateLiteral(date_to_days(value)) # type: ignore
156+
elif isinstance(value, time):
157+
return TimeLiteral(time_to_micros(value)) # type: ignore
155158
else:
156159
raise TypeError(f"Invalid literal value: {repr(value)}")
157160

pyiceberg/table/update/schema.py

Lines changed: 121 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
from copy import copy
2121
from dataclasses import dataclass
2222
from enum import Enum
23-
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
23+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
2424

2525
from pyiceberg.exceptions import ResolveError, ValidationError
26+
from pyiceberg.expressions import literal # type: ignore
2627
from pyiceberg.schema import (
2728
PartnerAccessor,
2829
Schema,
@@ -47,6 +48,7 @@
4748
UpdatesAndRequirements,
4849
UpdateTableMetadata,
4950
)
51+
from pyiceberg.typedef import L
5052
from pyiceberg.types import IcebergType, ListType, MapType, NestedField, PrimitiveType, StructType
5153

5254
if TYPE_CHECKING:
@@ -153,7 +155,12 @@ def union_by_name(self, new_schema: Union[Schema, "pa.Schema"]) -> UpdateSchema:
153155
return self
154156

155157
def add_column(
156-
self, path: Union[str, Tuple[str, ...]], field_type: IcebergType, doc: Optional[str] = None, required: bool = False
158+
self,
159+
path: Union[str, Tuple[str, ...]],
160+
field_type: IcebergType,
161+
doc: Optional[str] = None,
162+
required: bool = False,
163+
default_value: Optional[L] = None,
157164
) -> UpdateSchema:
158165
"""Add a new column to a nested struct or Add a new top-level column.
159166
@@ -168,6 +175,7 @@ def add_column(
168175
field_type: Type for the new column.
169176
doc: Documentation string for the new column.
170177
required: Whether the new column is required.
178+
default_value: Default value for the new column.
171179
172180
Returns:
173181
This for method chaining.
@@ -177,10 +185,6 @@ def add_column(
177185
raise ValueError(f"Cannot add column with ambiguous name: {path}, provide a tuple instead")
178186
path = (path,)
179187

180-
if required and not self._allow_incompatible_changes:
181-
# Table format version 1 and 2 cannot add required column because there is no initial value
182-
raise ValueError(f"Incompatible change: cannot add required column: {'.'.join(path)}")
183-
184188
name = path[-1]
185189
parent = path[:-1]
186190

@@ -212,13 +216,34 @@ def add_column(
212216

213217
# assign new IDs in order
214218
new_id = self.assign_new_column_id()
219+
new_type = assign_fresh_schema_ids(field_type, self.assign_new_column_id)
220+
221+
if default_value is not None:
222+
try:
223+
# To make sure that the value is valid for the type
224+
initial_default = literal(default_value).to(new_type).value
225+
except ValueError as e:
226+
raise ValueError(f"Invalid default value: {e}") from e
227+
else:
228+
initial_default = default_value # type: ignore
229+
230+
if (required and initial_default is None) and not self._allow_incompatible_changes:
231+
# Table format version 1 and 2 cannot add required column because there is no initial value
232+
raise ValueError(f"Incompatible change: cannot add required column: {'.'.join(path)}")
215233

216234
# update tracking for moves
217235
self._added_name_to_id[full_name] = new_id
218236
self._id_to_parent[new_id] = parent_full_path
219237

220-
new_type = assign_fresh_schema_ids(field_type, self.assign_new_column_id)
221-
field = NestedField(field_id=new_id, name=name, field_type=new_type, required=required, doc=doc)
238+
field = NestedField(
239+
field_id=new_id,
240+
name=name,
241+
field_type=new_type,
242+
required=required,
243+
doc=doc,
244+
initial_default=initial_default,
245+
write_default=initial_default,
246+
)
222247

223248
if parent_id in self._adds:
224249
self._adds[parent_id].append(field)
@@ -250,6 +275,19 @@ def delete_column(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema:
250275

251276
return self
252277

278+
def set_default_value(self, path: Union[str, Tuple[str, ...]], default_value: Optional[L]) -> UpdateSchema:
279+
"""Set the default value of a column.
280+
281+
Args:
282+
path: The path to the column.
283+
284+
Returns:
285+
The UpdateSchema with the delete operation staged.
286+
"""
287+
self._set_column_default_value(path, default_value)
288+
289+
return self
290+
253291
def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) -> UpdateSchema:
254292
"""Update the name of a column.
255293
@@ -273,6 +311,8 @@ def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) -
273311
field_type=updated.field_type,
274312
doc=updated.doc,
275313
required=updated.required,
314+
initial_default=updated.initial_default,
315+
write_default=updated.write_default,
276316
)
277317
else:
278318
self._updates[field_from.field_id] = NestedField(
@@ -281,6 +321,8 @@ def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) -
281321
field_type=field_from.field_type,
282322
doc=field_from.doc,
283323
required=field_from.required,
324+
initial_default=field_from.initial_default,
325+
write_default=field_from.write_default,
284326
)
285327

286328
# Lookup the field because of casing
@@ -330,6 +372,8 @@ def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: b
330372
field_type=updated.field_type,
331373
doc=updated.doc,
332374
required=required,
375+
initial_default=updated.initial_default,
376+
write_default=updated.write_default,
333377
)
334378
else:
335379
self._updates[field.field_id] = NestedField(
@@ -338,6 +382,52 @@ def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: b
338382
field_type=field.field_type,
339383
doc=field.doc,
340384
required=required,
385+
initial_default=field.initial_default,
386+
write_default=field.write_default,
387+
)
388+
389+
def _set_column_default_value(self, path: Union[str, Tuple[str, ...]], default_value: Any) -> None:
390+
path = (path,) if isinstance(path, str) else path
391+
name = ".".join(path)
392+
393+
field = self._schema.find_field(name, self._case_sensitive)
394+
395+
if default_value is not None:
396+
try:
397+
# To make sure that the value is valid for the type
398+
default_value = literal(default_value).to(field.field_type).value
399+
except ValueError as e:
400+
raise ValueError(f"Invalid default value: {e}") from e
401+
402+
if field.required and default_value == field.write_default:
403+
# if the change is a noop, allow it even if allowIncompatibleChanges is false
404+
return
405+
406+
if not self._allow_incompatible_changes and field.required and default_value is None:
407+
raise ValueError("Cannot change change default-value of a required column to None")
408+
409+
if field.field_id in self._deletes:
410+
raise ValueError(f"Cannot update a column that will be deleted: {name}")
411+
412+
if updated := self._updates.get(field.field_id):
413+
self._updates[field.field_id] = NestedField(
414+
field_id=updated.field_id,
415+
name=updated.name,
416+
field_type=updated.field_type,
417+
doc=updated.doc,
418+
required=updated.required,
419+
initial_default=updated.initial_default,
420+
write_default=default_value,
421+
)
422+
else:
423+
self._updates[field.field_id] = NestedField(
424+
field_id=field.field_id,
425+
name=field.name,
426+
field_type=field.field_type,
427+
doc=field.doc,
428+
required=field.required,
429+
initial_default=field.initial_default,
430+
write_default=default_value,
341431
)
342432

343433
def update_column(
@@ -387,6 +477,8 @@ def update_column(
387477
field_type=field_type or updated.field_type,
388478
doc=doc if doc is not None else updated.doc,
389479
required=updated.required,
480+
initial_default=updated.initial_default,
481+
write_default=updated.write_default,
390482
)
391483
else:
392484
self._updates[field.field_id] = NestedField(
@@ -395,6 +487,8 @@ def update_column(
395487
field_type=field_type or field.field_type,
396488
doc=doc if doc is not None else field.doc,
397489
required=field.required,
490+
initial_default=field.initial_default,
491+
write_default=field.write_default,
398492
)
399493

400494
if required is not None:
@@ -636,19 +730,35 @@ def struct(self, struct: StructType, field_results: List[Optional[IcebergType]])
636730
name = field.name
637731
doc = field.doc
638732
required = field.required
733+
write_default = field.write_default
639734

640735
# There is an update
641736
if update := self._updates.get(field.field_id):
642737
name = update.name
643738
doc = update.doc
644739
required = update.required
645-
646-
if field.name == name and field.field_type == result_type and field.required == required and field.doc == doc:
740+
write_default = update.write_default
741+
742+
if (
743+
field.name == name
744+
and field.field_type == result_type
745+
and field.required == required
746+
and field.doc == doc
747+
and field.write_default == write_default
748+
):
647749
new_fields.append(field)
648750
else:
649751
has_changes = True
650752
new_fields.append(
651-
NestedField(field_id=field.field_id, name=name, field_type=result_type, required=required, doc=doc)
753+
NestedField(
754+
field_id=field.field_id,
755+
name=name,
756+
field_type=result_type,
757+
required=required,
758+
doc=doc,
759+
initial_default=field.initial_default,
760+
write_default=write_default,
761+
)
652762
)
653763

654764
if has_changes:

0 commit comments

Comments
 (0)