Skip to content

Commit 751ec68

Browse files
author
Juan Orduz
committed
remove unnesserary casting
1 parent dd6a6e8 commit 751ec68

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

pymc_bart/pgbart.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import numpy as np
1818
import numpy.typing as npt
19+
import pymc as pm
20+
import pytensor as pt
1921
from numba import njit
2022
from pymc.initial_point import PointType
2123
from pymc.model import Model, modelcontext
@@ -120,15 +122,15 @@ class PGBART(ArrayStepShared):
120122
"tune": (bool, []),
121123
}
122124

123-
def __init__( # noqa: PLR0915
125+
def __init__( # noqa: PLR0912, PLR0915
124126
self,
125-
vars=None, # pylint: disable=redefined-builtin
127+
vars: list[pm.Distribution] | None = None,
126128
num_particles: int = 10,
127129
batch: tuple[float, float] = (0.1, 0.1),
128130
model: Optional[Model] = None,
129131
initial_point: PointType | None = None,
130-
compile_kwargs: dict | None = None, # pylint: disable=unused-argument
131-
):
132+
compile_kwargs: dict | None = None,
133+
) -> None:
132134
model = modelcontext(model)
133135
if initial_point is None:
134136
initial_point = model.initial_point()
@@ -137,6 +139,10 @@ def __init__( # noqa: PLR0915
137139
else:
138140
vars = [model.rvs_to_values.get(var, var) for var in vars]
139141
vars = inputvars(vars)
142+
143+
if vars is None:
144+
raise ValueError("Unable to find variables to sample")
145+
140146
value_bart = vars[0]
141147
self.bart = model.values_to_rvs[value_bart].owner.op
142148

@@ -395,7 +401,7 @@ def update_weight(self, particle: ParticleTree, odim: int) -> None:
395401
particle.log_weight = new_likelihood
396402

397403
@staticmethod
398-
def competence(var, has_grad):
404+
def competence(var: pm.Distribution, has_grad: bool) -> Competence:
399405
"""PGBART is only suitable for BART distributions."""
400406
dist = getattr(var.owner, "op", None)
401407
if isinstance(dist, BARTRV):
@@ -406,7 +412,7 @@ def competence(var, has_grad):
406412
class RunningSd:
407413
"""Welford's online algorithm for computing the variance/standard deviation"""
408414

409-
def __init__(self, shape: tuple) -> None:
415+
def __init__(self, shape: tuple[int, ...]) -> None:
410416
self.count = 0 # number of data points
411417
self.mean = np.zeros(shape) # running mean
412418
self.m_2 = np.zeros(shape) # running second moment
@@ -561,7 +567,7 @@ def draw_leaf_value(
561567
return np.zeros(shape), linear_params
562568

563569
if y_mu_pred.size == 1:
564-
mu_mean = (np.full(shape, y_mu_pred.item() / m) + norm).astype(np.float64)
570+
mu_mean = np.full(shape, y_mu_pred.item() / m) + norm
565571
elif y_mu_pred.size < 3 or response == "constant":
566572
mu_mean = fast_mean(y_mu_pred) / m + norm
567573
else:
@@ -585,7 +591,7 @@ def fast_mean(ari: npt.NDArray) -> Union[float, npt.NDArray]:
585591
for j in range(ari.shape[0]):
586592
for i in range(count):
587593
res[j] += ari[j, i]
588-
return (res / count).astype(np.float64)
594+
return res / count
589595

590596

591597
@njit
@@ -596,7 +602,7 @@ def fast_linear_fit(
596602
norm: npt.NDArray,
597603
) -> tuple[npt.NDArray, list[npt.NDArray]]:
598604
n = len(x)
599-
y = (y / m + np.expand_dims(norm, axis=1)).astype(np.float64)
605+
y = y / m + np.expand_dims(norm, axis=1)
600606

601607
xbar = np.sum(x) / n
602608
ybar = np.sum(y, axis=1) / n
@@ -732,7 +738,9 @@ def are_whole_number(array: npt.NDArray) -> np.bool_:
732738
return np.all(np.mod(array[~np.isnan(array)], 1) == 0)
733739

734740

735-
def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin
741+
def logp(
742+
point, out_vars: list[pm.Distribution], vars: list[pm.Distribution], shared: list[pt.Tensor]
743+
):
736744
"""Compile PyTensor function of the model and the input and output variables.
737745
738746
Parameters

pymc_bart/tree.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,9 @@ def _traverse_tree(
286286
x_shape = (1,) if len(X.shape) == 1 else X.shape[:-1]
287287
nd_dims = (...,) + (None,) * len(x_shape)
288288

289-
stack = [(0, np.ones(x_shape), 0)] # (node_index, weight, idx_split_variable) initial state
289+
stack: list[tuple[int, npt.NDArray, int]] = [
290+
(0, np.ones(x_shape), 0)
291+
] # (node_index, weight, idx_split_variable) initial state
290292
p_d = (
291293
np.zeros(shape + x_shape) if isinstance(shape, tuple) else np.zeros((shape,) + x_shape)
292294
)
@@ -312,14 +314,14 @@ def _traverse_tree(
312314
stack.append(
313315
(
314316
left_node_index,
315-
(weights * prop_nvalue_left).astype(np.float64),
317+
weights * prop_nvalue_left,
316318
idx_split_variable,
317319
)
318320
)
319321
stack.append(
320322
(
321323
right_node_index,
322-
(weights * (1 - prop_nvalue_left)).astype(np.float64),
324+
weights * (1 - prop_nvalue_left),
323325
idx_split_variable,
324326
)
325327
)

0 commit comments

Comments
 (0)