@@ -43,7 +43,9 @@ class BARTRV(RandomVariable):
43
43
_print_name : Tuple [str , str ] = ("BART" , "\\ operatorname{BART}" )
44
44
all_trees = List [List [List [Tree ]]]
45
45
46
- def _supp_shape_from_params (self , dist_params , rep_param_idx = 1 , param_shapes = None ):
46
+ def _supp_shape_from_params (
47
+ self , dist_params , rep_param_idx = 1 , param_shapes = None
48
+ ): # pylint: disable=arguments-renamed
47
49
return dist_params [0 ].shape [:1 ]
48
50
49
51
@classmethod
@@ -126,7 +128,7 @@ def __new__(
126
128
alpha : float = 0.95 ,
127
129
beta : float = 2.0 ,
128
130
response : str = "constant" ,
129
- split_prior : Optional [List [ float ]] = None ,
131
+ split_prior : Optional [npt . NDArray [ np . float_ ]] = None ,
130
132
split_rules : Optional [List [SplitRule ]] = None ,
131
133
separate_trees : Optional [bool ] = False ,
132
134
** kwargs ,
@@ -141,8 +143,7 @@ def __new__(
141
143
142
144
X , Y = preprocess_xy (X , Y )
143
145
144
- if split_prior is None :
145
- split_prior = []
146
+ split_prior = np .array ([]) if split_prior is None else np .asarray (split_prior )
146
147
147
148
bart_op = type (
148
149
f"BART_{ name } " ,
0 commit comments