20
20
from copy import copy
21
21
from typing import Union , cast
22
22
23
+ import narwhals as nw
23
24
import numpy as np
24
25
import pandas as pd
25
26
import pytensor
26
27
import pytensor .tensor as pt
27
28
import xarray as xr
28
29
30
+ from narwhals .typing import IntoFrameT , IntoSeriesT
29
31
from pytensor .compile .builders import OpFromGraph
30
32
from pytensor .compile .sharedvalue import SharedVariable
31
33
from pytensor .graph .basic import Variable
@@ -185,7 +187,7 @@ def determine_coords(
185
187
if hasattr (value , "columns" ):
186
188
if dims is not None :
187
189
dim_name = dims [1 ]
188
- if dim_name is None and value .columns .name is not None :
190
+ if dim_name is None and nw . dependencies . is_pandas_dataframe ( value ) and value .columns .name is not None :
189
191
dim_name = value .columns .name
190
192
if dim_name is not None :
191
193
coords [dim_name ] = value .columns
@@ -197,12 +199,12 @@ def determine_coords(
197
199
# str is applied because dim entries may be None
198
200
coords [str (dim_name )] = cast (xr .DataArray , value [dim ]).to_numpy ()
199
201
200
- if isinstance (value , np .ndarray ) and dims is not None :
201
- if len (dims ) != value .ndim :
202
+ elif ( isinstance (value , np .ndarray ) or nw . dependencies . is_polars_series ( value ) ) and dims is not None :
203
+ if len (dims ) != len ( value .shape ): # Polars objects have no .ndim ...
202
204
raise ShapeError (
203
205
"Invalid data shape. The rank of the dataset must match the length of `dims`." ,
204
206
actual = value .shape ,
205
- expected = value .ndim ,
207
+ expected = len ( value .shape ) ,
206
208
)
207
209
for size , dim in zip (value .shape , dims ):
208
210
coord = model .coords .get (dim , None )
@@ -219,7 +221,7 @@ def determine_coords(
219
221
220
222
def Data (
221
223
name : str ,
222
- value ,
224
+ value : IntoFrameT | IntoSeriesT | xr . DataArray | np . ndarray ,
223
225
* ,
224
226
dims : Sequence [str ] | None = None ,
225
227
coords : dict [str , Sequence | np .ndarray ] | None = None ,
@@ -248,7 +250,7 @@ def Data(
248
250
----------
249
251
name : str
250
252
The name for this variable.
251
- value : array_like or pandas. Series, pandas.Dataframe
253
+ value : array_like or Narwhals-compatible Series or DataFrame
252
254
A value to associate with this variable.
253
255
dims : str, tuple of str or tuple of None, optional
254
256
Dimension names of the random variables (as opposed to the shapes of these
0 commit comments