Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from copy import copy
from typing import Union, cast

import narwhals as nw
import numpy as np
import pandas as pd
import pytensor
import pytensor.tensor as pt
import xarray as xr

from narwhals.typing import IntoFrameT, IntoSeriesT
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Variable
Expand Down Expand Up @@ -185,7 +187,7 @@ def determine_coords(
if hasattr(value, "columns"):
if dims is not None:
dim_name = dims[1]
if dim_name is None and value.columns.name is not None:
if dim_name is None and nw.dependencies.is_pandas_dataframe(value) and value.columns.name is not None:
dim_name = value.columns.name
if dim_name is not None:
coords[dim_name] = value.columns
Expand All @@ -197,12 +199,12 @@ def determine_coords(
# str is applied because dim entries may be None
coords[str(dim_name)] = cast(xr.DataArray, value[dim]).to_numpy()

if isinstance(value, np.ndarray) and dims is not None:
if len(dims) != value.ndim:
elif (isinstance(value, np.ndarray) or nw.dependencies.is_polars_series(value)) and dims is not None:
if len(dims) != len(value.shape): # Polars objects have no .ndim ...
raise ShapeError(
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
actual=value.shape,
expected=value.ndim,
expected=len(value.shape),
)
for size, dim in zip(value.shape, dims):
coord = model.coords.get(dim, None)
Expand All @@ -219,7 +221,7 @@ def determine_coords(

def Data(
name: str,
value,
value: IntoFrameT | IntoSeriesT | xr.DataArray | np.ndarray,
*,
dims: Sequence[str] | None = None,
coords: dict[str, Sequence | np.ndarray] | None = None,
Expand Down Expand Up @@ -248,7 +250,7 @@ def Data(
----------
name : str
The name for this variable.
value : array_like or pandas.Series, pandas.Dataframe
value : array_like or Narwhals-compatible Series or DataFrame
A value to associate with this variable.
dims : str, tuple of str or tuple of None, optional
Dimension names of the random variables (as opposed to the shapes of these
Expand Down