Skip to content

Commit 4b15fb6

Browse files
authored
Pass data type as string representation to NestedField (#1860)
<!-- Thanks for opening a pull request! --> <!-- In the case this PR will resolve an issue, please replace ${GITHUB_ISSUE_ID} below with the actual Github issue id. --> <!-- Closes #${GITHUB_ISSUE_ID} --> # Rationale for this change Closes #1851 # Are these changes tested? test_types.py # Are there any user-facing changes? N/A <!-- In the case of user-facing changes, please add the changelog label. -->
1 parent 96e6d54 commit 4b15fb6

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

pyiceberg/types.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
Field,
4848
PrivateAttr,
4949
SerializeAsAny,
50+
field_validator,
5051
model_serializer,
5152
model_validator,
5253
)
@@ -310,6 +311,14 @@ class NestedField(IcebergType):
310311
... doc="Just a long"
311312
... ))
312313
'2: bar: required long (Just a long)'
314+
>>> str(NestedField(
315+
... field_id=3,
316+
... name='baz',
317+
... field_type="string",
318+
... required=True,
319+
... doc="A string field"
320+
... ))
321+
'3: baz: required string (A string field)'
313322
"""
314323

315324
field_id: int = Field(alias="id")
@@ -320,11 +329,21 @@ class NestedField(IcebergType):
320329
initial_default: Optional[Any] = Field(alias="initial-default", default=None, repr=False)
321330
write_default: Optional[L] = Field(alias="write-default", default=None, repr=False) # type: ignore
322331

332+
@field_validator("field_type", mode="before")
333+
def convert_field_type(cls, v: Any) -> IcebergType:
334+
"""Convert string values into IcebergType instances."""
335+
if isinstance(v, str):
336+
try:
337+
return IcebergType.handle_primitive_type(v, None)
338+
except ValueError as e:
339+
raise ValueError(f"Unsupported field type: '{v}'") from e
340+
return v
341+
323342
def __init__(
324343
self,
325344
field_id: Optional[int] = None,
326345
name: Optional[str] = None,
327-
field_type: Optional[IcebergType] = None,
346+
field_type: Optional[IcebergType | str] = None,
328347
required: bool = False,
329348
doc: Optional[str] = None,
330349
initial_default: Optional[Any] = None,

tests/test_types.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,21 @@
6262
(12, BinaryType),
6363
]
6464

65+
primitive_types = {
66+
"boolean": BooleanType,
67+
"int": IntegerType,
68+
"long": LongType,
69+
"float": FloatType,
70+
"double": DoubleType,
71+
"date": DateType,
72+
"time": TimeType,
73+
"timestamp": TimestampType,
74+
"timestamptz": TimestamptzType,
75+
"string": StringType,
76+
"uuid": UUIDType,
77+
"binary": BinaryType,
78+
}
79+
6580

6681
@pytest.mark.parametrize("input_index, input_type", non_parameterized_types)
6782
def test_repr_primitive_types(input_index: int, input_type: Type[PrimitiveType]) -> None:
@@ -231,6 +246,32 @@ def test_nested_field() -> None:
231246
assert "validation errors for NestedField" in str(exc_info.value)
232247

233248

249+
def test_nested_field_complex_type_as_str_unsupported() -> None:
250+
unsupported_types = ["list", "map", "struct"]
251+
for type_str in unsupported_types:
252+
with pytest.raises(ValueError) as exc_info:
253+
_ = NestedField(1, "field", type_str, required=True)
254+
assert f"Unsupported field type: '{type_str}'" in str(exc_info.value)
255+
256+
257+
def test_nested_field_primitive_type_as_str() -> None:
258+
for type_str, type_class in primitive_types.items():
259+
field_var = NestedField(
260+
1,
261+
"field",
262+
type_str,
263+
required=True,
264+
)
265+
assert isinstance(
266+
field_var.field_type, type_class
267+
), f"Expected {type_class.__name__}, got {field_var.field_type.__class__.__name__}"
268+
269+
# Test that passing 'bool' raises a ValueError, as it should be 'boolean'
270+
with pytest.raises(ValueError) as exc_info:
271+
_ = NestedField(1, "field", "bool", required=True)
272+
assert "Unsupported field type: 'bool'" in str(exc_info.value)
273+
274+
234275
@pytest.mark.parametrize("input_index,input_type", non_parameterized_types)
235276
@pytest.mark.parametrize("check_index,check_type", non_parameterized_types)
236277
def test_non_parameterized_type_equality(

0 commit comments

Comments
 (0)