Skip to content

Commit 8ee9254

Browse files
authored
Skip coords with scalar value (#868)
1 parent b7b97e7 commit 8ee9254

File tree

3 files changed

+66
-0
lines changed

3 files changed

+66
-0
lines changed

pymc_marketing/mmm/plot.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import numpy.typing as npt
2323
import xarray as xr
2424

25+
from pymc_marketing.mmm.utils import drop_scalar_coords
26+
2527
Values = Sequence[Any] | npt.NDArray[Any]
2628
Coords = dict[str, Values]
2729

@@ -100,6 +102,8 @@ def plot_hdi(
100102
Figure and the axes
101103
102104
"""
105+
curve = drop_scalar_coords(curve)
106+
103107
hdi_kwargs = hdi_kwargs or {}
104108
conf = az.hdi(curve, **hdi_kwargs)[curve.name]
105109

@@ -190,6 +194,8 @@ def plot_samples(
190194
Figure and the axes
191195
192196
"""
197+
curve = drop_scalar_coords(curve)
198+
193199
plot_coords = get_plot_coords(
194200
curve.coords,
195201
non_grid_names=non_grid_names.union({"chain", "draw"}),
@@ -262,6 +268,7 @@ def plot_curve(
262268
Figure and the axes
263269
264270
"""
271+
curve = drop_scalar_coords(curve)
265272

266273
hdi_kwargs = hdi_kwargs or {}
267274
sample_kwargs = sample_kwargs or {}

pymc_marketing/mmm/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,3 +361,30 @@ def create_new_spend_data(
361361
spend,
362362
]
363363
)
364+
365+
366+
def drop_scalar_coords(curve: xr.DataArray) -> xr.DataArray:
367+
"""
368+
Remove scalar coordinates from an xarray DataArray.
369+
370+
This function identifies and removes scalar coordinates from the given
371+
DataArray. Scalar coordinates are those with a single value that are
372+
not part of the DataArray's indexes. The function returns a new DataArray
373+
with the scalar coordinates removed.
374+
375+
Parameters
376+
----------
377+
curve : xr.DataArray
378+
The input DataArray from which scalar coordinates will be removed.
379+
380+
Returns
381+
-------
382+
xr.DataArray
383+
A new DataArray with the identified scalar coordinates removed.
384+
"""
385+
scalar_coords_to_drop = []
386+
for coord, values in curve.coords.items():
387+
if values.size == 1 and coord not in curve.indexes:
388+
scalar_coords_to_drop.append(coord)
389+
390+
return curve.reset_coords(scalar_coords_to_drop, drop=True)

tests/mmm/test_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
apply_sklearn_transformer_across_dim,
2222
compute_sigmoid_second_derivative,
2323
create_new_spend_data,
24+
drop_scalar_coords,
2425
estimate_menten_parameters,
2526
estimate_sigmoid_parameters,
2627
find_sigmoid_inflection_point,
@@ -281,3 +282,34 @@ def test_create_new_spend_data_value_errors() -> None:
281282
one_time=True,
282283
spend_leading_up=np.array([3, 4, 5]),
283284
)
285+
286+
287+
@pytest.fixture
288+
def mock_curve_with_scalars() -> xr.DataArray:
289+
coords = {
290+
"x": [1, 2, 3],
291+
"y": [10, 20, 30],
292+
"scalar1": 42, # Scalar coordinate
293+
"scalar2": 3.14, # Another scalar coordinate
294+
}
295+
data = np.random.rand(3, 3)
296+
return xr.DataArray(data, coords=coords, dims=["x", "y"])
297+
298+
299+
def test_drop_scalar_coords(mock_curve_with_scalars) -> None:
300+
original_curve = mock_curve_with_scalars.copy(deep=True) # Make a deep copy
301+
curve = drop_scalar_coords(mock_curve_with_scalars)
302+
303+
# Ensure scalar coordinates are removed
304+
assert "scalar1" not in curve.coords
305+
assert "scalar2" not in curve.coords
306+
307+
# Ensure other coordinates are still present
308+
assert "x" in curve.coords
309+
assert "y" in curve.coords
310+
311+
# Ensure data shape is unchanged
312+
assert curve.shape == (3, 3)
313+
314+
# Ensure the original DataArray was not modified
315+
xr.testing.assert_identical(mock_curve_with_scalars, original_curve)

0 commit comments

Comments
 (0)