Skip to content

Commit e3f16b9

Browse files
mypy T___T
1 parent 9bcd9de commit e3f16b9

File tree

1 file changed

+30
-16
lines changed

1 file changed

+30
-16
lines changed

pymc/data.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import pytensor.tensor as pt
2828
import xarray as xr
2929

30-
from narwhals.typing import IntoFrameT, IntoSeriesT
30+
from narwhals.typing import IntoFrameT, IntoLazyFrameT, IntoSeriesT
3131
from pytensor.compile import SharedVariable
3232
from pytensor.compile.builders import OpFromGraph
3333
from pytensor.graph.basic import Variable
@@ -163,7 +163,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
163163
return mb_tensors if len(variables) else mb_tensors[0]
164164

165165

166-
def _handle_none_dims(dims: Sequence[str] | None, ndim: int) -> Sequence[str] | Sequence[None]:
166+
def _handle_none_dims(
167+
dims: Sequence[str | None] | None, ndim: int
168+
) -> Sequence[str | None] | Sequence[None]:
167169
if dims is None:
168170
return [None] * ndim
169171
else:
@@ -176,7 +178,7 @@ def determine_coords(
176178
model: "Model",
177179
dims: Sequence[str | None] | None = None,
178180
coords: dict[str, Sequence | np.ndarray] | None = None,
179-
):
181+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
180182
"""Determine coordinate values from data or the model (via ``dims``)."""
181183
raise NotImplementedError(
182184
f"Cannot determine coordinates for data of type {type(value)}, please provide `coords` explicitly or "
@@ -190,10 +192,13 @@ def determine_array_coords(
190192
model: "Model",
191193
dims: Sequence[str] | None = None,
192194
coords: dict[str, Sequence | np.ndarray] | None = None,
193-
):
195+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
194196
if coords is None:
195197
coords = {}
196198

199+
if dims is None:
200+
return coords, _handle_none_dims(dims, value.ndim)
201+
197202
if len(dims) != value.ndim:
198203
raise ShapeError(
199204
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
@@ -215,7 +220,7 @@ def determine_xarray_coords(
215220
model: "Model",
216221
dims: Sequence[str | None] | None = None,
217222
coords: dict[str, Sequence | np.ndarray] | None = None,
218-
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]:
223+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
219224
if coords is None:
220225
coords = {}
221226

@@ -231,16 +236,16 @@ def determine_xarray_coords(
231236

232237

233238
def _dataframe_agnostic_coords(
234-
value: IntoFrameT,
239+
value: IntoFrameT | IntoLazyFrameT | nw.DataFrame | nw.LazyFrame,
235240
model: "Model",
236241
ndim_in: int = 2,
237242
dims: Sequence[str | None] | None = None,
238243
coords: dict[str, Sequence | np.ndarray] | None = None,
239-
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]:
244+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
240245
if coords is None:
241246
coords = {}
242247

243-
value = nw.from_native(value, allow_series=False)
248+
value = cast(nw.DataFrame | nw.LazyFrame, nw.from_native(value, allow_series=False)) # type: ignore[type-var]
244249
if isinstance(value, nw.LazyFrame):
245250
value = value.collect()
246251

@@ -260,9 +265,9 @@ def _dataframe_agnostic_coords(
260265

261266
index_dim = dims[0]
262267
if index_dim is not None and index_dim in value.columns:
263-
coords[index_dim] = value.select(nw.col(index_dim)).to_numpy()
268+
coords[index_dim] = tuple(value.select(nw.col(index_dim)).to_numpy())
264269
elif index_dim in model.coords:
265-
coords[index_dim] = model.coords[index_dim]
270+
coords[index_dim] = model.coords[index_dim] # type: ignore[assignment]
266271
else:
267272
raise ValueError(
268273
f"Dimension '{index_dim}' not found in DataFrame columns or model coordinates. Cannot infer "
@@ -282,9 +287,15 @@ def _series_agnostic_coords(
282287
model: "Model",
283288
dims: Sequence[str | None] | None = None,
284289
coords: dict[str, Sequence | np.ndarray] | None = None,
285-
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]:
286-
value = nw.from_native(value, series_only=True).to_frame()
287-
return _dataframe_agnostic_coords(value, ndim_in=1, model=model, dims=dims, coords=coords)
290+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
291+
value = cast(nw.Series, nw.from_native(value, series_only=True)) # type: ignore[assignment]
292+
return _dataframe_agnostic_coords(
293+
cast(nw.DataFrame | nw.LazyFrame, value.to_frame()), # type: ignore[attr-defined]
294+
ndim_in=1,
295+
model=model,
296+
dims=dims,
297+
coords=coords,
298+
) # type: ignore[arg-type]
288299

289300

290301
def _register_dataframe_backend(library_name: str):
@@ -297,7 +308,7 @@ def determine_series_coords(
297308
model: "Model",
298309
dims: Sequence[str] | None = None,
299310
coords: dict[str, Sequence | np.ndarray] | None = None,
300-
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]:
311+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
301312
return _series_agnostic_coords(value, model=model, dims=dims, coords=coords)
302313

303314
@determine_coords.register(library.DataFrame)
@@ -306,7 +317,7 @@ def determine_dataframe_coords(
306317
model: "Model",
307318
dims: Sequence[str] | None = None,
308319
coords: dict[str, Sequence | np.ndarray] | None = None,
309-
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]:
320+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
310321
return _dataframe_agnostic_coords(value, model=model, dims=dims, coords=coords)
311322

312323
except ImportError:
@@ -366,6 +377,9 @@ def Data(
366377
infer_dims_and_coords : bool, default=False
367378
If True, the ``Data`` container will try to infer what the coordinates
368379
and dimension names should be if there is an index in ``value``.
380+
model : pymc.Model, optional
381+
Model to which to add the data variable. If not specified, the data variable
382+
will be added to the model on the context stack.
369383
**kwargs : dict, optional
370384
Extra arguments passed to :func:`pytensor.shared`.
371385
@@ -434,7 +448,7 @@ def Data(
434448
expected=x.ndim,
435449
)
436450

437-
new_dims: Sequence[str] | Sequence[None] | None
451+
new_dims: Sequence[str | None] | Sequence[None] | None
438452
if infer_dims_and_coords:
439453
coords, new_dims = determine_coords(value, model, dims)
440454
else:

0 commit comments

Comments
 (0)