Skip to content

Commit 03c16ec

Browse files
authored
fix column with transform (#7843)
1 parent 0e7c6ca commit 03c16ec

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/datasets/arrow_dataset.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,11 @@ def __init__(self, source: Union["Dataset", "Column"], column_name: str):
660660

661661
def __iter__(self) -> Iterator[Any]:
662662
if isinstance(self.source, Dataset):
663-
source = self.source._fast_select_column(self.column_name)
663+
if self.source._format_type == "custom":
664+
# the formatting transform may require all columns
665+
source = self.source
666+
else:
667+
source = self.source._fast_select_column(self.column_name)
664668
else:
665669
source = self.source
666670
for example in source:
@@ -670,7 +674,12 @@ def __getitem__(self, key: Union[int, str, list[int]]) -> Any:
670674
if isinstance(key, str):
671675
return Column(self, key)
672676
elif isinstance(self.source, Dataset):
673-
return self.source._fast_select_column(self.column_name)[key][self.column_name]
677+
if self.source._format_type == "custom":
678+
# the formatting transform may require all columns
679+
source = self.source
680+
else:
681+
source = self.source._fast_select_column(self.column_name)
682+
return source[key][self.column_name]
674683
elif isinstance(key, int):
675684
return self.source[key][self.column_name]
676685
else:

0 commit comments

Comments
 (0)