Skip to content

Commit e8f6208

Browse files
Narwhals compatibility layer in pm.Data
1 parent 88ee6a1 commit e8f6208

12 files changed

+236
-57
lines changed

conda-envs/environment-alternative-backends.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies:
1818
- jaxlib>=0.4.28
1919
- libblas=*=*mkl
2020
- mkl-service
21+
- narwhals>=2.11.0
2122
- numpy>=1.25.0
2223
- numpyro>=0.8.0
2324
- pandas>=0.24.0

conda-envs/environment-dev.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ dependencies:
99
- blas
1010
- cachetools>=4.2.1
1111
- cloudpickle
12+
- narwhals>=2.11.0
1213
- numpy>=1.25.0
1314
- pandas>=0.24.0
1415
- pip

conda-envs/environment-docs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ dependencies:
88
- arviz>=0.13.0
99
- cachetools>=4.2.1
1010
- cloudpickle
11+
- narwhals>=2.11.0
1112
- numpy>=1.25.0
1213
- pandas>=0.24.0
1314
- pip

conda-envs/environment-test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dependencies:
1010
- cachetools>=4.2.1
1111
- cloudpickle
1212
- jax
13+
- narwhals>=2.11.0
1314
- numpy>=1.25.0
1415
- pandas>=0.24.0
1516
- pip

conda-envs/windows-environment-dev.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ dependencies:
99
- blas
1010
- cachetools>=4.2.1
1111
- cloudpickle
12+
- narwhals>=2.11.0
1213
- numpy>=1.25.0
1314
- pandas>=0.24.0
1415
- pip

conda-envs/windows-environment-test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies:
1111
- cloudpickle
1212
- libpython
1313
- mkl-service>=2.3.0
14+
- narwhals>=2.11.0
1415
- numpy>=1.25.0
1516
- pandas>=0.24.0
1617
- pip

pymc/data.py

Lines changed: 147 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,25 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import importlib
1515
import io
1616
import typing
1717
import urllib.request
1818

1919
from collections.abc import Sequence
2020
from copy import copy
21+
from functools import singledispatch
2122
from typing import Union, cast
2223

2324
import narwhals as nw
2425
import numpy as np
25-
import pandas as pd
2626
import pytensor
2727
import pytensor.tensor as pt
2828
import xarray as xr
2929

3030
from narwhals.typing import IntoFrameT, IntoSeriesT
31+
from pytensor.compile import SharedVariable
3132
from pytensor.compile.builders import OpFromGraph
32-
from pytensor.compile.sharedvalue import SharedVariable
3333
from pytensor.graph.basic import Variable
3434
from pytensor.raise_op import Assert
3535
from pytensor.tensor.random.basic import IntegersRV
@@ -163,60 +163,159 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
163163
return mb_tensors if len(variables) else mb_tensors[0]
164164

165165

166+
def _handle_none_dims(dims: Sequence[str] | None, ndim: int) -> Sequence[str] | Sequence[None]:
167+
if dims is None:
168+
return [None] * ndim
169+
else:
170+
return dims
171+
172+
173+
@singledispatch
166174
def determine_coords(
167-
model,
168-
value: pd.DataFrame | pd.Series | xr.DataArray,
175+
value,
176+
model: "Model",
177+
dims: Sequence[str | None] | None = None,
178+
coords: dict[str, Sequence | np.ndarray] | None = None,
179+
):
180+
"""Determine coordinate values from data or the model (via ``dims``)."""
181+
raise NotImplementedError(
182+
f"Cannot determine coordinates for data of type {type(value)}, please provide `coords` explicitly or "
183+
f"convert the data to a supported type"
184+
)
185+
186+
187+
@determine_coords.register(np.ndarray)
188+
def determine_array_coords(
189+
value: np.ndarray,
190+
model: "Model",
169191
dims: Sequence[str] | None = None,
170192
coords: dict[str, Sequence | np.ndarray] | None = None,
193+
):
194+
if coords is None:
195+
coords = {}
196+
197+
if len(dims) != value.ndim:
198+
raise ShapeError(
199+
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
200+
actual=value.shape,
201+
expected=len(value.shape),
202+
)
203+
204+
for size, dim in zip(value.shape, dims):
205+
coord = model.coords.get(dim, None)
206+
if coord is None and dim is not None:
207+
coords[dim] = range(size)
208+
209+
return coords, _handle_none_dims(dims, value.ndim)
210+
211+
212+
@determine_coords.register(xr.DataArray)
213+
def determine_xarray_coords(
214+
value: xr.DataArray,
215+
model: "Model",
216+
dims: Sequence[str | None] | None = None,
217+
coords: dict[str, Sequence | np.ndarray] | None = None,
171218
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]:
172-
"""Determine coordinate values from data or the model (via ``dims``)."""
173219
if coords is None:
174220
coords = {}
175221

176-
dim_name = None
177-
# If value is a df or a series, we interpret the index as coords:
178-
if hasattr(value, "index"):
179-
if dims is not None:
180-
dim_name = dims[0]
181-
if dim_name is None and value.index.name is not None:
182-
dim_name = value.index.name
183-
if dim_name is not None:
184-
coords[dim_name] = value.index
185-
186-
# If value is a df, we also interpret the columns as coords:
187-
if hasattr(value, "columns"):
188-
if dims is not None:
189-
dim_name = dims[1]
190-
if dim_name is None and nw.dependencies.is_pandas_dataframe(value) and value.columns.name is not None:
191-
dim_name = value.columns.name
192-
if dim_name is not None:
193-
coords[dim_name] = value.columns
194-
195-
if isinstance(value, xr.DataArray):
196-
if dims is not None:
197-
for dim in dims:
198-
dim_name = dim
199-
# str is applied because dim entries may be None
200-
coords[str(dim_name)] = cast(xr.DataArray, value[dim]).to_numpy()
201-
202-
elif (isinstance(value, np.ndarray) or nw.dependencies.is_polars_series(value)) and dims is not None:
203-
if len(dims) != len(value.shape): # Polars objects have no .ndim ...
204-
raise ShapeError(
205-
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
206-
actual=value.shape,
207-
expected=len(value.shape),
208-
)
209-
for size, dim in zip(value.shape, dims):
210-
coord = model.coords.get(dim, None)
211-
if coord is None and dim is not None:
212-
coords[dim] = range(size)
222+
if dims is None:
223+
return coords, _handle_none_dims(dims, value.ndim)
224+
225+
for dim in dims:
226+
dim_name = dim
227+
# str is applied because dim entries may be None
228+
coords[str(dim_name)] = cast(xr.DataArray, value[dim]).to_numpy()
229+
230+
return coords, _handle_none_dims(dims, value.ndim)
231+
232+
233+
def _dataframe_agnostic_coords(
234+
value: IntoFrameT,
235+
model: "Model",
236+
ndim_in: int = 2,
237+
dims: Sequence[str | None] | None = None,
238+
coords: dict[str, Sequence | np.ndarray] | None = None,
239+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]:
240+
if coords is None:
241+
coords = {}
242+
243+
value = nw.from_native(value, allow_series=False)
244+
if isinstance(value, nw.LazyFrame):
245+
value = value.collect()
246+
247+
index = nw.maybe_get_index(value)
248+
if index is not None:
249+
value = value.with_columns(**{index.name: index.to_numpy()})
213250

214251
if dims is None:
215-
# TODO: Also determine dim names from the index
216-
new_dims: Sequence[str] | Sequence[None] = [None] * np.ndim(value)
252+
return coords, _handle_none_dims(dims, ndim_in)
253+
254+
if len(dims) != ndim_in:
255+
raise ShapeError(
256+
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
257+
actual=value.shape,
258+
expected=len(dims),
259+
)
260+
261+
index_dim = dims[0]
262+
if index_dim is not None and index_dim in value.columns:
263+
coords[index_dim] = value.select(nw.col(index_dim)).to_numpy()
264+
elif index_dim in model.coords:
265+
coords[index_dim] = model.coords[index_dim]
217266
else:
218-
new_dims = dims
219-
return coords, new_dims
267+
raise ValueError(
268+
f"Dimension '{index_dim}' not found in DataFrame columns or model coordinates. Cannot infer "
269+
"index coordinates."
270+
)
271+
272+
if len(dims) > 1:
273+
column_dim = dims[1]
274+
if column_dim is not None:
275+
coords[column_dim] = value.select(nw.exclude(index_dim)).columns
276+
277+
return coords, _handle_none_dims(dims, ndim_in)
278+
279+
280+
def _series_agnostic_coords(
281+
value: IntoSeriesT,
282+
model: "Model",
283+
dims: Sequence[str | None] | None = None,
284+
coords: dict[str, Sequence | np.ndarray] | None = None,
285+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]:
286+
value = nw.from_native(value, series_only=True).to_frame()
287+
return _dataframe_agnostic_coords(value, ndim_in=1, model=model, dims=dims, coords=coords)
288+
289+
290+
def _register_dataframe_backend(library_name: str):
291+
try:
292+
library = importlib.import_module(library_name)
293+
294+
@determine_coords.register(library.Series)
295+
def determine_series_coords(
296+
value: library.DataFrame | library.Series,
297+
model: "Model",
298+
dims: Sequence[str] | None = None,
299+
coords: dict[str, Sequence | np.ndarray] | None = None,
300+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]:
301+
return _series_agnostic_coords(value, model=model, dims=dims, coords=coords)
302+
303+
@determine_coords.register(library.DataFrame)
304+
def determine_dataframe_coords(
305+
value: library.DataFrame | library.Series,
306+
model: "Model",
307+
dims: Sequence[str] | None = None,
308+
coords: dict[str, Sequence | np.ndarray] | None = None,
309+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]:
310+
return _dataframe_agnostic_coords(value, model=model, dims=dims, coords=coords)
311+
312+
except ImportError:
313+
pass
314+
315+
316+
_register_dataframe_backend("pandas")
317+
_register_dataframe_backend("polars")
318+
_register_dataframe_backend("dask.dataframe")
220319

221320

222321
def Data(
@@ -337,7 +436,7 @@ def Data(
337436

338437
new_dims: Sequence[str] | Sequence[None] | None
339438
if infer_dims_and_coords:
340-
coords, new_dims = determine_coords(model, value, dims)
439+
coords, new_dims = determine_coords(value, model, dims)
341440
else:
342441
new_dims = dims
343442

pymc/pytensorf.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import importlib
1415
import warnings
1516

1617
from collections.abc import Iterable, Sequence
1718
from typing import cast
1819

20+
import narwhals as nw
1921
import numpy as np
20-
import pandas as pd
2122
import pytensor
2223
import pytensor.tensor as pt
2324
import scipy.sparse as sps
@@ -128,11 +129,32 @@ def convert_data(data) -> np.ndarray | Variable:
128129
return smarttypeX(ret)
129130

130131

131-
@_as_tensor_variable.register(pd.Series)
132-
@_as_tensor_variable.register(pd.DataFrame)
133-
def dataframe_to_tensor_variable(df: pd.DataFrame, *args, **kwargs) -> TensorVariable:
134-
return pt.as_tensor_variable(df.to_numpy(), *args, **kwargs)
132+
# Optional registrations for DataFrame packages
133+
def _register_dataframe_backend(library_name: str):
134+
try:
135+
library = importlib.import_module(library_name)
136+
137+
@_as_tensor_variable.register(library.Series)
138+
def series_to_tensor_variable(s: library.Series, *args, **kwargs) -> TensorVariable:
139+
s = nw.from_native(s, allow_series=False)
140+
if isinstance(s, nw.LazyFrame):
141+
s = s.collect()
142+
return pt.as_tensor_variable(s.to_numpy(), *args, **kwargs)
143+
144+
@_as_tensor_variable.register(library.DataFrame)
145+
def dataframe_to_tensor_variable(df: library.DataFrame, *args, **kwargs) -> TensorVariable:
146+
df = nw.from_native(df, allow_series=False)
147+
if isinstance(df, nw.LazyFrame):
148+
df = df.collect()
149+
return pt.as_tensor_variable(df.to_numpy(), *args, **kwargs)
150+
151+
except ImportError:
152+
pass
153+
135154

155+
_register_dataframe_backend("pandas")
156+
_register_dataframe_backend("polars")
157+
_register_dataframe_backend("dask.dataframe")
136158

137159
_cheap_eval_mode = Mode(linker="py", optimizer="minimum_compile")
138160

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ jupyter-sphinx
99
mcbackend>=0.4.0
1010
mypy==1.15.0
1111
myst-nb<=1.0.0
12+
narwhals>=2.11.0
1213
numdifftools>=0.9.40
1314
numpy>=1.25.0
1415
numpydoc

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
arviz>=0.13.0
22
cachetools>=4.2.1
33
cloudpickle
4+
narwhals>=2.11.0
45
numpy>=1.25.0
56
pandas>=0.24.0
67
pytensor>=2.35.0,<2.36

0 commit comments

Comments
 (0)