|
30 | 30 |
|
31 | 31 | from pydantic import Field, conlist, field_validator, model_serializer
|
32 | 32 |
|
33 |
| -from pyiceberg.schema import Schema, SchemaVisitor, visit |
| 33 | +from pyiceberg.schema import P, PartnerAccessor, Schema, SchemaVisitor, SchemaWithPartnerVisitor, visit, visit_with_partner |
34 | 34 | from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel
|
35 |
| -from pyiceberg.types import ListType, MapType, NestedField, PrimitiveType, StructType |
| 35 | +from pyiceberg.types import IcebergType, ListType, MapType, NestedField, PrimitiveType, StructType |
| 36 | +from pyiceberg.utils.deprecated import deprecated |
36 | 37 |
|
37 | 38 |
|
38 | 39 | class MappedField(IcebergBaseModel):
|
@@ -74,6 +75,11 @@ class NameMapping(IcebergRootModel[List[MappedField]]):
|
74 | 75 | def _field_by_name(self) -> Dict[str, MappedField]:
|
75 | 76 | return visit_name_mapping(self, _IndexByName())
|
76 | 77 |
|
| 78 | + @deprecated( |
| 79 | + deprecated_in="0.8.0", |
| 80 | + removed_in="0.9.0", |
| 81 | + help_message="Please use `apply_name_mapping` instead", |
| 82 | + ) |
77 | 83 | def find(self, *names: str) -> MappedField:
|
78 | 84 | name = ".".join(names)
|
79 | 85 | try:
|
@@ -248,3 +254,127 @@ def create_mapping_from_schema(schema: Schema) -> NameMapping:
|
248 | 254 |
|
249 | 255 | def update_mapping(mapping: NameMapping, updates: Dict[int, NestedField], adds: Dict[int, List[NestedField]]) -> NameMapping:
|
250 | 256 | return NameMapping(visit_name_mapping(mapping, _UpdateMapping(updates, adds)))
|
| 257 | + |
| 258 | + |
| 259 | +class NameMappingAccessor(PartnerAccessor[MappedField]): |
| 260 | + def schema_partner(self, partner: Optional[MappedField]) -> Optional[MappedField]: |
| 261 | + return partner |
| 262 | + |
| 263 | + def field_partner( |
| 264 | + self, partner_struct: Optional[Union[List[MappedField], MappedField]], _: int, field_name: str |
| 265 | + ) -> Optional[MappedField]: |
| 266 | + if partner_struct is not None: |
| 267 | + if isinstance(partner_struct, MappedField): |
| 268 | + partner_struct = partner_struct.fields |
| 269 | + |
| 270 | + for field in partner_struct: |
| 271 | + if field_name in field.names: |
| 272 | + return field |
| 273 | + |
| 274 | + return None |
| 275 | + |
| 276 | + def list_element_partner(self, partner_list: Optional[MappedField]) -> Optional[MappedField]: |
| 277 | + if partner_list is not None: |
| 278 | + for field in partner_list.fields: |
| 279 | + if "element" in field.names: |
| 280 | + return field |
| 281 | + return None |
| 282 | + |
| 283 | + def map_key_partner(self, partner_map: Optional[MappedField]) -> Optional[MappedField]: |
| 284 | + if partner_map is not None: |
| 285 | + for field in partner_map.fields: |
| 286 | + if "key" in field.names: |
| 287 | + return field |
| 288 | + return None |
| 289 | + |
| 290 | + def map_value_partner(self, partner_map: Optional[MappedField]) -> Optional[MappedField]: |
| 291 | + if partner_map is not None: |
| 292 | + for field in partner_map.fields: |
| 293 | + if "value" in field.names: |
| 294 | + return field |
| 295 | + return None |
| 296 | + |
| 297 | + |
| 298 | +class NameMappingProjectionVisitor(SchemaWithPartnerVisitor[MappedField, IcebergType]): |
| 299 | + current_path: List[str] |
| 300 | + |
| 301 | + def __init__(self) -> None: |
| 302 | + # For keeping track where we are in case when a field cannot be found |
| 303 | + self.current_path = [] |
| 304 | + |
| 305 | + def before_field(self, field: NestedField, field_partner: Optional[P]) -> None: |
| 306 | + self.current_path.append(field.name) |
| 307 | + |
| 308 | + def after_field(self, field: NestedField, field_partner: Optional[P]) -> None: |
| 309 | + self.current_path.pop() |
| 310 | + |
| 311 | + def before_list_element(self, element: NestedField, element_partner: Optional[P]) -> None: |
| 312 | + self.current_path.append("element") |
| 313 | + |
| 314 | + def after_list_element(self, element: NestedField, element_partner: Optional[P]) -> None: |
| 315 | + self.current_path.pop() |
| 316 | + |
| 317 | + def before_map_key(self, key: NestedField, key_partner: Optional[P]) -> None: |
| 318 | + self.current_path.append("key") |
| 319 | + |
| 320 | + def after_map_key(self, key: NestedField, key_partner: Optional[P]) -> None: |
| 321 | + self.current_path.pop() |
| 322 | + |
| 323 | + def before_map_value(self, value: NestedField, value_partner: Optional[P]) -> None: |
| 324 | + self.current_path.append("value") |
| 325 | + |
| 326 | + def after_map_value(self, value: NestedField, value_partner: Optional[P]) -> None: |
| 327 | + self.current_path.pop() |
| 328 | + |
| 329 | + def schema(self, schema: Schema, schema_partner: Optional[MappedField], struct_result: StructType) -> IcebergType: |
| 330 | + return Schema(*struct_result.fields, schema_id=schema.schema_id) |
| 331 | + |
| 332 | + def struct(self, struct: StructType, struct_partner: Optional[MappedField], field_results: List[NestedField]) -> IcebergType: |
| 333 | + return StructType(*field_results) |
| 334 | + |
| 335 | + def field(self, field: NestedField, field_partner: Optional[MappedField], field_result: IcebergType) -> IcebergType: |
| 336 | + if field_partner is None: |
| 337 | + raise ValueError(f"Field missing from NameMapping: {'.'.join(self.current_path)}") |
| 338 | + |
| 339 | + return NestedField( |
| 340 | + field_id=field_partner.field_id, |
| 341 | + name=field.name, |
| 342 | + field_type=field_result, |
| 343 | + required=field.required, |
| 344 | + doc=field.doc, |
| 345 | + initial_default=field.initial_default, |
| 346 | + initial_write=field.write_default, |
| 347 | + ) |
| 348 | + |
| 349 | + def list(self, list_type: ListType, list_partner: Optional[MappedField], element_result: IcebergType) -> IcebergType: |
| 350 | + if list_partner is None: |
| 351 | + raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}") |
| 352 | + |
| 353 | + element_id = next(field for field in list_partner.fields if "element" in field.names).field_id |
| 354 | + return ListType(element_id=element_id, element=element_result, element_required=list_type.element_required) |
| 355 | + |
| 356 | + def map( |
| 357 | + self, map_type: MapType, map_partner: Optional[MappedField], key_result: IcebergType, value_result: IcebergType |
| 358 | + ) -> IcebergType: |
| 359 | + if map_partner is None: |
| 360 | + raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}") |
| 361 | + |
| 362 | + key_id = next(field for field in map_partner.fields if "key" in field.names).field_id |
| 363 | + value_id = next(field for field in map_partner.fields if "value" in field.names).field_id |
| 364 | + return MapType( |
| 365 | + key_id=key_id, |
| 366 | + key_type=key_result, |
| 367 | + value_id=value_id, |
| 368 | + value_type=value_result, |
| 369 | + value_required=map_type.value_required, |
| 370 | + ) |
| 371 | + |
| 372 | + def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[MappedField]) -> PrimitiveType: |
| 373 | + if primitive_partner is None: |
| 374 | + raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}") |
| 375 | + |
| 376 | + return primitive |
| 377 | + |
| 378 | + |
| 379 | +def apply_name_mapping(schema_without_ids: Schema, name_mapping: NameMapping) -> Schema: |
| 380 | + return visit_with_partner(schema_without_ids, name_mapping, NameMappingProjectionVisitor(), NameMappingAccessor()) # type: ignore |
0 commit comments