Skip to content

Commit ebdd6b6

Browse files
committed
Generalize data input to support Narwhals-compatible Series or DataFrames
1 parent 340e403 commit ebdd6b6

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

pymc/data.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
from copy import copy
2121
from typing import Union, cast
2222

23+
import narwhals as nw
2324
import numpy as np
2425
import pandas as pd
2526
import pytensor
2627
import pytensor.tensor as pt
2728
import xarray as xr
2829

30+
from narwhals.typing import IntoFrameT, IntoSeriesT
2931
from pytensor.compile.builders import OpFromGraph
3032
from pytensor.compile.sharedvalue import SharedVariable
3133
from pytensor.graph.basic import Variable
@@ -185,7 +187,7 @@ def determine_coords(
185187
if hasattr(value, "columns"):
186188
if dims is not None:
187189
dim_name = dims[1]
188-
if dim_name is None and value.columns.name is not None:
190+
if dim_name is None and nw.dependencies.is_pandas_dataframe(value) and value.columns.name is not None:
189191
dim_name = value.columns.name
190192
if dim_name is not None:
191193
coords[dim_name] = value.columns
@@ -197,12 +199,12 @@ def determine_coords(
197199
# str is applied because dim entries may be None
198200
coords[str(dim_name)] = cast(xr.DataArray, value[dim]).to_numpy()
199201

200-
if isinstance(value, np.ndarray) and dims is not None:
201-
if len(dims) != value.ndim:
202+
elif (isinstance(value, np.ndarray) or nw.dependencies.is_polars_series(value)) and dims is not None:
203+
if len(dims) != len(value.shape): # Polars objects have no .ndim ...
202204
raise ShapeError(
203205
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
204206
actual=value.shape,
205-
expected=value.ndim,
207+
expected=len(value.shape),
206208
)
207209
for size, dim in zip(value.shape, dims):
208210
coord = model.coords.get(dim, None)
@@ -219,7 +221,7 @@ def determine_coords(
219221

220222
def Data(
221223
name: str,
222-
value,
224+
value: IntoFrameT | IntoSeriesT | xr.DataArray | np.ndarray,
223225
*,
224226
dims: Sequence[str] | None = None,
225227
coords: dict[str, Sequence | np.ndarray] | None = None,
@@ -248,7 +250,7 @@ def Data(
248250
----------
249251
name : str
250252
The name for this variable.
251-
value : array_like or pandas.Series, pandas.Dataframe
253+
value : array_like or Narwhals-compatible Series or DataFrame
252254
A value to associate with this variable.
253255
dims : str, tuple of str or tuple of None, optional
254256
Dimension names of the random variables (as opposed to the shapes of these

0 commit comments

Comments
 (0)