|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
16 | 16 |
|
| 17 | +import warnings |
17 | 18 | from multiprocessing import Manager
|
18 | 19 | from typing import List, Optional, Tuple
|
19 |
| -import warnings |
20 | 20 |
|
21 | 21 | import numpy as np
|
22 | 22 | import numpy.typing as npt
|
|
26 | 26 | from pymc.logprob.abstract import _logprob
|
27 | 27 | from pytensor.tensor.random.op import RandomVariable
|
28 | 28 |
|
| 29 | +from .split_rules import SplitRule |
29 | 30 | from .tree import Tree
|
30 | 31 | from .utils import TensorLike, _sample_posterior
|
31 |
| -from .split_rules import SplitRule |
32 | 32 |
|
33 | 33 | __all__ = ["BART"]
|
34 | 34 |
|
@@ -93,7 +93,7 @@ class BART(Distribution):
|
93 | 93 | Each element of split_prior should be in the [0, 1] interval and the elements should sum to
|
94 | 94 | 1. Otherwise they will be normalized.
|
95 | 95 | Defaults to 0, i.e. all covariates have the same prior probability to be selected.
|
96 |
| - split_rules : Optional[SplitRule], default None |
| 96 | + split_rules : Optional[List[SplitRule]], default None |
97 | 97 | List of SplitRule objects, one per column in input data.
|
98 | 98 | Allows using different split rules for different columns. Default is ContinuousSplitRule.
|
99 | 99 | Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
|
@@ -127,7 +127,7 @@ def __new__(
|
127 | 127 | beta: float = 2.0,
|
128 | 128 | response: str = "constant",
|
129 | 129 | split_prior: Optional[List[float]] = None,
|
130 |
| - split_rules: Optional[SplitRule] = None, |
| 130 | + split_rules: Optional[List[SplitRule]] = None, |
131 | 131 | separate_trees: Optional[bool] = False,
|
132 | 132 | **kwargs,
|
133 | 133 | ):
|
|
0 commit comments