Skip to content

Commit c8997c3

Browse files
authored
fix linear response bug (#101)
1 parent 8c92834 commit c8997c3

File tree

4 files changed

+23
-27
lines changed

4 files changed

+23
-27
lines changed

pymc_bart/pgbart.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -536,15 +536,14 @@ def draw_leaf_value(
536536
return np.zeros(shape), linear_params
537537

538538
if y_mu_pred.size == 1:
539-
mu_mean = np.full(shape, y_mu_pred.item() / m)
539+
mu_mean = np.full(shape, y_mu_pred.item() / m) + norm
540540
else:
541-
if response == "constant":
542-
mu_mean = fast_mean(y_mu_pred) / m
543-
if response == "linear":
544-
mu_mean, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred, m=m)
541+
if y_mu_pred.size < 3 or response == "constant":
542+
mu_mean = fast_mean(y_mu_pred) / m + norm
543+
else:
544+
mu_mean, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred, m=m, norm=norm)
545545

546-
draw = mu_mean + norm
547-
return draw, linear_params
546+
return mu_mean, linear_params
548547

549548

550549
@njit
@@ -570,9 +569,10 @@ def fast_linear_fit(
570569
x: npt.NDArray[np.float_],
571570
y: npt.NDArray[np.float_],
572571
m: int,
572+
norm: npt.NDArray[np.float_],
573573
) -> Tuple[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]]:
574574
n = len(x)
575-
y = y / m
575+
y = y / m + np.expand_dims(norm, axis=1)
576576

577577
xbar = np.sum(x) / n
578578
ybar = np.sum(y, axis=1) / n

pymc_bart/tree.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,6 @@ def new_leaf_node(
6565
linear_params=linear_params,
6666
)
6767

68-
@classmethod
69-
def new_split_node(cls, split_value: npt.NDArray[np.float_], idx_split_variable: int) -> "Node":
70-
return cls(value=split_value, idx_split_variable=idx_split_variable)
71-
7268
def is_split_node(self) -> bool:
7369
return self.idx_split_variable >= 0
7470

@@ -282,42 +278,42 @@ def _traverse_tree(
282278
"""
283279

284280
x_shape = (1,) if len(X.shape) == 1 else X.shape[:-1]
281+
nd_dims = (...,) + (None,) * len(x_shape)
285282

286-
stack = [(0, np.ones(x_shape))] # (node_index, weight) initial state
283+
stack = [(0, np.ones(x_shape), 0)] # (node_index, weight, idx_split_variable) initial state
287284
p_d = (
288285
np.zeros(shape + x_shape) if isinstance(shape, tuple) else np.zeros((shape,) + x_shape)
289286
)
290287
while stack:
291-
node_index, weights = stack.pop()
288+
node_index, weights, idx_split_variable = stack.pop()
292289
node = self.get_node(node_index)
293290
if node.is_leaf_node():
294291
params = node.linear_params
295-
nd_dims = (...,) + (None,) * len(x_shape)
296292
if params is None:
297293
p_d += weights * node.value[nd_dims]
298294
else:
299-
# this produce nonsensical results
300295
p_d += weights * (
301-
params[0][nd_dims] + params[1][nd_dims] * X[..., node.idx_split_variable]
296+
params[0][nd_dims] + params[1][nd_dims] * X[..., idx_split_variable]
302297
)
303-
# this produce reasonable result
304-
# p_d += weight * node.value.mean()
305298
else:
306299
left_node_index, right_node_index = get_idx_left_child(
307300
node_index
308301
), get_idx_right_child(node_index)
302+
idx_split_variable = node.idx_split_variable
309303
if excluded is not None and node.idx_split_variable in excluded:
310304
prop_nvalue_left = self.get_node(left_node_index).nvalue / node.nvalue
311-
stack.append((left_node_index, weights * prop_nvalue_left))
312-
stack.append((right_node_index, weights * (1 - prop_nvalue_left)))
305+
stack.append((left_node_index, weights * prop_nvalue_left, idx_split_variable))
306+
stack.append(
307+
(right_node_index, weights * (1 - prop_nvalue_left), idx_split_variable)
308+
)
313309
else:
314310
to_left = (
315-
self.split_rules[node.idx_split_variable]
316-
.divide(X[..., node.idx_split_variable], node.value)
311+
self.split_rules[idx_split_variable]
312+
.divide(X[..., idx_split_variable], node.value)
317313
.astype("float")
318314
)
319-
stack.append((left_node_index, weights * to_left))
320-
stack.append((right_node_index, weights * (1 - to_left)))
315+
stack.append((left_node_index, weights * to_left, idx_split_variable))
316+
stack.append((right_node_index, weights * (1 - to_left), idx_split_variable))
321317

322318
if len(X.shape) == 1:
323319
p_d = p_d[..., 0]

tests/test_pgbart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_fast_mean():
5454
ids=["1d-id", "1d-const"],
5555
)
5656
def test_fast_linear_fit(x, y, a_expected, b_expected):
57-
y_fit, linear_params = fast_linear_fit(x, y, m=1)
57+
y_fit, linear_params = fast_linear_fit(x, y, m=1, norm=np.zeros(1))
5858
assert linear_params[0] == a_expected
5959
assert linear_params[1] == b_expected
6060
np.testing.assert_almost_equal(

tests/test_tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
def test_split_node():
77
index = 5
8-
split_node = Node.new_split_node(idx_split_variable=2, split_value=3.0)
8+
split_node = Node(idx_split_variable=2, value=3.0)
99
assert get_depth(index) == 2
1010
assert split_node.value == 3.0
1111
assert split_node.idx_split_variable == 2

0 commit comments

Comments
 (0)