Skip to content

Commit 132cd2b

Browse files
committed
Merge branch 'branch-25.10' of github.com:rapidsai/cudf into branch-25.10
2 parents 1581500 + 3ae9ff8 commit 132cd2b

File tree

16 files changed

+302
-140
lines changed

16 files changed

+302
-140
lines changed

python/cudf/cudf/pandas/_wrappers/pandas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1132,7 +1132,8 @@ def _find_user_frame():
11321132
frame = inspect.currentframe()
11331133
while frame:
11341134
modname = frame.f_globals.get("__name__", "")
1135-
if modname == "__main__" or not modname.startswith("cudf."):
1135+
# TODO: Remove "nvtx." entry once we cross nvtx-0.2.11 as minimum version
1136+
if modname == "__main__" or not modname.startswith(("cudf.", "nvtx.")):
11361137
return frame
11371138
frame = frame.f_back
11381139
raise RuntimeError("Could not find the user's frame.")

python/cudf_polars/cudf_polars/containers/dataframe.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,26 @@
2929
def _create_polars_column_metadata(
3030
name: str, dtype: PolarsDataType
3131
) -> plc.interop.ColumnMetadata:
32-
"""Create ColumnMetadata preserving pl.Struct field names."""
32+
"""Create ColumnMetadata preserving dtype attributes not supported by libcudf."""
33+
children_meta = []
34+
timezone = ""
35+
precision: int | None = None
36+
3337
if isinstance(dtype, pl.Struct):
3438
children_meta = [
3539
_create_polars_column_metadata(field.name, field.dtype)
3640
for field in dtype.fields
3741
]
38-
else:
39-
children_meta = []
40-
timezone = dtype.time_zone if isinstance(dtype, pl.Datetime) else None
42+
elif isinstance(dtype, pl.Datetime):
43+
timezone = dtype.time_zone or timezone
44+
elif isinstance(dtype, pl.Decimal):
45+
precision = dtype.precision
46+
4147
return plc.interop.ColumnMetadata(
42-
name=name, timezone=timezone or "", children_meta=children_meta
48+
name=name,
49+
timezone=timezone,
50+
precision=precision,
51+
children_meta=children_meta,
4352
)
4453

4554

python/cudf_polars/cudf_polars/containers/datatype.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def _from_polars(dtype: pl.DataType) -> plc.DataType:
8181
assert_never(dtype.time_unit)
8282
elif isinstance(dtype, pl.String):
8383
return plc.DataType(plc.TypeId.STRING)
84+
elif isinstance(dtype, pl.Decimal):
85+
return plc.DataType(plc.TypeId.DECIMAL128, scale=-dtype.scale)
8486
elif isinstance(dtype, pl.Null):
8587
# TODO: Hopefully
8688
return plc.DataType(plc.TypeId.EMPTY)

python/cudf_polars/cudf_polars/dsl/utils/aggregations.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,30 @@ def decompose_single_agg(
163163
is_median = agg.name == "median"
164164
is_quantile = agg.name == "quantile"
165165

166+
# quantile agg on decimal: unsupported -> keep dtype Decimal
167+
# mean/median on decimal: Polars returns float -> pre-cast
168+
decimal_unsupported = False
169+
if plc.traits.is_fixed_point(child_dtype):
170+
if is_quantile:
171+
decimal_unsupported = True
172+
elif agg.name in {"mean", "median"}:
173+
tid = agg.dtype.plc.id()
174+
if tid in {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64}:
175+
cast_to = (
176+
DataType(pl.Float64)
177+
if tid == plc.TypeId.FLOAT64
178+
else DataType(pl.Float32)
179+
)
180+
child = expr.Cast(cast_to, child)
181+
child_dtype = child.dtype.plc
182+
166183
is_group_quantile_supported = plc.traits.is_integral(
167184
child_dtype
168185
) or plc.traits.is_floating_point(child_dtype)
169186

170187
unsupported = (
171-
(is_median or is_quantile) and not is_group_quantile_supported
188+
decimal_unsupported
189+
or ((is_median or is_quantile) and not is_group_quantile_supported)
172190
) or (not plc.aggregation.is_valid_aggregation(child_dtype, req))
173191
if unsupported:
174192
return [], named_expr.reconstruct(expr.Literal(child.dtype, None))
@@ -177,19 +195,12 @@ def decompose_single_agg(
177195
# The aggregation is just reconstructed with the new
178196
# (potentially masked) child. This is safe because we recursed
179197
# to ensure there are no nested aggregations.
180-
return (
181-
[(named_expr.reconstruct(agg.reconstruct([child])), True)],
182-
named_expr.reconstruct(expr.Col(agg.dtype, name)),
183-
)
184-
elif agg.name in ("mean", "median", "quantile", "std", "var"):
185-
# libcudf promotes these to float64; but polars
186-
# keeps Float32, so cast back in post-processing.
187-
named = expr.NamedExpr(name, agg)
188-
post_col: expr.Expr = expr.Col(DataType(pl.Float64()), name)
189-
if agg.dtype.plc.id() == plc.TypeId.FLOAT32:
190-
post_col = expr.Cast(agg.dtype, post_col)
191-
return [(named, True)], expr.NamedExpr(name, post_col)
192-
elif agg.name == "sum":
198+
199+
# rebuild the agg with the transformed child
200+
new_children = [child] if not is_quantile else [child, agg.children[1]]
201+
named_expr = named_expr.reconstruct(agg.reconstruct(new_children))
202+
203+
if agg.name == "sum":
193204
col = (
194205
expr.Cast(agg.dtype, expr.Col(DataType(pl.datatypes.Int64()), name))
195206
if (
@@ -235,6 +246,14 @@ def decompose_single_agg(
235246
return [(named_expr, True), (win_len, True)], expr.NamedExpr(
236247
name, post_ternary_expr
237248
)
249+
elif agg.name in {"mean", "median", "quantile", "std", "var"}:
250+
post_agg_col: expr.Expr = expr.Col(
251+
DataType(pl.Float64()), name
252+
) # libcudf promotes to float64
253+
if agg.dtype.plc.id() == plc.TypeId.FLOAT32:
254+
# Cast back to float32 to match Polars
255+
post_agg_col = expr.Cast(agg.dtype, post_agg_col)
256+
return [(named_expr, True)], named_expr.reconstruct(post_agg_col)
238257
else:
239258
return [(named_expr, True)], named_expr.reconstruct(
240259
expr.Col(agg.dtype, name)

python/cudf_polars/cudf_polars/experimental/base.py

Lines changed: 76 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,12 @@ class UniqueStats:
9090

9191
class DataSourceInfo:
9292
"""
93-
Datasource information.
93+
Table data source information.
9494
9595
Notes
9696
-----
9797
This class should be sub-classed for specific
98-
datasource types (e.g. Parquet, DataFrame, etc.).
98+
data source types (e.g. Parquet, DataFrame, etc.).
9999
The required properties/methods enable lazy
100100
sampling of the underlying datasource.
101101
"""
@@ -117,6 +117,70 @@ def add_unique_stats_column(self, column: str) -> None:
117117
"""Add a column needing unique-value information."""
118118

119119

120+
class ColumnSourceInfo:
121+
"""
122+
Source column information.
123+
124+
Parameters
125+
----------
126+
table_source_info
127+
Table data source information.
128+
column_name
129+
Column name in the data source.
130+
131+
Notes
132+
-----
133+
This is a thin wrapper around DataSourceInfo that provides
134+
direct access to column-specific information.
135+
"""
136+
137+
__slots__ = ("_allow_unique_sampling", "column_name", "table_source_info")
138+
table_source_info: DataSourceInfo
139+
column_name: str
140+
_allow_unique_sampling: bool
141+
142+
def __init__(self, table_source_info: DataSourceInfo, column_name: str) -> None:
143+
self.table_source_info = table_source_info
144+
self.column_name = column_name
145+
self._allow_unique_sampling = False
146+
147+
@property
148+
def row_count(self) -> ColumnStat[int]:
149+
"""Data source row-count estimate."""
150+
return self.table_source_info.row_count
151+
152+
def unique_stats(self, *, force: bool = False) -> UniqueStats:
153+
"""
154+
Return unique-value statistics for a column.
155+
156+
Parameters
157+
----------
158+
force
159+
If True, return unique-value statistics even if the column
160+
wasn't marked as needing unique-value information.
161+
"""
162+
return (
163+
self.table_source_info.unique_stats(self.column_name)
164+
# Avoid sampling unique-stats if this column
165+
# wasn't marked as needing unique-stats.
166+
if force or self._allow_unique_sampling
167+
else UniqueStats()
168+
)
169+
170+
@property
171+
def storage_size(self) -> ColumnStat[int]:
172+
"""Return the average column size for a single file."""
173+
return self.table_source_info.storage_size(self.column_name)
174+
175+
def add_unique_stats_column(self, column: str | None = None) -> None:
176+
"""Add a column needing unique-value information."""
177+
if column in (None, self.column_name):
178+
self._allow_unique_sampling = True
179+
return self.table_source_info.add_unique_stats_column(
180+
column or self.column_name
181+
)
182+
183+
120184
class ColumnStats:
121185
"""
122186
Column statistics.
@@ -128,34 +192,29 @@ class ColumnStats:
128192
children
129193
Child ColumnStats objects.
130194
source_info
131-
Datasource information.
132-
source_name
133-
Source-column name.
195+
Column source information.
134196
unique_stats
135197
Unique-value statistics.
136198
"""
137199

138-
__slots__ = ("children", "name", "source_info", "source_name", "unique_stats")
200+
__slots__ = ("children", "name", "source_info", "unique_stats")
139201

140202
name: str
141203
children: tuple[ColumnStats, ...]
142-
source_info: DataSourceInfo
143-
source_name: str
204+
source_info: ColumnSourceInfo
144205
unique_stats: UniqueStats
145206

146207
def __init__(
147208
self,
148209
name: str,
149210
*,
150211
children: tuple[ColumnStats, ...] = (),
151-
source_info: DataSourceInfo | None = None,
152-
source_name: str | None = None,
212+
source_info: ColumnSourceInfo | None = None,
153213
unique_stats: UniqueStats | None = None,
154214
) -> None:
155215
self.name = name
156216
self.children = children
157-
self.source_info = source_info or DataSourceInfo()
158-
self.source_name = source_name or name
217+
self.source_info = source_info or ColumnSourceInfo(DataSourceInfo(), name)
159218
self.unique_stats = unique_stats or UniqueStats()
160219

161220
def new_parent(
@@ -184,7 +243,6 @@ def new_parent(
184243
children=(self,),
185244
# Want to reference the same DataSourceInfo
186245
source_info=self.source_info,
187-
source_name=self.source_name,
188246
# Want fresh UniqueStats so we can mutate in place
189247
unique_stats=UniqueStats(),
190248
)
@@ -195,6 +253,11 @@ class StatsCollector:
195253

196254
__slots__ = ("column_stats", "row_count")
197255

256+
row_count: dict[IR, ColumnStat[int]]
257+
"""Estimated row count for each IR node."""
258+
column_stats: dict[IR, dict[str, ColumnStats]]
259+
"""Column statistics for each IR node."""
260+
198261
def __init__(self) -> None:
199262
self.row_count: dict[IR, ColumnStat[int]] = {}
200263
self.column_stats: dict[IR, dict[str, ColumnStats]] = {}

python/cudf_polars/cudf_polars/experimental/io.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from cudf_polars.dsl.ir import IR, DataFrameScan, Empty, Scan, Sink, Union
2121
from cudf_polars.experimental.base import (
22+
ColumnSourceInfo,
2223
ColumnStat,
2324
ColumnStats,
2425
DataSourceInfo,
@@ -118,8 +119,8 @@ def from_scan(ir: Scan, config_options: ConfigOptions) -> ScanPartitionPlan:
118119
blocksize: int = config_options.executor.target_partition_size
119120
column_stats = _extract_scan_stats(ir, config_options)
120121
column_sizes: list[int] = []
121-
for name, cs in column_stats.items():
122-
storage_size = cs.source_info.storage_size(name)
122+
for cs in column_stats.values():
123+
storage_size = cs.source_info.storage_size
123124
if storage_size.value is not None:
124125
column_sizes.append(storage_size.value)
125126

@@ -821,16 +822,15 @@ def _extract_scan_stats(
821822
) -> dict[str, ColumnStats]:
822823
"""Extract base ColumnStats for a Scan node."""
823824
if ir.typ == "parquet":
824-
source_info = _sample_pq_stats(
825+
table_source_info = _sample_pq_stats(
825826
tuple(ir.paths),
826827
config_options.parquet_options.max_footer_samples,
827828
config_options.parquet_options.max_row_group_samples,
828829
)
829830
return {
830831
name: ColumnStats(
831832
name=name,
832-
source_info=source_info,
833-
source_name=name,
833+
source_info=ColumnSourceInfo(table_source_info, name),
834834
)
835835
for name in ir.schema
836836
}
@@ -879,12 +879,11 @@ def unique_stats(self, column: str) -> UniqueStats:
879879

880880
def _extract_dataframescan_stats(ir: DataFrameScan) -> dict[str, ColumnStats]:
881881
"""Extract base ColumnStats for a DataFrameScan node."""
882-
source_info = DataFrameSourceInfo(ir.df)
882+
table_source_info = DataFrameSourceInfo(ir.df)
883883
return {
884884
name: ColumnStats(
885885
name=name,
886-
source_info=source_info,
887-
source_name=name,
886+
source_info=ColumnSourceInfo(table_source_info, name),
888887
)
889888
for name in ir.schema
890889
}

python/cudf_polars/cudf_polars/experimental/statistics.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,8 @@ def _update_unique_stats_columns(
6666
if (
6767
name not in unique_fraction
6868
and (column_stats := child_column_stats.get(name)) is not None
69-
and (source_stats := column_stats.source_info) is not None
7069
):
71-
source_stats.add_unique_stats_column(column_stats.source_name or name)
70+
column_stats.source_info.add_unique_stats_column()
7271

7372

7473
@initialize_column_stats.register(IR)

python/cudf_polars/cudf_polars/testing/plugin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ def pytest_configure(config: pytest.Config) -> None:
174174
"tests/unit/io/test_lazy_parquet.py::test_parquet_schema_arg[False-row_groups]": "allow_missing_columns argument in read_parquet not translated in IR",
175175
"tests/unit/io/test_lazy_parquet.py::test_parquet_schema_arg[False-prefiltered]": "allow_missing_columns argument in read_parquet not translated in IR",
176176
"tests/unit/io/test_lazy_parquet.py::test_parquet_schema_arg[False-none]": "allow_missing_columns argument in read_parquet not translated in IR",
177+
"tests/unit/datatypes/test_decimal.py::test_decimal_aggregations": "https://github.com/pola-rs/polars/issues/23899",
178+
"tests/unit/datatypes/test_decimal.py::test_decimal_arithmetic_schema": "https://github.com/pola-rs/polars/issues/23899",
177179
}
178180

179181

@@ -191,6 +193,7 @@ def pytest_configure(config: pytest.Config) -> None:
191193
# Tests performance difference of CPU engine
192194
"tests/unit/operations/test_join.py::test_join_where_eager_perf_21145": "Tests performance bug in CPU engine",
193195
"tests/unit/operations/namespaces/list/test_list.py::test_list_struct_field_perf": "Tests CPU Engine perf",
196+
"tests/benchmark/test_with_columns.py::test_with_columns_quadratic_19503": "Tests performance bug in CPU engine",
194197
# The test may segfault with the legacy streaming engine. We should
195198
# remove this skip when all polars tests use the new streaming engine.
196199
"tests/unit/streaming/test_streaming_group_by.py::test_streaming_group_by_literal[1]": "May segfault w/the legacy streaming engine",

python/cudf_polars/docs/overview.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,11 @@ datasource (e.g. a Parquet dataset or in-memory `DataFrame`).
417417
**aggregated** column sampling via sub-classing. For example,
418418
The `ParquetSourceInfo` sub-class uses caching to avoid
419419
redundant file-system access.
420+
- `ColumnSourceInfo`: This class wraps a `DataSourceInfo` object.
421+
Since `DataSourceInfo` tracks information for an entire table, we use
422+
`ColumnSourceInfo` to provide a single-column view of the object.
420423
- `ColumnStats`: This class is used to group together the "base"
421-
`DataSourceInfo` reference and the current `UniqueStats` estimates
424+
`ColumnSourceInfo` reference and the local `UniqueStats` estimates
422425
for a specific IR + column combination. We bundle these references
423426
together to simplify the design and maintenance of `StatsCollector`.
424427
**NOTE:** The current `UniqueStats` estimates are not yet populated.

0 commit comments

Comments
 (0)