Skip to content

Commit 2392c01

Browse files
committed
Add narwhals result types
Signed-off-by: Pascal Tomecek <pascal.tomecek@cubistsystematic.com>
1 parent d4b2286 commit 2392c01

File tree

3 files changed

+102
-1
lines changed

3 files changed

+102
-1
lines changed

ccflow/result/narwhals.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import narwhals.stable.v1 as nw
2+
from pydantic import Field
3+
4+
from ..base import ResultBase
5+
from ..exttypes.narwhals import DataFrameT, FrameT
6+
7+
__all__ = (
8+
"NarwhalsFrameResult",
9+
"NarwhalsDataFrameResult",
10+
)
11+
12+
13+
class NarwhalsFrameResult(ResultBase):
14+
df: FrameT = Field(description="Narwhals DataFrame or LazyFrame")
15+
16+
def collect(self) -> "NarwhalsDataFrameResult":
17+
"""Collects the result into a NarwhalsDataFrameResult."""
18+
if isinstance(self.df, nw.LazyFrame):
19+
return NarwhalsDataFrameResult(df=self.df.collect(), **self.model_dump(exclude={"df", "type_"}))
20+
return NarwhalsDataFrameResult(df=self.df, **self.model_dump(exclude={"df", "type_"}))
21+
22+
23+
class NarwhalsDataFrameResult(NarwhalsFrameResult):
24+
df: DataFrameT = Field(description="Narwhals eager Dataframe")
25+
26+
def collect(self) -> "NarwhalsDataFrameResult":
27+
"""Collects the result into a NarwhalsDataFrameResult."""
28+
return self
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from typing import Annotated
2+
3+
import narwhals.stable.v1 as nw
4+
import polars as pl
5+
import pytest
6+
7+
from ccflow.exttypes.narwhals import (
8+
DataFrameT,
9+
SchemaValidator,
10+
)
11+
from ccflow.result.narwhals import NarwhalsDataFrameResult, NarwhalsFrameResult
12+
13+
14+
@pytest.fixture
15+
def data():
16+
return {
17+
"a": [1.0, 2.0, 3.0],
18+
"b": [4, 5, 6],
19+
"c": ["foo", "bar", "baz"],
20+
"d": [0, 0, 0],
21+
}
22+
23+
24+
@pytest.fixture
25+
def schema():
26+
return {
27+
"a": nw.Float64,
28+
"b": nw.Int64,
29+
"c": nw.String,
30+
"d": nw.Float64,
31+
}
32+
33+
34+
def test_narwhals_frame_result(data):
35+
df = pl.DataFrame(data)
36+
result = NarwhalsFrameResult(df=df)
37+
assert isinstance(result.df, nw.DataFrame)
38+
assert result.df.to_native() is df
39+
40+
df = pl.DataFrame(data).lazy()
41+
result = NarwhalsFrameResult(df=df)
42+
assert isinstance(result.df, nw.LazyFrame)
43+
assert result.df.to_native() is df
44+
45+
46+
def test_narwhals_dataframe_result(data):
47+
df = pl.DataFrame(data)
48+
result = NarwhalsDataFrameResult(df=df)
49+
assert isinstance(result.df, nw.DataFrame)
50+
assert result.df.to_native() is df
51+
52+
df = pl.DataFrame(data).lazy()
53+
result = NarwhalsDataFrameResult(df=df)
54+
assert isinstance(result.df, nw.DataFrame)
55+
56+
57+
def test_collect(data):
58+
df = pl.DataFrame(data)
59+
result = NarwhalsFrameResult(df=df)
60+
result2 = result.collect()
61+
assert isinstance(result2, NarwhalsDataFrameResult)
62+
assert isinstance(result2.df, nw.DataFrame)
63+
64+
65+
def test_custom(data, schema):
66+
class MyNarwhalsResult(NarwhalsDataFrameResult):
67+
df: Annotated[DataFrameT, SchemaValidator(schema, cast=True)]
68+
69+
df = pl.DataFrame(data)
70+
result = MyNarwhalsResult(df=df)
71+
assert result.df.schema["d"] == nw.Float64()

docs/wiki/Key-Features.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,13 @@ A Result is an object that holds the results from a callable model. It provides
134134
The following table summarizes the "result" models.
135135

136136
| Name | Path | Description |
137-
| :-- | :-- | :--- |
137+
|:---------------------------|:-------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
138138
| `GenericResult` | `ccflow.result` | A generic result (holds anything). |
139139
| `DictResult` | `ccflow.result` | A generic dict (key/value) result. |
140140
| `ArrowResult` | `ccflow.result.pyarrow` | Holds an arrow table. |
141141
| `ArrowDateRangeResult` | `ccflow.result.pyarrow` | Extension of `ArrowResult` for representing a table over a date range that can be divided by date, such that generation of any sub-range of dates gives the same results as the original table filtered for those dates. |
142+
| `NarwhalsResult` | `ccflow.result.narwhals` | Holds a narwhals `DataFrame` or `LazyFrame`. |
143+
| `NarwhalsDataFrameResult` | `ccflow.result.narwhals` | Holds a narwhals eager `DataFrame`. |
142144
| `NumpyResult` | `ccflow.result.numpy` | Holds a numpy array. |
143145
| `PandasResult` | `ccflow.result.pandas` | Holds a pandas dataframe. |
144146
| `XArrayResult` | `ccflow.result.xarray` | Holds an xarray. |

0 commit comments

Comments
 (0)