Skip to content

Commit dd6a6e8

Browse files
author
Juan Orduz
committed
remove reference np.float64
1 parent 7c6b462 commit dd6a6e8

File tree

4 files changed

+71
-73
lines changed

4 files changed

+71
-73
lines changed

pymc_bart/bart.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __new__(
132132
alpha: float = 0.95,
133133
beta: float = 2.0,
134134
response: str = "constant",
135-
split_prior: Optional[npt.NDArray[np.float64]] = None,
135+
split_prior: Optional[npt.NDArray] = None,
136136
split_rules: Optional[list[SplitRule]] = None,
137137
separate_trees: Optional[bool] = False,
138138
**kwargs,
@@ -203,9 +203,7 @@ def get_moment(cls, rv, size, *rv_inputs):
203203
return mean
204204

205205

206-
def preprocess_xy(
207-
X: TensorLike, Y: TensorLike
208-
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
206+
def preprocess_xy(X: TensorLike, Y: TensorLike) -> tuple[npt.NDArray, npt.NDArray]:
209207
if isinstance(Y, (Series, DataFrame)):
210208
Y = Y.to_numpy()
211209
if isinstance(X, (Series, DataFrame)):

pymc_bart/pgbart.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def normalize(self, particles: list[ParticleTree]) -> float:
325325
return wei / wei.sum()
326326

327327
def resample(
328-
self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64]
328+
self, particles: list[ParticleTree], normalized_weights: npt.NDArray
329329
) -> list[ParticleTree]:
330330
"""
331331
Use systematic resample for all but the first particle
@@ -347,7 +347,7 @@ def resample(
347347
return particles
348348

349349
def get_particle_tree(
350-
self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64]
350+
self, particles: list[ParticleTree], normalized_weights: npt.NDArray
351351
) -> tuple[ParticleTree, Tree]:
352352
"""
353353
Sample a new particle and associated tree
@@ -359,7 +359,7 @@ def get_particle_tree(
359359

360360
return new_particle, new_particle.tree
361361

362-
def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray[np.int_]:
362+
def systematic(self, normalized_weights: npt.NDArray) -> npt.NDArray[np.int_]:
363363
"""
364364
Systematic resampling.
365365
@@ -411,7 +411,7 @@ def __init__(self, shape: tuple) -> None:
411411
self.mean = np.zeros(shape) # running mean
412412
self.m_2 = np.zeros(shape) # running second moment
413413

414-
def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
414+
def update(self, new_value: npt.NDArray) -> Union[float, npt.NDArray]:
415415
self.count = self.count + 1
416416
self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value)
417417
return fast_mean(std)
@@ -420,21 +420,21 @@ def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray
420420
@njit
421421
def _update(
422422
count: int,
423-
mean: npt.NDArray[np.float64],
424-
m_2: npt.NDArray[np.float64],
425-
new_value: npt.NDArray[np.float64],
426-
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]:
423+
mean: npt.NDArray,
424+
m_2: npt.NDArray,
425+
new_value: npt.NDArray,
426+
) -> tuple[npt.NDArray, npt.NDArray, Union[float, npt.NDArray]]:
427427
delta = new_value - mean
428428
mean += delta / count
429429
delta2 = new_value - mean
430430
m_2 += delta * delta2
431431

432432
std = (m_2 / count) ** 0.5
433-
return mean.astype(np.float64), m_2.astype(np.float64), std.astype(np.float64)
433+
return mean, m_2, std
434434

435435

436436
class SampleSplittingVariable:
437-
def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None:
437+
def __init__(self, alpha_vec: npt.NDArray) -> None:
438438
"""
439439
Sample splitting variables proportional to `alpha_vec`.
440440
@@ -547,16 +547,16 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d
547547

548548

549549
def draw_leaf_value(
550-
y_mu_pred: npt.NDArray[np.float64],
551-
x_mu: npt.NDArray[np.float64],
550+
y_mu_pred: npt.NDArray,
551+
x_mu: npt.NDArray,
552552
m: int,
553-
norm: npt.NDArray[np.float64],
553+
norm: npt.NDArray,
554554
shape: int,
555555
response: str,
556-
) -> tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
556+
) -> tuple[npt.NDArray, Optional[npt.NDArray]]:
557557
"""Draw Gaussian distributed leaf values."""
558558
linear_params = None
559-
mu_mean: npt.NDArray[np.float64]
559+
mu_mean: npt.NDArray
560560
if y_mu_pred.size == 0:
561561
return np.zeros(shape), linear_params
562562

@@ -571,7 +571,7 @@ def draw_leaf_value(
571571

572572

573573
@njit
574-
def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
574+
def fast_mean(ari: npt.NDArray) -> Union[float, npt.NDArray]:
575575
"""Use Numba to speed up the computation of the mean."""
576576
if ari.ndim == 1:
577577
count = ari.shape[0]
@@ -590,11 +590,11 @@ def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float
590590

591591
@njit
592592
def fast_linear_fit(
593-
x: npt.NDArray[np.float64],
594-
y: npt.NDArray[np.float64],
593+
x: npt.NDArray,
594+
y: npt.NDArray,
595595
m: int,
596-
norm: npt.NDArray[np.float64],
597-
) -> tuple[npt.NDArray[np.float64], list[npt.NDArray[np.float64]]]:
596+
norm: npt.NDArray,
597+
) -> tuple[npt.NDArray, list[npt.NDArray]]:
598598
n = len(x)
599599
y = (y / m + np.expand_dims(norm, axis=1)).astype(np.float64)
600600

@@ -678,17 +678,17 @@ def update(self):
678678

679679
@njit
680680
def inverse_cdf(
681-
single_uniform: npt.NDArray[np.float64], normalized_weights: npt.NDArray[np.float64]
681+
single_uniform: npt.NDArray, normalized_weights: npt.NDArray
682682
) -> npt.NDArray[np.int_]:
683683
"""
684684
Inverse CDF algorithm for a finite distribution.
685685
686686
Parameters
687687
----------
688-
single_uniform: npt.NDArray[np.float64]
688+
single_uniform: npt.NDArray
689689
Ordered points in [0,1]
690690
691-
normalized_weights: npt.NDArray[np.float64])
691+
normalized_weights: npt.NDArray)
692692
Normalized weights
693693
694694
Returns
@@ -711,7 +711,7 @@ def inverse_cdf(
711711

712712

713713
@njit
714-
def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray[np.float64]:
714+
def jitter_duplicated(array: npt.NDArray, std: float) -> npt.NDArray:
715715
"""
716716
Jitter duplicated values.
717717
"""
@@ -727,7 +727,7 @@ def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray
727727

728728

729729
@njit
730-
def are_whole_number(array: npt.NDArray[np.float64]) -> np.bool_:
730+
def are_whole_number(array: npt.NDArray) -> np.bool_:
731731
"""Check if all values in array are whole numbers"""
732732
return np.all(np.mod(array[~np.isnan(array)], 1) == 0)
733733

pymc_bart/tree.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class Node:
2828
2929
Attributes
3030
----------
31-
value : npt.NDArray[np.float64]
31+
value : npt.NDArray
3232
idx_data_points : Optional[npt.NDArray[np.int_]]
3333
idx_split_variable : int
3434
linear_params: Optional[list[float]] = None
@@ -38,11 +38,11 @@ class Node:
3838

3939
def __init__(
4040
self,
41-
value: npt.NDArray[np.float64] = np.array([-1.0]),
41+
value: npt.NDArray = np.array([-1.0]),
4242
nvalue: int = 0,
4343
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
4444
idx_split_variable: int = -1,
45-
linear_params: Optional[list[npt.NDArray[np.float64]]] = None,
45+
linear_params: Optional[list[npt.NDArray]] = None,
4646
) -> None:
4747
self.value = value
4848
self.nvalue = nvalue
@@ -53,11 +53,11 @@ def __init__(
5353
@classmethod
5454
def new_leaf_node(
5555
cls,
56-
value: npt.NDArray[np.float64],
56+
value: npt.NDArray,
5757
nvalue: int = 0,
5858
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
5959
idx_split_variable: int = -1,
60-
linear_params: Optional[list[npt.NDArray[np.float64]]] = None,
60+
linear_params: Optional[list[npt.NDArray]] = None,
6161
) -> "Node":
6262
return cls(
6363
value=value,
@@ -101,7 +101,7 @@ class Tree:
101101
The dictionary's keys are integers that represent the nodes position.
102102
The dictionary's values are objects of type Node that represent the split and leaf nodes
103103
of the tree itself.
104-
output: Optional[npt.NDArray[np.float64]]
104+
output: Optional[npt.NDArray]
105105
Array of shape number of observations, shape
106106
split_rules : list[SplitRule]
107107
List of SplitRule objects, one per column in input data.
@@ -122,7 +122,7 @@ class Tree:
122122
def __init__(
123123
self,
124124
tree_structure: dict[int, Node],
125-
output: npt.NDArray[np.float64],
125+
output: npt.NDArray,
126126
split_rules: list[SplitRule],
127127
idx_leaf_nodes: Optional[list[int]] = None,
128128
) -> None:
@@ -134,7 +134,7 @@ def __init__(
134134
@classmethod
135135
def new_tree(
136136
cls,
137-
leaf_node_value: npt.NDArray[np.float64],
137+
leaf_node_value: npt.NDArray,
138138
idx_data_points: Optional[npt.NDArray[np.int_]],
139139
num_observations: int,
140140
shape: int,
@@ -190,7 +190,7 @@ def grow_leaf_node(
190190
self,
191191
current_node: Node,
192192
selected_predictor: int,
193-
split_value: npt.NDArray[np.float64],
193+
split_value: npt.NDArray,
194194
index_leaf_node: int,
195195
) -> None:
196196
current_node.value = split_value
@@ -222,7 +222,7 @@ def get_split_variables(self) -> Generator[int, None, None]:
222222
if node.is_split_node():
223223
yield node.idx_split_variable
224224

225-
def _predict(self) -> npt.NDArray[np.float64]:
225+
def _predict(self) -> npt.NDArray:
226226
output = self.output
227227

228228
if self.idx_leaf_nodes is not None:
@@ -233,23 +233,23 @@ def _predict(self) -> npt.NDArray[np.float64]:
233233

234234
def predict(
235235
self,
236-
x: npt.NDArray[np.float64],
236+
x: npt.NDArray,
237237
excluded: Optional[list[int]] = None,
238238
shape: int = 1,
239-
) -> npt.NDArray[np.float64]:
239+
) -> npt.NDArray:
240240
"""
241241
Predict output of tree for an (un)observed point x.
242242
243243
Parameters
244244
----------
245-
x : npt.NDArray[np.float64]
245+
x : npt.NDArray
246246
Unobserved point
247247
excluded: Optional[list[int]]
248248
Indexes of the variables to exclude when computing predictions
249249
250250
Returns
251251
-------
252-
npt.NDArray[np.float64]
252+
npt.NDArray
253253
Value of the leaf value where the unobserved point lies.
254254
"""
255255
if excluded is None:
@@ -259,16 +259,16 @@ def predict(
259259

260260
def _traverse_tree(
261261
self,
262-
X: npt.NDArray[np.float64],
262+
X: npt.NDArray,
263263
excluded: Optional[list[int]] = None,
264264
shape: Union[int, tuple[int, ...]] = 1,
265-
) -> npt.NDArray[np.float64]:
265+
) -> npt.NDArray:
266266
"""
267267
Traverse the tree starting from the root node given an (un)observed point.
268268
269269
Parameters
270270
----------
271-
X : npt.NDArray[np.float64]
271+
X : npt.NDArray
272272
(Un)observed point(s)
273273
node_index : int
274274
Index of the node to start the traversal from
@@ -279,7 +279,7 @@ def _traverse_tree(
279279
280280
Returns
281281
-------
282-
npt.NDArray[np.float64]
282+
npt.NDArray
283283
Leaf node value or mean of leaf node values
284284
"""
285285

@@ -338,14 +338,14 @@ def _traverse_tree(
338338
return p_d
339339

340340
def _traverse_leaf_values(
341-
self, leaf_values: list[npt.NDArray[np.float64]], leaf_n_values: list[int], node_index: int
341+
self, leaf_values: list[npt.NDArray], leaf_n_values: list[int], node_index: int
342342
) -> None:
343343
"""
344344
Traverse the tree appending leaf values starting from a particular node.
345345
346346
Parameters
347347
----------
348-
leaf_values : list[npt.NDArray[np.float64]]
348+
leaf_values : list[npt.NDArray]
349349
node_index : int
350350
"""
351351
node = self.get_node(node_index)

0 commit comments

Comments
 (0)