Skip to content

Commit e04ca5c

Browse files
authored
update pre-commit (#171)
1 parent b759dff commit e04ca5c

File tree

5 files changed

+68
-68
lines changed

5 files changed

+68
-68
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ ci:
1212

1313
repos:
1414
- repo: https://github.com/astral-sh/ruff-pre-commit
15-
rev: v0.4.7
15+
rev: v0.5.0
1616
hooks:
1717
- id: ruff
18-
args: ["--fix", "--show-source"]
18+
args: ["--fix", "--output-format=full"]
1919
- id: ruff-format
2020
args: ["--line-length=100"]
2121
- repo: https://github.com/pre-commit/mirrors-mypy
22-
rev: v1.10.0
22+
rev: v1.10.1
2323
hooks:
2424
- id: mypy
2525
args: [--ignore-missing-imports]

pymc_bart/bart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __new__(
125125
alpha: float = 0.95,
126126
beta: float = 2.0,
127127
response: str = "constant",
128-
split_prior: Optional[npt.NDArray[np.float_]] = None,
128+
split_prior: Optional[npt.NDArray[np.float64]] = None,
129129
split_rules: Optional[List[SplitRule]] = None,
130130
separate_trees: Optional[bool] = False,
131131
**kwargs,
@@ -198,7 +198,7 @@ def get_moment(cls, rv, size, *rv_inputs):
198198

199199
def preprocess_xy(
200200
X: TensorLike, Y: TensorLike
201-
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]]:
201+
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
202202
if isinstance(Y, (Series, DataFrame)):
203203
Y = Y.to_numpy()
204204
if isinstance(X, (Series, DataFrame)):

pymc_bart/pgbart.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def normalize(self, particles: List[ParticleTree]) -> float:
313313
return wei / wei.sum()
314314

315315
def resample(
316-
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float_]
316+
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64]
317317
) -> List[ParticleTree]:
318318
"""
319319
Use systematic resample for all but the first particle
@@ -335,7 +335,7 @@ def resample(
335335
return particles
336336

337337
def get_particle_tree(
338-
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float_]
338+
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64]
339339
) -> Tuple[ParticleTree, Tree]:
340340
"""
341341
Sample a new particle and associated tree
@@ -347,7 +347,7 @@ def get_particle_tree(
347347

348348
return new_particle, new_particle.tree
349349

350-
def systematic(self, normalized_weights: npt.NDArray[np.float_]) -> npt.NDArray[np.int_]:
350+
def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray[np.int_]:
351351
"""
352352
Systematic resampling.
353353
@@ -399,7 +399,7 @@ def __init__(self, shape: tuple) -> None:
399399
self.mean = np.zeros(shape) # running mean
400400
self.m_2 = np.zeros(shape) # running second moment
401401

402-
def update(self, new_value: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_]]:
402+
def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
403403
self.count = self.count + 1
404404
self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value)
405405
return fast_mean(std)
@@ -408,10 +408,10 @@ def update(self, new_value: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[
408408
@njit
409409
def _update(
410410
count: int,
411-
mean: npt.NDArray[np.float_],
412-
m_2: npt.NDArray[np.float_],
413-
new_value: npt.NDArray[np.float_],
414-
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], Union[float, npt.NDArray[np.float_]]]:
411+
mean: npt.NDArray[np.float64],
412+
m_2: npt.NDArray[np.float64],
413+
new_value: npt.NDArray[np.float64],
414+
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]:
415415
delta = new_value - mean
416416
mean += delta / count
417417
delta2 = new_value - mean
@@ -422,7 +422,7 @@ def _update(
422422

423423

424424
class SampleSplittingVariable:
425-
def __init__(self, alpha_vec: npt.NDArray[np.float_]) -> None:
425+
def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None:
426426
"""
427427
Sample splitting variables proportional to `alpha_vec`.
428428
@@ -535,13 +535,13 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d
535535

536536

537537
def draw_leaf_value(
538-
y_mu_pred: npt.NDArray[np.float_],
539-
x_mu: npt.NDArray[np.float_],
538+
y_mu_pred: npt.NDArray[np.float64],
539+
x_mu: npt.NDArray[np.float64],
540540
m: int,
541-
norm: npt.NDArray[np.float_],
541+
norm: npt.NDArray[np.float64],
542542
shape: int,
543543
response: str,
544-
) -> Tuple[npt.NDArray[np.float_], Optional[npt.NDArray[np.float_]]]:
544+
) -> Tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
545545
"""Draw Gaussian distributed leaf values."""
546546
linear_params = None
547547
mu_mean = np.empty(shape)
@@ -559,7 +559,7 @@ def draw_leaf_value(
559559

560560

561561
@njit
562-
def fast_mean(ari: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_]]:
562+
def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
563563
"""Use Numba to speed up the computation of the mean."""
564564
if ari.ndim == 1:
565565
count = ari.shape[0]
@@ -578,11 +578,11 @@ def fast_mean(ari: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_
578578

579579
@njit
580580
def fast_linear_fit(
581-
x: npt.NDArray[np.float_],
582-
y: npt.NDArray[np.float_],
581+
x: npt.NDArray[np.float64],
582+
y: npt.NDArray[np.float64],
583583
m: int,
584-
norm: npt.NDArray[np.float_],
585-
) -> Tuple[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]]:
584+
norm: npt.NDArray[np.float64],
585+
) -> Tuple[npt.NDArray[np.float64], List[npt.NDArray[np.float64]]]:
586586
n = len(x)
587587
y = y / m + np.expand_dims(norm, axis=1)
588588

@@ -666,17 +666,17 @@ def update(self):
666666

667667
@njit
668668
def inverse_cdf(
669-
single_uniform: npt.NDArray[np.float_], normalized_weights: npt.NDArray[np.float_]
669+
single_uniform: npt.NDArray[np.float64], normalized_weights: npt.NDArray[np.float64]
670670
) -> npt.NDArray[np.int_]:
671671
"""
672672
Inverse CDF algorithm for a finite distribution.
673673
674674
Parameters
675675
----------
676-
single_uniform: npt.NDArray[np.float_]
676+
single_uniform: npt.NDArray[np.float64]
677677
Ordered points in [0,1]
678678
679-
normalized_weights: npt.NDArray[np.float_])
679+
normalized_weights: npt.NDArray[np.float64])
680680
Normalized weights
681681
682682
Returns
@@ -699,7 +699,7 @@ def inverse_cdf(
699699

700700

701701
@njit
702-
def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[np.float_]:
702+
def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray[np.float64]:
703703
"""
704704
Jitter duplicated values.
705705
"""
@@ -715,7 +715,7 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[
715715

716716

717717
@njit
718-
def are_whole_number(array: npt.NDArray[np.float_]) -> np.bool_:
718+
def are_whole_number(array: npt.NDArray[np.float64]) -> np.bool_:
719719
"""Check if all values in array are whole numbers"""
720720
return np.all(np.mod(array[~np.isnan(array)], 1) == 0)
721721

pymc_bart/tree.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class Node:
2727
2828
Attributes
2929
----------
30-
value : npt.NDArray[np.float_]
30+
value : npt.NDArray[np.float64]
3131
idx_data_points : Optional[npt.NDArray[np.int_]]
3232
idx_split_variable : int
3333
linear_params: Optional[List[float]] = None
@@ -37,11 +37,11 @@ class Node:
3737

3838
def __init__(
3939
self,
40-
value: npt.NDArray[np.float_] = np.array([-1.0]),
40+
value: npt.NDArray[np.float64] = np.array([-1.0]),
4141
nvalue: int = 0,
4242
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
4343
idx_split_variable: int = -1,
44-
linear_params: Optional[List[npt.NDArray[np.float_]]] = None,
44+
linear_params: Optional[List[npt.NDArray[np.float64]]] = None,
4545
) -> None:
4646
self.value = value
4747
self.nvalue = nvalue
@@ -52,11 +52,11 @@ def __init__(
5252
@classmethod
5353
def new_leaf_node(
5454
cls,
55-
value: npt.NDArray[np.float_],
55+
value: npt.NDArray[np.float64],
5656
nvalue: int = 0,
5757
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
5858
idx_split_variable: int = -1,
59-
linear_params: Optional[List[npt.NDArray[np.float_]]] = None,
59+
linear_params: Optional[List[npt.NDArray[np.float64]]] = None,
6060
) -> "Node":
6161
return cls(
6262
value=value,
@@ -100,7 +100,7 @@ class Tree:
100100
The dictionary's keys are integers that represent the nodes position.
101101
The dictionary's values are objects of type Node that represent the split and leaf nodes
102102
of the tree itself.
103-
output: Optional[npt.NDArray[np.float_]]
103+
output: Optional[npt.NDArray[np.float64]]
104104
Array of shape number of observations, shape
105105
split_rules : List[SplitRule]
106106
List of SplitRule objects, one per column in input data.
@@ -121,7 +121,7 @@ class Tree:
121121
def __init__(
122122
self,
123123
tree_structure: Dict[int, Node],
124-
output: npt.NDArray[np.float_],
124+
output: npt.NDArray[np.float64],
125125
split_rules: List[SplitRule],
126126
idx_leaf_nodes: Optional[List[int]] = None,
127127
) -> None:
@@ -133,7 +133,7 @@ def __init__(
133133
@classmethod
134134
def new_tree(
135135
cls,
136-
leaf_node_value: npt.NDArray[np.float_],
136+
leaf_node_value: npt.NDArray[np.float64],
137137
idx_data_points: Optional[npt.NDArray[np.int_]],
138138
num_observations: int,
139139
shape: int,
@@ -189,7 +189,7 @@ def grow_leaf_node(
189189
self,
190190
current_node: Node,
191191
selected_predictor: int,
192-
split_value: npt.NDArray[np.float_],
192+
split_value: npt.NDArray[np.float64],
193193
index_leaf_node: int,
194194
) -> None:
195195
current_node.value = split_value
@@ -221,7 +221,7 @@ def get_split_variables(self) -> Generator[int, None, None]:
221221
if node.is_split_node():
222222
yield node.idx_split_variable
223223

224-
def _predict(self) -> npt.NDArray[np.float_]:
224+
def _predict(self) -> npt.NDArray[np.float64]:
225225
output = self.output
226226

227227
if self.idx_leaf_nodes is not None:
@@ -232,23 +232,23 @@ def _predict(self) -> npt.NDArray[np.float_]:
232232

233233
def predict(
234234
self,
235-
x: npt.NDArray[np.float_],
235+
x: npt.NDArray[np.float64],
236236
excluded: Optional[List[int]] = None,
237237
shape: int = 1,
238-
) -> npt.NDArray[np.float_]:
238+
) -> npt.NDArray[np.float64]:
239239
"""
240240
Predict output of tree for an (un)observed point x.
241241
242242
Parameters
243243
----------
244-
x : npt.NDArray[np.float_]
244+
x : npt.NDArray[np.float64]
245245
Unobserved point
246246
excluded: Optional[List[int]]
247247
Indexes of the variables to exclude when computing predictions
248248
249249
Returns
250250
-------
251-
npt.NDArray[np.float_]
251+
npt.NDArray[np.float64]
252252
Value of the leaf value where the unobserved point lies.
253253
"""
254254
if excluded is None:
@@ -258,16 +258,16 @@ def predict(
258258

259259
def _traverse_tree(
260260
self,
261-
X: npt.NDArray[np.float_],
261+
X: npt.NDArray[np.float64],
262262
excluded: Optional[List[int]] = None,
263263
shape: Union[int, Tuple[int, ...]] = 1,
264-
) -> npt.NDArray[np.float_]:
264+
) -> npt.NDArray[np.float64]:
265265
"""
266266
Traverse the tree starting from the root node given an (un)observed point.
267267
268268
Parameters
269269
----------
270-
X : npt.NDArray[np.float_]
270+
X : npt.NDArray[np.float64]
271271
(Un)observed point(s)
272272
node_index : int
273273
Index of the node to start the traversal from
@@ -278,7 +278,7 @@ def _traverse_tree(
278278
279279
Returns
280280
-------
281-
npt.NDArray[np.float_]
281+
npt.NDArray[np.float64]
282282
Leaf node value or mean of leaf node values
283283
"""
284284

@@ -327,14 +327,14 @@ def _traverse_tree(
327327
return p_d
328328

329329
def _traverse_leaf_values(
330-
self, leaf_values: List[npt.NDArray[np.float_]], leaf_n_values: List[int], node_index: int
330+
self, leaf_values: List[npt.NDArray[np.float64]], leaf_n_values: List[int], node_index: int
331331
) -> None:
332332
"""
333333
Traverse the tree appending leaf values starting from a particular node.
334334
335335
Parameters
336336
----------
337-
leaf_values : List[npt.NDArray[np.float_]]
337+
leaf_values : List[npt.NDArray[np.float64]]
338338
node_index : int
339339
"""
340340
node = self.get_node(node_index)

0 commit comments

Comments
 (0)