-
Notifications
You must be signed in to change notification settings - Fork 170
feat(RFC): A richer Expr IR
#2572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
| class FunctionFlags(enum.Flag): | |
| ALLOW_GROUP_AWARE = 1 << 0 | |
| """> Raise if use in group by | |
| Not sure where this is disabled. | |
| """ | |
| INPUT_WILDCARD_EXPANSION = 1 << 4 | |
| """Appears on all the horizontal aggs. | |
| https://github.com/pola-rs/polars/blob/e8ad1059721410e65a3d5c1d84055fb22a4d6d43/crates/polars-plan/src/plans/options.rs#L49-L58 | |
| """ | |
| RETURNS_SCALAR = 1 << 5 | |
| """Automatically explode on unit length if it ran as final aggregation.""" | |
| ROW_SEPARABLE = 1 << 8 | |
| """Not sure lol. | |
| https://github.com/pola-rs/polars/pull/22573 | |
| """ | |
| LENGTH_PRESERVING = 1 << 9 | |
| """mutually exclusive with `RETURNS_SCALAR`""" | |
| def is_elementwise(self) -> bool: | |
| return self in (FunctionFlags.ROW_SEPARABLE | FunctionFlags.LENGTH_PRESERVING) | |
| def returns_scalar(self) -> bool: | |
| return self in FunctionFlags.RETURNS_SCALAR | |
| def is_length_preserving(self) -> bool: | |
| return self in FunctionFlags.LENGTH_PRESERVING | |
| @staticmethod | |
| def default() -> FunctionFlags: | |
| return FunctionFlags.ALLOW_GROUP_AWARE |
narwhals/narwhals/_plan/options.py
Lines 52 to 108 in 0bada48
| class FunctionOptions(Immutable): | |
| """ExprMetadata` but less god object. | |
| https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs | |
| """ | |
| __slots__ = ("flags",) | |
| flags: FunctionFlags | |
| def is_elementwise(self) -> bool: | |
| return self.flags.is_elementwise() | |
| def returns_scalar(self) -> bool: | |
| return self.flags.returns_scalar() | |
| def is_length_preserving(self) -> bool: | |
| return self.flags.is_length_preserving() | |
| def with_flags(self, flags: FunctionFlags, /) -> FunctionOptions: | |
| if (FunctionFlags.RETURNS_SCALAR | FunctionFlags.LENGTH_PRESERVING) in flags: | |
| msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive." | |
| raise TypeError(msg) | |
| obj = FunctionOptions.__new__(FunctionOptions) | |
| object.__setattr__(obj, "flags", self.flags | flags) | |
| return obj | |
| def with_elementwise(self) -> FunctionOptions: | |
| return self.with_flags( | |
| FunctionFlags.ROW_SEPARABLE | FunctionFlags.LENGTH_PRESERVING | |
| ) | |
| @staticmethod | |
| def default() -> FunctionOptions: | |
| obj = FunctionOptions.__new__(FunctionOptions) | |
| object.__setattr__(obj, "flags", FunctionFlags.default()) | |
| return obj | |
| @staticmethod | |
| def elementwise() -> FunctionOptions: | |
| return FunctionOptions.default().with_elementwise() | |
| @staticmethod | |
| def row_separable() -> FunctionOptions: | |
| return FunctionOptions.groupwise().with_flags(FunctionFlags.ROW_SEPARABLE) | |
| @staticmethod | |
| def length_preserving() -> FunctionOptions: | |
| return FunctionOptions.default().with_flags(FunctionFlags.LENGTH_PRESERVING) | |
| @staticmethod | |
| def groupwise() -> FunctionOptions: | |
| return FunctionOptions.default() | |
| @staticmethod | |
| def aggregation() -> FunctionOptions: | |
| return FunctionOptions.groupwise().with_flags(FunctionFlags.RETURNS_SCALAR) |
narwhals/narwhals/_plan/common.py
Lines 149 to 172 in 0bada48
| class Function(ExprIR): | |
| """Shared by expr functions and namespace functions. | |
| https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L114 | |
| """ | |
| @property | |
| def function_options(self) -> FunctionOptions: | |
| from narwhals._plan.options import FunctionOptions | |
| return FunctionOptions.default() | |
| @property | |
| def is_scalar(self) -> bool: | |
| return self.function_options.returns_scalar() | |
| def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]: | |
| from narwhals._plan.expr import FunctionExpr | |
| from narwhals._plan.options import FunctionOptions | |
| # NOTE: Still need to figure out how these should be generated | |
| # Feel like it should be the union of `input` & `function` | |
| PLACEHOLDER = FunctionOptions.default() # noqa: N806 | |
| return FunctionExpr(input=inputs, function=self, options=PLACEHOLDER) |
narwhals/narwhals/_plan/expr.py
Lines 157 to 185 in 0bada48
| class FunctionExpr(ExprIR, t.Generic[_FunctionT]): | |
| """**Representing `Expr::Function`**. | |
| https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L114-L120 | |
| https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123 | |
| """ | |
| __slots__ = ("function", "input", "options") | |
| input: Seq[ExprIR] | |
| function: _FunctionT | |
| """Enum type is named `FunctionExpr` in `polars`. | |
| Mirroring *exactly* doesn't make much sense in OOP. | |
| https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123 | |
| """ | |
| options: FunctionOptions | |
| """Assuming this is **either**: | |
| 1. `function.function_options` | |
| 2. The union of (1) and any `FunctionOptions` in `inputs` | |
| """ | |
| def with_options(self, options: FunctionOptions, /) -> Self: | |
| options = self.options.with_flags(options.flags) | |
| return type(self)(input=self.input, function=self.function, options=options) |
- Mentioned in (#2391 (comment)) - Needed again for #2572
at the moment it looks like this adds a self-standing |
* chore(typing): Add `_typing_compat.py` - Mentioned in (#2391 (comment)) - Needed again for #2572 * refactor: Reuse `TypeVar` import * refactor: Reuse `@deprecated` import * refactor: Reuse `Protocol38` import * docs: Add module-level docstring
Still need: - reprs - fix the hierarchy issue (#2572 (comment)) - Flag summing (#2572 (comment))
- 1 step closer to the understanding for (#2572 (comment)) - There's still some magic going on when `polars` serializes - Need to track down where `'collect_groups': 'ElementWise'` and `'collect_groups': 'GroupWise'` first appear - Seems like the flags get reduced
narwhals/_plan/functions.py
Outdated
| @property | ||
| def function_options(self) -> FunctionOptions: | ||
| return FunctionOptions.length_preserving() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like FunctionOptions.length_preserving is one we need to pay more attention to
|
Thanks for peeking @MarcoGorelli
That is definitely the eventual goal! 🤞 Despite how quickly things have progressed, I still feel I'm a few steps behind being ready for that just yet. General overviewI'm trying to focus on modeling these structures and how they interact:
My thought was that So like what I have in Current
|
This comment was marked as resolved.
This comment was marked as resolved.
Can't tell if this means `FirstT` will match the entry `firstt`, but preserve the `firstt` fix (https://github.com/codespell-project/codespell#ignoring-words) (#2572 (comment))
|
I should've expected this, but it was a nice suprise to find we get hashable selectors for free 😄 from narwhals._plan import selectors as ndcs
>>> ndcs.matches("[^z]a")._ir == ndcs.matches("[^z]a")._ir
True
>>> ndcs.matches("[^z]a")._ir == ndcs.matches("abc")._ir
False@MarcoGorelli regarding (#2291) from narwhals._plan import selectors as ndcs
>>> ndcs.all()._ir == ndcs.all()._ir
True
lhs = ndcs.all()
rhs = ndcs.all().mean()
>>> lhs._ir == rhs._ir
False
>>> lhs._ir == rhs._ir.expr
TrueAnd the same holds for the non-selectors from narwhals._plan import demo as nwd
lhs = nwd.all()
rhs = nwd.all().mean()
>>> lhs._ir == rhs._ir
False
>>> lhs._ir == rhs._ir.expr
True
>>> type(rhs._ir)
narwhals._plan.aggregation.Mean |
An experiment towards (#2572 (comment))
tests/plan/expr_parsing_test.py
Outdated
| def test_valid_windows() -> None: | ||
| """Was planning to test this matched, but we seem to allow elementwise horizontal? | ||
| https://github.com/narwhals-dev/narwhals/blob/63c8e4771a1df4e0bfeea5559c303a4a447d5cc2/tests/expression_parsing_test.py#L10-L45 | ||
| """ | ||
| ELEMENTWISE_ERR = re.compile(r"cannot use.+over.+elementwise", re.IGNORECASE) # noqa: N806 | ||
| a = nwd.col("a") | ||
| assert a.cum_sum() | ||
| assert a.cum_sum().over(order_by="id") | ||
| with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): | ||
| assert a.cum_sum().abs().over(order_by="id") | ||
|
|
||
| assert (a.cum_sum() + 1).over(order_by="id") | ||
| assert a.cum_sum().cum_sum().over(order_by="id") | ||
| assert a.cum_sum().cum_sum() | ||
| assert nwd.sum_horizontal(a, a.cum_sum()) | ||
| with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): | ||
| assert nwd.sum_horizontal(a, a.cum_sum()).over(order_by="a") | ||
|
|
||
| assert nwd.sum_horizontal(a, a.cum_sum().over(order_by="i")) | ||
| assert nwd.sum_horizontal(a.diff(), a.cum_sum().over(order_by="i")) | ||
| with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): | ||
| assert nwd.sum_horizontal(a.diff(), a.cum_sum()).over(order_by="i") | ||
|
|
||
| with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): | ||
| assert nwd.sum_horizontal(a.diff().abs(), a.cum_sum()).over(order_by="i") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MarcoGorelli quick question
This is adapted from an existing test:
tests.expression_parsing_test.test_window_kind
narwhals/tests/expression_parsing_test.py
Lines 10 to 45 in 63c8e47
| @pytest.mark.parametrize( | |
| ("expr", "expected"), | |
| [ | |
| (nw.col("a"), 0), | |
| (nw.col("a").mean(), 0), | |
| (nw.col("a").cum_sum(), 1), | |
| (nw.col("a").cum_sum().over(order_by="id"), 0), | |
| (nw.col("a").cum_sum().abs().over(order_by="id"), 1), | |
| ((nw.col("a").cum_sum() + 1).over(order_by="id"), 1), | |
| (nw.col("a").cum_sum().cum_sum().over(order_by="id"), 1), | |
| (nw.col("a").cum_sum().cum_sum(), 2), | |
| (nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()), 1), | |
| (nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()).over(order_by="a"), 1), | |
| (nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum().over(order_by="i")), 0), | |
| ( | |
| nw.sum_horizontal( | |
| nw.col("a").diff(), nw.col("a").cum_sum().over(order_by="i") | |
| ), | |
| 1, | |
| ), | |
| ( | |
| nw.sum_horizontal(nw.col("a").diff(), nw.col("a").cum_sum()).over( | |
| order_by="i" | |
| ), | |
| 2, | |
| ), | |
| ( | |
| nw.sum_horizontal(nw.col("a").diff().abs(), nw.col("a").cum_sum()).over( | |
| order_by="i" | |
| ), | |
| 2, | |
| ), | |
| ], | |
| ) | |
| def test_window_kind(expr: nw.Expr, expected: int) -> None: | |
| assert expr._metadata.n_orderable_ops == expected |
AFAICT, all of the expressions I've needed a InvalidOperationError for shouldn't be valid.
But they aren't raising in current narwhals 🤔
1
import narwhals as nw
a = nw.col("a")
a.cum_sum().abs().over(order_by="id")This error explicitly mentions abs
narwhals/narwhals/_expression_parsing.py
Lines 357 to 362 in 9bd10ad
| if self.is_elementwise or self.is_filtration: | |
| msg = ( | |
| "Cannot use `over` on expressions which are elementwise\n" | |
| "(e.g. `abs`) or which change length (e.g. `drop_nulls`)." | |
| ) | |
| raise InvalidOperationError(msg) |
2, 3, 4
These are all raising the same as (1), but the issue seems to be that horizontal functions aren't being treated as elementwise
import narwhals as nw
a = nw.col("a")
nw.sum_horizontal(a, a.cum_sum()).over(order_by="a")
nw.sum_horizontal(a.diff(), a.cum_sum()).over(order_by="i")
nw.sum_horizontal(a.diff().abs(), a.cum_sum()).over(order_by="i")In polars, they all seem to be elementwise but with an additional flag
I've done the same in this PR, but I don't think that flag would factor into this?
narwhals/narwhals/_plan/functions.py
Lines 291 to 299 in 9bd10ad
| class SumHorizontal(Function): | |
| @property | |
| def function_options(self) -> FunctionOptions: | |
| return FunctionOptions.elementwise().with_flags( | |
| FunctionFlags.INPUT_WILDCARD_EXPANSION | |
| ) | |
| def __repr__(self) -> str: | |
| return "sum_horizontal" |
Very rough, but identifies the edge cases at least Child of (#2572)
…}`, `Expr.is_{first,last}_distinct` (#3173)
- Related #2572 (comment) - Child of #2572
Will close #2571
What type of PR is this? (check all applicable)
Related issues
Exprinternal representation #2571(sort:updated-desc "(expr-ir)/" in:title)Checklist
If you have comments or can explain your changes, please do so below
Important
See (#2571) for detail!!!!!!!
Very open to feedback
Tasks
Show 2025 May-July
pl.Expr.metapl.Expr.meta)ExprIRmetamethods_typing_compatmodule #2578Merge another PR with (perf: Avoid module-levelimportlib.util.find_spec#2391 (comment)) firstTypeVardefaults moreTypeVar("T", bound=Thing, default=Thing), instead of an opaqueExprIRSelector(s)narwhals/narwhals/_plan/expr.py
Lines 336 to 337 in 0bada48
BinaryExprthat describes the restricted set of operators that are allowedpolars, since they wrappl.colinternallyIntoExprin more places (including and beyond whatnarwhalsallows now)demo.py*_horizontalconcat_str)dummy.pyover,sort_by)FunctionOptions+ friends see commentWhere does the{flags: ...}->{collect_groups: ..., flags: ...}expansion happen?polars>=1.3.0fixed the issue (see comment)Ternarywhen-then-otherwise🥳)Metais_*,has_*,output_namemeta methodsroot_namesundo_aliases,popNamename.py(Expr::KeepName,Expr::RenameAlias)polarswill help with themetamethodsCat,Struct,List(a3e29d1)String(72c33ce)DateTime(aee0a7e)_expression_parsing.pyrulesrustversion worksExpansionFlagsexpand_function_inputsrewrite_projectionsreplace_selectorexpand_selectorreplace_selector_innerreplace_and_add_to_resultsreplace_nthprepare_excludedexpand_columnsexpand_dtypesreplace_dtype_or_index_with_columndtypes_match(probably can solve w/ existingnarwhals)expand_indicesreplace_index_with_columnreplace_wildcardrewrite_special_aliasesreplace_wildcard_with_columnreplace_regexexpand_regexExprIR.map_irExprIR #2572 (comment))ExprIR.map_irfor most nodesWindowExpr.map_irFunctionExpr.map_irRollingExpr,AnonymousExprinheritselectorsExprIR(main) #3066ExprIR(main) #3066 (comment))_planpackage #3122group_by, utilizepyarrow.acero#3143{Expr,Series}.{first,last}#2528)protocols.py#3166order_by,hashjoin,DataFrame.{filter,join},Expr.is_{first,last}_distinct#3173__dict__appearing onImmutablesubclasses (thread)__slots__, and not__dict__too #3201ExpansionFlags.from_ir#3206LogicalPlan(see thread)pyarrow.aceroLogicalPlanfrom narwhals opspyarrow.acero.Declarationpyarrow.Tableto lean on things only supported thereColumnNameOrSelectorcan be addedparse_into_selector_irpc.ExpressionDataFrame.dropshould support itExpr.meta.serialize,Expr.deserializepc.Expressionover(*partition_by)#3224