2020from copy import copy
2121from typing import Union , cast
2222
23+ import narwhals as nw
2324import numpy as np
2425import pandas as pd
2526import pytensor
2627import pytensor .tensor as pt
2728import xarray as xr
2829
30+ from narwhals .typing import IntoFrameT , IntoSeriesT
2931from pytensor .compile .builders import OpFromGraph
3032from pytensor .compile .sharedvalue import SharedVariable
3133from pytensor .graph .basic import Variable
@@ -185,7 +187,7 @@ def determine_coords(
185187 if hasattr (value , "columns" ):
186188 if dims is not None :
187189 dim_name = dims [1 ]
188- if dim_name is None and value .columns .name is not None :
190+ if dim_name is None and nw . dependencies . is_pandas_dataframe ( value ) and value .columns .name is not None :
189191 dim_name = value .columns .name
190192 if dim_name is not None :
191193 coords [dim_name ] = value .columns
@@ -197,12 +199,12 @@ def determine_coords(
197199 # str is applied because dim entries may be None
198200 coords [str (dim_name )] = cast (xr .DataArray , value [dim ]).to_numpy ()
199201
200- if isinstance (value , np .ndarray ) and dims is not None :
201- if len (dims ) != value .ndim :
202+ elif ( isinstance (value , np .ndarray ) or nw . dependencies . is_polars_series ( value ) ) and dims is not None :
203+ if len (dims ) != len ( value .shape ): # Polars objects have no .ndim ...
202204 raise ShapeError (
203205 "Invalid data shape. The rank of the dataset must match the length of `dims`." ,
204206 actual = value .shape ,
205- expected = value .ndim ,
207+ expected = len ( value .shape ) ,
206208 )
207209 for size , dim in zip (value .shape , dims ):
208210 coord = model .coords .get (dim , None )
@@ -219,7 +221,7 @@ def determine_coords(
219221
220222def Data (
221223 name : str ,
222- value ,
224+ value : IntoFrameT | IntoSeriesT | xr . DataArray | np . ndarray ,
223225 * ,
224226 dims : Sequence [str ] | None = None ,
225227 coords : dict [str , Sequence | np .ndarray ] | None = None ,
@@ -248,7 +250,7 @@ def Data(
248250 ----------
249251 name : str
250252 The name for this variable.
251- value : array_like or pandas. Series, pandas.Dataframe
253+ value : array_like or Narwhals-compatible Series or DataFrame
252254 A value to associate with this variable.
253255 dims : str, tuple of str or tuple of None, optional
254256 Dimension names of the random variables (as opposed to the shapes of these
0 commit comments