Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.3
rev: v0.8.4
hooks:
- id: ruff
args: ["--fix", "--output-format=full"]
- id: ruff-format
args: ["--line-length=100"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
rev: v1.14.0
hooks:
- id: mypy
args: [--ignore-missing-imports]
Expand Down
15 changes: 0 additions & 15 deletions mypy.ini

This file was deleted.

10 changes: 5 additions & 5 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def _update(
m_2 += delta * delta2

std = (m_2 / count) ** 0.5
return mean, m_2, std
return mean.astype(np.float64), m_2.astype(np.float64), std.astype(np.float64)


class SampleSplittingVariable:
Expand Down Expand Up @@ -556,12 +556,12 @@ def draw_leaf_value(
) -> tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
"""Draw Gaussian distributed leaf values."""
linear_params = None
mu_mean = np.empty(shape)
mu_mean: npt.NDArray[np.float64]
if y_mu_pred.size == 0:
return np.zeros(shape), linear_params

if y_mu_pred.size == 1:
mu_mean = np.full(shape, y_mu_pred.item() / m) + norm
mu_mean = (np.full(shape, y_mu_pred.item() / m) + norm).astype(np.float64)
elif y_mu_pred.size < 3 or response == "constant":
mu_mean = fast_mean(y_mu_pred) / m + norm
else:
Expand All @@ -585,7 +585,7 @@ def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float
for j in range(ari.shape[0]):
for i in range(count):
res[j] += ari[j, i]
return res / count
return (res / count).astype(np.float64)


@njit
Expand All @@ -596,7 +596,7 @@ def fast_linear_fit(
norm: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.float64], list[npt.NDArray[np.float64]]]:
n = len(x)
y = y / m + np.expand_dims(norm, axis=1)
y = (y / m + np.expand_dims(norm, axis=1)).astype(np.float64)

xbar = np.sum(x) / n
ybar = np.sum(y, axis=1) / n
Expand Down
14 changes: 12 additions & 2 deletions pymc_bart/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,19 @@ def _traverse_tree(
)
if excluded is not None and idx_split_variable in excluded:
prop_nvalue_left = self.get_node(left_node_index).nvalue / node.nvalue
stack.append((left_node_index, weights * prop_nvalue_left, idx_split_variable))
stack.append(
(right_node_index, weights * (1 - prop_nvalue_left), idx_split_variable)
(
left_node_index,
(weights * prop_nvalue_left).astype(np.float64),
idx_split_variable,
)
)
stack.append(
(
right_node_index,
(weights * (1 - prop_nvalue_left)).astype(np.float64),
idx_split_variable,
)
)
else:
to_left = (
Expand Down
8 changes: 4 additions & 4 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,9 +826,9 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
else:
labels = np.arange(n_vars).astype(str)

r2_mean = np.zeros(n_vars)
r2_hdi = np.zeros((n_vars, 2))
preds = np.zeros((n_vars, samples, *bartrv.eval().T.shape))
r2_mean: npt.NDArray[np.float64] = np.zeros(n_vars)
r2_hdi: npt.NDArray[np.float64] = np.zeros((n_vars, 2))
preds: npt.NDArray[np.float64] = np.zeros((n_vars, samples, *bartrv.eval().T.shape))

if method == "backward_VI":
if fixed >= n_vars:
Expand All @@ -848,7 +848,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
idxs = np.argsort(
idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
)
subsets = [idxs[:-i].tolist() for i in range(1, len(idxs))]
subsets: list[list[int]] = [list(idxs[:-i]) for i in range(1, len(idxs))]
subsets.append(None) # type: ignore

if method == "backward_VI":
Expand Down
17 changes: 17 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,20 @@ exclude_lines = [
isort = 1
black = 1
pyupgrade = 1


[tool.mypy]
files = "pymc_bart/*.py"
plugins = "numpy.typing.mypy_plugin"

[tool.mypy-matplotlib]
ignore_missing_imports = true

[tool.mypy-numba]
ignore_missing_imports = true

[tool.mypy-pymc]
ignore_missing_imports = true

[tool.mypy-scipy]
ignore_missing_imports = true
Comment on lines +38 to +52
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way we can remove mypy.ini

Loading