|
24 | 24 | if TYPE_CHECKING: |
25 | 25 | from collections.abc import Iterable, Iterator, Mapping, Sequence |
26 | 26 |
|
27 | | - from typing_extensions import Self |
| 27 | + from typing_extensions import Self, TypeAlias |
28 | 28 |
|
29 | 29 | from narwhals._arrow.typing import ChunkedArrayAny |
30 | 30 | from narwhals._plan.arrow.namespace import ArrowNamespace |
|
34 | 34 | from narwhals.dtypes import DType |
35 | 35 | from narwhals.typing import IntoSchema |
36 | 36 |
|
| 37 | +Incomplete: TypeAlias = Any |
| 38 | + |
37 | 39 |
|
38 | 40 | class ArrowDataFrame(EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"]): |
39 | 41 | implementation = Implementation.PYARROW |
@@ -173,25 +175,54 @@ def filter(self, predicate: NamedIR) -> Self: |
173 | 175 | mask = acero.lit(resolved.native) |
174 | 176 | return self._with_native(self.native.filter(mask)) |
175 | 177 |
|
176 | | - # TODO @dangotbanned: Clean this up after getting more tests in place |
177 | 178 | def partition_by(self, by: Sequence[str], *, include_key: bool = True) -> list[Self]: |
178 | | - original_names = self.columns |
179 | | - temp_name = temp.column_name(original_names) |
180 | | - native = self.native |
181 | | - composite_values = group_by.concat_str(acero.select_names_table(native, by)) |
182 | | - re_keyed = native.add_column(0, temp_name, composite_values) |
183 | | - source = acero.table_source(re_keyed) |
184 | | - if include_key: |
185 | | - keep = original_names |
186 | | - else: |
187 | | - ignore = {*by, temp_name} |
188 | | - keep = [name for name in original_names if name not in ignore] |
189 | | - select = acero.select_names(keep) |
190 | | - key = acero.col(temp_name) |
191 | | - # Need to iterate over the whole thing, so py_list first should be faster |
192 | | - partitions = ( |
193 | | - acero.declare(source, acero.filter(key == v), select) |
194 | | - for v in composite_values.unique().to_pylist() |
195 | | - ) |
196 | 179 | from_native = self._with_native |
197 | | - return [from_native(decl.to_table()) for decl in partitions] |
| 180 | + partitions = partition_by(self.native, by, include_key=include_key) |
| 181 | + return [from_native(df) for df in partitions] |
| 182 | + |
| 183 | + |
| 184 | +def partition_by( |
| 185 | + native: pa.Table, by: Sequence[str], *, include_key: bool = True |
| 186 | +) -> Iterator[pa.Table]: |
| 187 | + if len(by) == 1: |
| 188 | + yield from _partition_by_one(native, by[0], include_key=include_key) |
| 189 | + else: |
| 190 | + yield from _partition_by_many(native, by, include_key=include_key) |
| 191 | + |
| 192 | + |
| 193 | +def _partition_by_one( |
| 194 | + native: pa.Table, by: str, *, include_key: bool = True |
| 195 | +) -> Iterator[pa.Table]: |
| 196 | + """Optimized path for single-column partition.""" |
| 197 | + arr_dict: Incomplete = fn.array(native.column(by).dictionary_encode("encode")) |
| 198 | + indices: pa.Int32Array = arr_dict.indices |
| 199 | + if not include_key: |
| 200 | + native = native.remove_column(native.schema.get_field_index(by)) |
| 201 | + for idx in range(len(arr_dict.dictionary)): |
| 202 | + # NOTE: Acero filter doesn't support `null_selection_behavior="emit_null"` |
| 203 | + # Is there any reasonable way to do this in Acero? |
| 204 | + yield native.filter(pc.equal(pa.scalar(idx), indices)) |
| 205 | + |
| 206 | + |
| 207 | +def _partition_by_many( |
| 208 | + native: pa.Table, by: Sequence[str], *, include_key: bool = True |
| 209 | +) -> Iterator[pa.Table]: |
| 210 | + original_names = native.column_names |
| 211 | + temp_name = temp.column_name(original_names) |
| 212 | + key = acero.col(temp_name) |
| 213 | + composite_values = group_by.concat_str(acero.select_names_table(native, by)) |
| 214 | + # Need to iterate over the whole thing, so py_list first should be faster |
| 215 | + unique_py = composite_values.unique().to_pylist() |
| 216 | + re_keyed = native.add_column(0, temp_name, composite_values) |
| 217 | + source = acero.table_source(re_keyed) |
| 218 | + if include_key: |
| 219 | + keep = original_names |
| 220 | + else: |
| 221 | + ignore = {*by, temp_name} |
| 222 | + keep = [name for name in original_names if name not in ignore] |
| 223 | + select = acero.select_names(keep) |
| 224 | + for v in unique_py: |
| 225 | + # NOTE: May want to split the `Declaration` production iterator into it's own function |
| 226 | + # E.g, to push down column selection to *before* collection |
| 227 | + # Not needed for this task though |
| 228 | + yield acero.collect(source, acero.filter(key == v), select) |
0 commit comments