Skip to content

Commit 88f0e07

Browse files
authored
change var to variable (#512)
1 parent 009b5ac commit 88f0e07

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

pymc_extras/prior.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def create_variable(self, name: str) -> pt.TensorVariable:
278278
def sample_prior(
279279
factory: VariableFactory,
280280
coords=None,
281-
name: str = "var",
281+
name: str = "variable",
282282
wrap: bool = False,
283283
**sample_prior_predictive_kwargs,
284284
) -> xr.Dataset:
@@ -292,7 +292,7 @@ def sample_prior(
292292
The coordinates for the variable, by default None.
293293
Only required if the dims are specified.
294294
name : str, optional
295-
The name of the variable, by default "var".
295+
The name of the variable, by default "variable".
296296
wrap : bool, optional
297297
Whether to wrap the variable in a `pm.Deterministic` node, by default False.
298298
sample_prior_predictive_kwargs : dict
@@ -900,7 +900,7 @@ def __eq__(self, other) -> bool:
900900
def sample_prior(
901901
self,
902902
coords=None,
903-
name: str = "var",
903+
name: str = "variable",
904904
**sample_prior_predictive_kwargs,
905905
) -> xr.Dataset:
906906
"""Sample the prior distribution for the variable.
@@ -911,7 +911,7 @@ def sample_prior(
911911
The coordinates for the variable, by default None.
912912
Only required if the dims are specified.
913913
name : str, optional
914-
The name of the variable, by default "var".
914+
The name of the variable, by default "variable".
915915
sample_prior_predictive_kwargs : dict
916916
Additional arguments to pass to `pm.sample_prior_predictive`.
917917

tests/test_prior.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,10 @@ def test_custom_transform() -> None:
616616
prior = dist.sample_prior(draws=10)
617617
df_prior = prior.to_dataframe()
618618

619-
np.testing.assert_array_equal(df_prior["var"].to_numpy(), df_prior["var_raw"].to_numpy() ** 2)
619+
np.testing.assert_array_equal(
620+
df_prior.variable.to_numpy(),
621+
df_prior.variable_raw.to_numpy() ** 2,
622+
)
620623

621624

622625
def test_custom_transform_comes_first() -> None:
@@ -627,7 +630,10 @@ def test_custom_transform_comes_first() -> None:
627630
prior = dist.sample_prior(draws=10)
628631
df_prior = prior.to_dataframe()
629632

630-
np.testing.assert_array_equal(df_prior["var"].to_numpy(), 2 * df_prior["var_raw"].to_numpy())
633+
np.testing.assert_array_equal(
634+
df_prior.variable.to_numpy(),
635+
2 * df_prior.variable_raw.to_numpy(),
636+
)
631637

632638
clear_custom_transforms()
633639

@@ -686,7 +692,7 @@ def test_sample_prior_arbitrary_no_name() -> None:
686692
prior = sample_prior(var, coords={"channel": ["A", "B", "C"]}, draws=25)
687693

688694
assert isinstance(prior, xr.Dataset)
689-
assert "var" not in prior
695+
assert "variable" not in prior
690696

691697
prior_with = sample_prior(
692698
var,
@@ -696,7 +702,7 @@ def test_sample_prior_arbitrary_no_name() -> None:
696702
)
697703

698704
assert isinstance(prior_with, xr.Dataset)
699-
assert "var" in prior_with
705+
assert "variable" in prior_with
700706

701707

702708
def test_create_prior_with_arbitrary() -> None:

0 commit comments

Comments
 (0)