Skip to content

Commit 99672d9

Browse files
Change to Distribution (#15)
* change to Distribution * fix test Co-authored-by: aloctavodia <[email protected]>
1 parent 6529075 commit 99672d9

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

pymc_bart/bart.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from aesara.tensor.random.op import RandomVariable
2222
from pandas import DataFrame, Series
2323

24-
from pymc.distributions.distribution import NoDistribution, _moment
24+
from pymc.distributions.distribution import Distribution, _moment
2525

2626
__all__ = ["BART"]
2727

@@ -50,7 +50,7 @@ def rng_fn(cls, rng, X, Y, m, alpha, split_prior, size):
5050
bart = BARTRV()
5151

5252

53-
class BART(NoDistribution):
53+
class BART(Distribution):
5454
"""
5555
Bayesian Additive Regression Tree distribution.
5656
@@ -104,7 +104,7 @@ def __new__(
104104
),
105105
)()
106106

107-
NoDistribution.register(BARTRV)
107+
Distribution.register(BARTRV)
108108

109109
@_moment.register(BARTRV)
110110
def get_moment(rv, size, *rv_inputs):

tests/test_bart.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import pytest
44
from numpy.random import RandomState
55
from numpy.testing import assert_almost_equal, assert_array_equal
6-
from pymc.tests.test_distributions_moments import assert_moment_is_expected
6+
from pymc.tests.distributions.util import assert_moment_is_expected
77

88
import pymc_bart as pmb
99

10+
1011
def test_split_node():
1112
split_node = pmb.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0)
1213
assert split_node.index == 5

0 commit comments

Comments
 (0)