Skip to content

Commit 592c6b2

Browse files
author
Juan Orduz
authored
small improvements (#108)
1 parent 05467b3 commit 592c6b2

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

pymc_bart/bart.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import warnings
1718
from multiprocessing import Manager
1819
from typing import List, Optional, Tuple
19-
import warnings
2020

2121
import numpy as np
2222
import numpy.typing as npt
@@ -26,9 +26,9 @@
2626
from pymc.logprob.abstract import _logprob
2727
from pytensor.tensor.random.op import RandomVariable
2828

29+
from .split_rules import SplitRule
2930
from .tree import Tree
3031
from .utils import TensorLike, _sample_posterior
31-
from .split_rules import SplitRule
3232

3333
__all__ = ["BART"]
3434

@@ -93,7 +93,7 @@ class BART(Distribution):
9393
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
9494
1. Otherwise they will be normalized.
9595
Defaults to 0, i.e. all covariates have the same prior probability to be selected.
96-
split_rules : Optional[SplitRule], default None
96+
split_rules : Optional[List[SplitRule]], default None
9797
List of SplitRule objects, one per column in input data.
9898
Allows using different split rules for different columns. Default is ContinuousSplitRule.
9999
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
@@ -127,7 +127,7 @@ def __new__(
127127
beta: float = 2.0,
128128
response: str = "constant",
129129
split_prior: Optional[List[float]] = None,
130-
split_rules: Optional[SplitRule] = None,
130+
split_rules: Optional[List[SplitRule]] = None,
131131
separate_trees: Optional[bool] = False,
132132
**kwargs,
133133
):

pymc_bart/tree.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import numpy.typing as npt
2020
from pytensor import config
21+
2122
from .split_rules import SplitRule
2223

2324

@@ -101,6 +102,10 @@ class Tree:
101102
of the tree itself.
102103
output: Optional[npt.NDArray[np.float_]]
103104
Array of shape number of observations, shape
105+
split_rules : List[SplitRule]
106+
List of SplitRule objects, one per column in input data.
107+
Allows using different split rules for different columns. Default is ContinuousSplitRule.
108+
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
104109
idx_leaf_nodes : Optional[List[int]], by default None.
105110
Array with the index of the leaf nodes of the tree.
106111

pymc_bart/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def plot_pdp(
357357
func : Optional[Callable], by default None.
358358
Arbitrary function to apply to the predictions. Defaults to the identity function.
359359
samples : int
360-
Number of posterior samples used in the predictions. Defaults to 400
360+
Number of posterior samples used in the predictions. Defaults to 200
361361
random_seed : Optional[int], by default None.
362362
Seed used to sample from the posterior. Defaults to None.
363363
sharey : bool

0 commit comments

Comments
 (0)