Skip to content

Commit 9b9ed53

Browse files
authored
fix: Invert case_sensitive logic in StructType (#1147)
* fix: Invert logic in StructType * Add test for StructType.field_by_name * Remove var I forgot about. * Fix formatting post-lint
1 parent d587e67 commit 9b9ed53

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

pyiceberg/types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,13 +377,13 @@ def field(self, field_id: int) -> Optional[NestedField]:
377377

378378
def field_by_name(self, name: str, case_sensitive: bool = True) -> Optional[NestedField]:
379379
if case_sensitive:
380-
name_lower = name.lower()
381380
for field in self.fields:
382-
if field.name.lower() == name_lower:
381+
if field.name == name:
383382
return field
384383
else:
384+
name_lower = name.lower()
385385
for field in self.fields:
386-
if field.name == name:
386+
if field.name.lower() == name_lower:
387387
return field
388388
return None
389389

tests/test_types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,20 @@ def test_struct_type() -> None:
149149
assert type_var == pickle.loads(pickle.dumps(type_var))
150150

151151

152+
def test_struct_field_by_name() -> None:
153+
lower_field = NestedField(1, "lower_case_field", IntegerType(), required=True)
154+
upper_field = NestedField(2, "UPPER_CASE_FIELD", IntegerType(), required=True)
155+
type_var = StructType(lower_field, upper_field)
156+
157+
assert type_var.field_by_name("lower_case_field", case_sensitive=False) == lower_field
158+
assert type_var.field_by_name("upper_case_field", case_sensitive=False) == upper_field
159+
assert type_var.field_by_name("nonexistent_field", case_sensitive=False) is None
160+
161+
assert type_var.field_by_name("lower_case_field", case_sensitive=True) == lower_field
162+
assert type_var.field_by_name("upper_case_field", case_sensitive=True) is None
163+
assert type_var.field_by_name("nonexistent_field", case_sensitive=True) is None
164+
165+
152166
def test_list_type() -> None:
153167
type_var = ListType(
154168
1,

0 commit comments

Comments
 (0)