Skip to content

Commit 9ec4de8

Browse files
authored
lint updates (#199)
* lint updates * use built-in types
1 parent 40f1220 commit 9ec4de8

File tree

6 files changed

+101
-101
lines changed

6 files changed

+101
-101
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ci:
1212

1313
repos:
1414
- repo: https://github.com/astral-sh/ruff-pre-commit
15-
rev: v0.7.4
15+
rev: v0.8.0
1616
hooks:
1717
- id: ruff
1818
args: ["--fix", "--output-format=full"]

pymc_bart/bart.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import warnings
1818
from multiprocessing import Manager
19-
from typing import List, Optional, Tuple
19+
from typing import Optional
2020

2121
import numpy as np
2222
import numpy.typing as npt
@@ -39,8 +39,8 @@ class BARTRV(RandomVariable):
3939
name: str = "BART"
4040
signature = "(m,n),(m),(),(),() -> (m)"
4141
dtype: str = "floatX"
42-
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
43-
all_trees = List[List[List[Tree]]]
42+
_print_name: tuple[str, str] = ("BART", "\\operatorname{BART}")
43+
all_trees = list[list[list[Tree]]]
4444

4545
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed
4646
idx = dist_params[0].ndim - 2
@@ -92,10 +92,10 @@ class BART(Distribution):
9292
beta : float
9393
Controls the prior probability over the number of leaves of the trees.
9494
Should be positive.
95-
split_prior : Optional[List[float]], default None.
95+
split_prior : Optional[list[float]], default None.
9696
List of positive numbers, one per column in input data.
9797
Defaults to None, all covariates have the same prior probability to be selected.
98-
split_rules : Optional[List[SplitRule]], default None
98+
split_rules : Optional[list[SplitRule]], default None
9999
List of SplitRule objects, one per column in input data.
100100
Allows using different split rules for different columns. Default is ContinuousSplitRule.
101101
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
@@ -126,7 +126,7 @@ def __new__(
126126
beta: float = 2.0,
127127
response: str = "constant",
128128
split_prior: Optional[npt.NDArray[np.float64]] = None,
129-
split_rules: Optional[List[SplitRule]] = None,
129+
split_rules: Optional[list[SplitRule]] = None,
130130
separate_trees: Optional[bool] = False,
131131
**kwargs,
132132
):
@@ -198,7 +198,7 @@ def get_moment(cls, rv, size, *rv_inputs):
198198

199199
def preprocess_xy(
200200
X: TensorLike, Y: TensorLike
201-
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
201+
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
202202
if isinstance(Y, (Series, DataFrame)):
203203
Y = Y.to_numpy()
204204
if isinstance(X, (Series, DataFrame)):

pymc_bart/pgbart.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Optional, Tuple, Union
15+
from typing import Optional, Union
1616

1717
import numpy as np
1818
import numpy.typing as npt
@@ -43,7 +43,7 @@ class ParticleTree:
4343

4444
def __init__(self, tree: Tree):
4545
self.tree: Tree = tree.copy()
46-
self.expansion_nodes: List[int] = [0]
46+
self.expansion_nodes: list[int] = [0]
4747
self.log_weight: float = 0
4848

4949
def copy(self) -> "ParticleTree":
@@ -123,7 +123,7 @@ def __init__( # noqa: PLR0915
123123
self,
124124
vars=None, # pylint: disable=redefined-builtin
125125
num_particles: int = 10,
126-
batch: Tuple[float, float] = (0.1, 0.1),
126+
batch: tuple[float, float] = (0.1, 0.1),
127127
model: Optional[Model] = None,
128128
):
129129
model = modelcontext(model)
@@ -310,7 +310,7 @@ def astep(self, _):
310310
stats = {"variable_inclusion": variable_inclusion, "tune": self.tune}
311311
return self.sum_trees, [stats]
312312

313-
def normalize(self, particles: List[ParticleTree]) -> float:
313+
def normalize(self, particles: list[ParticleTree]) -> float:
314314
"""
315315
Use softmax to get normalized_weights.
316316
"""
@@ -321,16 +321,16 @@ def normalize(self, particles: List[ParticleTree]) -> float:
321321
return wei / wei.sum()
322322

323323
def resample(
324-
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64]
325-
) -> List[ParticleTree]:
324+
self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64]
325+
) -> list[ParticleTree]:
326326
"""
327327
Use systematic resample for all but the first particle
328328
329329
Ensure particles are copied only if needed.
330330
"""
331331
new_indices = self.systematic(normalized_weights) + 1
332-
seen: List[int] = []
333-
new_particles: List[ParticleTree] = []
332+
seen: list[int] = []
333+
new_particles: list[ParticleTree] = []
334334
for idx in new_indices:
335335
if idx in seen:
336336
new_particles.append(particles[idx].copy())
@@ -343,8 +343,8 @@ def resample(
343343
return particles
344344

345345
def get_particle_tree(
346-
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64]
347-
) -> Tuple[ParticleTree, Tree]:
346+
self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64]
347+
) -> tuple[ParticleTree, Tree]:
348348
"""
349349
Sample a new particle and associated tree
350350
"""
@@ -367,12 +367,12 @@ def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray
367367
single_uniform = (self.uniform.rvs() + np.arange(lnw)) / lnw
368368
return inverse_cdf(single_uniform, normalized_weights)
369369

370-
def init_particles(self, tree_id: int, odim: int) -> List[ParticleTree]:
370+
def init_particles(self, tree_id: int, odim: int) -> list[ParticleTree]:
371371
"""Initialize particles."""
372372
p0: ParticleTree = self.all_particles[odim][tree_id]
373373
# The old tree does not grow so we update the weight only once
374374
self.update_weight(p0, odim)
375-
particles: List[ParticleTree] = [p0]
375+
particles: list[ParticleTree] = [p0]
376376

377377
particles.extend(ParticleTree(self.a_tree) for _ in self.indices)
378378
return particles
@@ -419,7 +419,7 @@ def _update(
419419
mean: npt.NDArray[np.float64],
420420
m_2: npt.NDArray[np.float64],
421421
new_value: npt.NDArray[np.float64],
422-
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]:
422+
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]:
423423
delta = new_value - mean
424424
mean += delta / count
425425
delta2 = new_value - mean
@@ -439,15 +439,15 @@ def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None:
439439
"""
440440
self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum())))
441441

442-
def rvs(self) -> Union[int, Tuple[int, float]]:
442+
def rvs(self) -> Union[int, tuple[int, float]]:
443443
rnd: float = np.random.random()
444444
for i, val in self.enu:
445445
if rnd <= val:
446446
return i
447447
return self.enu[-1]
448448

449449

450-
def compute_prior_probability(alpha: int, beta: int) -> List[float]:
450+
def compute_prior_probability(alpha: int, beta: int) -> list[float]:
451451
"""
452452
Calculate the probability of the node being a leaf node (1 - p(being split node)).
453453
@@ -460,7 +460,7 @@ def compute_prior_probability(alpha: int, beta: int) -> List[float]:
460460
-------
461461
list with probabilities for leaf nodes
462462
"""
463-
prior_leaf_prob: List[float] = [0]
463+
prior_leaf_prob: list[float] = [0]
464464
depth = 0
465465
while prior_leaf_prob[-1] < 0.9999:
466466
prior_leaf_prob.append(1 - (alpha * ((1 + depth) ** (-beta))))
@@ -549,7 +549,7 @@ def draw_leaf_value(
549549
norm: npt.NDArray[np.float64],
550550
shape: int,
551551
response: str,
552-
) -> Tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
552+
) -> tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
553553
"""Draw Gaussian distributed leaf values."""
554554
linear_params = None
555555
mu_mean = np.empty(shape)
@@ -590,7 +590,7 @@ def fast_linear_fit(
590590
y: npt.NDArray[np.float64],
591591
m: int,
592592
norm: npt.NDArray[np.float64],
593-
) -> Tuple[npt.NDArray[np.float64], List[npt.NDArray[np.float64]]]:
593+
) -> tuple[npt.NDArray[np.float64], list[npt.NDArray[np.float64]]]:
594594
n = len(x)
595595
y = y / m + np.expand_dims(norm, axis=1)
596596

pymc_bart/tree.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections.abc import Generator
1516
from functools import lru_cache
16-
from typing import Dict, Generator, List, Optional, Tuple, Union
17+
from typing import Optional, Union
1718

1819
import numpy as np
1920
import numpy.typing as npt
@@ -30,7 +31,7 @@ class Node:
3031
value : npt.NDArray[np.float64]
3132
idx_data_points : Optional[npt.NDArray[np.int_]]
3233
idx_split_variable : int
33-
linear_params: Optional[List[float]] = None
34+
linear_params: Optional[list[float]] = None
3435
"""
3536

3637
__slots__ = "value", "nvalue", "idx_split_variable", "idx_data_points", "linear_params"
@@ -41,7 +42,7 @@ def __init__(
4142
nvalue: int = 0,
4243
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
4344
idx_split_variable: int = -1,
44-
linear_params: Optional[List[npt.NDArray[np.float64]]] = None,
45+
linear_params: Optional[list[npt.NDArray[np.float64]]] = None,
4546
) -> None:
4647
self.value = value
4748
self.nvalue = nvalue
@@ -56,7 +57,7 @@ def new_leaf_node(
5657
nvalue: int = 0,
5758
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
5859
idx_split_variable: int = -1,
59-
linear_params: Optional[List[npt.NDArray[np.float64]]] = None,
60+
linear_params: Optional[list[npt.NDArray[np.float64]]] = None,
6061
) -> "Node":
6162
return cls(
6263
value=value,
@@ -94,19 +95,19 @@ class Tree:
9495
9596
Attributes
9697
----------
97-
tree_structure : Dict[int, Node]
98+
tree_structure : dict[int, Node]
9899
A dictionary that represents the nodes stored in breadth-first order, based in the array
99100
method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays).
100101
The dictionary's keys are integers that represent the nodes position.
101102
The dictionary's values are objects of type Node that represent the split and leaf nodes
102103
of the tree itself.
103104
output: Optional[npt.NDArray[np.float64]]
104105
Array of shape number of observations, shape
105-
split_rules : List[SplitRule]
106+
split_rules : list[SplitRule]
106107
List of SplitRule objects, one per column in input data.
107108
Allows using different split rules for different columns. Default is ContinuousSplitRule.
108109
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
109-
idx_leaf_nodes : Optional[List[int]], by default None.
110+
idx_leaf_nodes : Optional[list[int]], by default None.
110111
Array with the index of the leaf nodes of the tree.
111112
112113
Parameters
@@ -120,10 +121,10 @@ class Tree:
120121

121122
def __init__(
122123
self,
123-
tree_structure: Dict[int, Node],
124+
tree_structure: dict[int, Node],
124125
output: npt.NDArray[np.float64],
125-
split_rules: List[SplitRule],
126-
idx_leaf_nodes: Optional[List[int]] = None,
126+
split_rules: list[SplitRule],
127+
idx_leaf_nodes: Optional[list[int]] = None,
127128
) -> None:
128129
self.tree_structure = tree_structure
129130
self.idx_leaf_nodes = idx_leaf_nodes
@@ -137,7 +138,7 @@ def new_tree(
137138
idx_data_points: Optional[npt.NDArray[np.int_]],
138139
num_observations: int,
139140
shape: int,
140-
split_rules: List[SplitRule],
141+
split_rules: list[SplitRule],
141142
) -> "Tree":
142143
return cls(
143144
tree_structure={
@@ -159,7 +160,7 @@ def __setitem__(self, index, node) -> None:
159160
self.set_node(index, node)
160161

161162
def copy(self) -> "Tree":
162-
tree: Dict[int, Node] = {
163+
tree: dict[int, Node] = {
163164
k: Node(
164165
value=v.value,
165166
nvalue=v.nvalue,
@@ -199,7 +200,7 @@ def grow_leaf_node(
199200
self.idx_leaf_nodes.remove(index_leaf_node)
200201

201202
def trim(self) -> "Tree":
202-
tree: Dict[int, Node] = {
203+
tree: dict[int, Node] = {
203204
k: Node(
204205
value=v.value,
205206
nvalue=v.nvalue,
@@ -233,7 +234,7 @@ def _predict(self) -> npt.NDArray[np.float64]:
233234
def predict(
234235
self,
235236
x: npt.NDArray[np.float64],
236-
excluded: Optional[List[int]] = None,
237+
excluded: Optional[list[int]] = None,
237238
shape: int = 1,
238239
) -> npt.NDArray[np.float64]:
239240
"""
@@ -243,7 +244,7 @@ def predict(
243244
----------
244245
x : npt.NDArray[np.float64]
245246
Unobserved point
246-
excluded: Optional[List[int]]
247+
excluded: Optional[list[int]]
247248
Indexes of the variables to exclude when computing predictions
248249
249250
Returns
@@ -259,8 +260,8 @@ def predict(
259260
def _traverse_tree(
260261
self,
261262
X: npt.NDArray[np.float64],
262-
excluded: Optional[List[int]] = None,
263-
shape: Union[int, Tuple[int, ...]] = 1,
263+
excluded: Optional[list[int]] = None,
264+
shape: Union[int, tuple[int, ...]] = 1,
264265
) -> npt.NDArray[np.float64]:
265266
"""
266267
Traverse the tree starting from the root node given an (un)observed point.
@@ -273,7 +274,7 @@ def _traverse_tree(
273274
Index of the node to start the traversal from
274275
split_variable : int
275276
Index of the variable used to split the node
276-
excluded: Optional[List[int]]
277+
excluded: Optional[list[int]]
277278
Indexes of the variables to exclude when computing predictions
278279
279280
Returns
@@ -327,14 +328,14 @@ def _traverse_tree(
327328
return p_d
328329

329330
def _traverse_leaf_values(
330-
self, leaf_values: List[npt.NDArray[np.float64]], leaf_n_values: List[int], node_index: int
331+
self, leaf_values: list[npt.NDArray[np.float64]], leaf_n_values: list[int], node_index: int
331332
) -> None:
332333
"""
333334
Traverse the tree appending leaf values starting from a particular node.
334335
335336
Parameters
336337
----------
337-
leaf_values : List[npt.NDArray[np.float64]]
338+
leaf_values : list[npt.NDArray[np.float64]]
338339
node_index : int
339340
"""
340341
node = self.get_node(node_index)

0 commit comments

Comments
 (0)