2727import pytensor .tensor as pt
2828import xarray as xr
2929
30- from narwhals .typing import IntoFrameT , IntoSeriesT
30+ from narwhals .typing import IntoFrameT , IntoLazyFrameT , IntoSeriesT
3131from pytensor .compile import SharedVariable
3232from pytensor .compile .builders import OpFromGraph
3333from pytensor .graph .basic import Variable
@@ -163,7 +163,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
163163 return mb_tensors if len (variables ) else mb_tensors [0 ]
164164
165165
166- def _handle_none_dims (dims : Sequence [str ] | None , ndim : int ) -> Sequence [str ] | Sequence [None ]:
166+ def _handle_none_dims (
167+ dims : Sequence [str | None ] | None , ndim : int
168+ ) -> Sequence [str | None ] | Sequence [None ]:
167169 if dims is None :
168170 return [None ] * ndim
169171 else :
@@ -176,7 +178,7 @@ def determine_coords(
176178 model : "Model" ,
177179 dims : Sequence [str | None ] | None = None ,
178180 coords : dict [str , Sequence | np .ndarray ] | None = None ,
179- ):
181+ ) -> tuple [ dict [ str , Sequence | np . ndarray ], Sequence [ str | None ] | Sequence [ None ]] :
180182 """Determine coordinate values from data or the model (via ``dims``)."""
181183 raise NotImplementedError (
182184 f"Cannot determine coordinates for data of type { type (value )} , please provide `coords` explicitly or "
@@ -190,10 +192,13 @@ def determine_array_coords(
190192 model : "Model" ,
191193 dims : Sequence [str ] | None = None ,
192194 coords : dict [str , Sequence | np .ndarray ] | None = None ,
193- ):
195+ ) -> tuple [ dict [ str , Sequence | np . ndarray ], Sequence [ str | None ] | Sequence [ None ]] :
194196 if coords is None :
195197 coords = {}
196198
199+ if dims is None :
200+ return coords , _handle_none_dims (dims , value .ndim )
201+
197202 if len (dims ) != value .ndim :
198203 raise ShapeError (
199204 "Invalid data shape. The rank of the dataset must match the length of `dims`." ,
@@ -215,7 +220,7 @@ def determine_xarray_coords(
215220 model : "Model" ,
216221 dims : Sequence [str | None ] | None = None ,
217222 coords : dict [str , Sequence | np .ndarray ] | None = None ,
218- ) -> tuple [dict [str , Sequence | np .ndarray ], Sequence [str ] | Sequence [None ]]:
223+ ) -> tuple [dict [str , Sequence | np .ndarray ], Sequence [str | None ] | Sequence [None ]]:
219224 if coords is None :
220225 coords = {}
221226
@@ -231,16 +236,16 @@ def determine_xarray_coords(
231236
232237
233238def _dataframe_agnostic_coords (
234- value : IntoFrameT ,
239+ value : IntoFrameT | IntoLazyFrameT | nw . DataFrame | nw . LazyFrame ,
235240 model : "Model" ,
236241 ndim_in : int = 2 ,
237242 dims : Sequence [str | None ] | None = None ,
238243 coords : dict [str , Sequence | np .ndarray ] | None = None ,
239- ) -> tuple [dict [str , Sequence | np .ndarray ], Sequence [str ] | Sequence [None ]]:
244+ ) -> tuple [dict [str , Sequence | np .ndarray ], Sequence [str | None ] | Sequence [None ]]:
240245 if coords is None :
241246 coords = {}
242247
243- value = nw .from_native (value , allow_series = False )
248+ value = cast ( nw .DataFrame | nw . LazyFrame , nw . from_native (value , allow_series = False )) # type: ignore[type-var]
244249 if isinstance (value , nw .LazyFrame ):
245250 value = value .collect ()
246251
@@ -260,9 +265,9 @@ def _dataframe_agnostic_coords(
260265
261266 index_dim = dims [0 ]
262267 if index_dim is not None and index_dim in value .columns :
263- coords [index_dim ] = value .select (nw .col (index_dim )).to_numpy ()
268+ coords [index_dim ] = tuple ( value .select (nw .col (index_dim )).to_numpy () )
264269 elif index_dim in model .coords :
265- coords [index_dim ] = model .coords [index_dim ]
270+ coords [index_dim ] = model .coords [index_dim ] # type: ignore[assignment]
266271 else :
267272 raise ValueError (
268273 f"Dimension '{ index_dim } ' not found in DataFrame columns or model coordinates. Cannot infer "
@@ -282,9 +287,15 @@ def _series_agnostic_coords(
282287 model : "Model" ,
283288 dims : Sequence [str | None ] | None = None ,
284289 coords : dict [str , Sequence | np .ndarray ] | None = None ,
285- ) -> tuple [dict [str , Sequence | np .ndarray ], Sequence [str ] | Sequence [None ]]:
286- value = nw .from_native (value , series_only = True ).to_frame ()
287- return _dataframe_agnostic_coords (value , ndim_in = 1 , model = model , dims = dims , coords = coords )
290+ ) -> tuple [dict [str , Sequence | np .ndarray ], Sequence [str | None ] | Sequence [None ]]:
291+ value = cast (nw .Series , nw .from_native (value , series_only = True )) # type: ignore[assignment]
292+ return _dataframe_agnostic_coords (
293+ cast (nw .DataFrame | nw .LazyFrame , value .to_frame ()), # type: ignore[attr-defined]
294+ ndim_in = 1 ,
295+ model = model ,
296+ dims = dims ,
297+ coords = coords ,
298+ ) # type: ignore[arg-type]
288299
289300
290301def _register_dataframe_backend (library_name : str ):
@@ -297,7 +308,7 @@ def determine_series_coords(
297308 model : "Model" ,
298309 dims : Sequence [str ] | None = None ,
299310 coords : dict [str , Sequence | np .ndarray ] | None = None ,
300- ) -> tuple [dict [str , Sequence | np .ndarray ], Sequence [str ] | Sequence [None ]]:
311+ ) -> tuple [dict [str , Sequence | np .ndarray ], Sequence [str | None ] | Sequence [None ]]:
301312 return _series_agnostic_coords (value , model = model , dims = dims , coords = coords )
302313
303314 @determine_coords .register (library .DataFrame )
@@ -306,7 +317,7 @@ def determine_dataframe_coords(
306317 model : "Model" ,
307318 dims : Sequence [str ] | None = None ,
308319 coords : dict [str , Sequence | np .ndarray ] | None = None ,
309- ) -> tuple [dict [str , Sequence | np .ndarray ], Sequence [str ] | Sequence [None ]]:
320+ ) -> tuple [dict [str , Sequence | np .ndarray ], Sequence [str | None ] | Sequence [None ]]:
310321 return _dataframe_agnostic_coords (value , model = model , dims = dims , coords = coords )
311322
312323 except ImportError :
@@ -366,6 +377,9 @@ def Data(
366377 infer_dims_and_coords : bool, default=False
367378 If True, the ``Data`` container will try to infer what the coordinates
368379 and dimension names should be if there is an index in ``value``.
380+ model : pymc.Model, optional
381+ Model to which to add the data variable. If not specified, the data variable
382+ will be added to the model on the context stack.
369383 **kwargs : dict, optional
370384 Extra arguments passed to :func:`pytensor.shared`.
371385
@@ -434,7 +448,7 @@ def Data(
434448 expected = x .ndim ,
435449 )
436450
437- new_dims : Sequence [str ] | Sequence [None ] | None
451+ new_dims : Sequence [str | None ] | Sequence [None ] | None
438452 if infer_dims_and_coords :
439453 coords , new_dims = determine_coords (value , model , dims )
440454 else :
0 commit comments