Skip to content

Commit 1afd3cc

Browse files
authored
use ruff for linting (#291)
1 parent 6a6903b commit 1afd3cc

29 files changed

+79
-108
lines changed

.pre-commit-config.yaml

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,18 @@
11
exclude: |
22
(?x)^(
3-
setup.py|
43
docs/.*|
54
)$
5+
66
repos:
7-
- repo: https://github.com/pre-commit/mirrors-isort
8-
rev: v5.10.1
9-
hooks:
10-
- id: isort
11-
- repo: https://github.com/ambv/black
12-
rev: 25.1.0
13-
hooks:
14-
- id: black
15-
language_version: python3
16-
- repo: https://github.com/pycqa/flake8
17-
rev: 7.2.0
18-
hooks:
19-
- id: flake8
20-
- repo: https://github.com/pycqa/pydocstyle
21-
rev: 6.3.0
22-
hooks:
23-
- id: pydocstyle
24-
args: ['--ignore', 'D213,D100,D203,D104']
25-
files: ^pyhgf/
7+
- repo: https://github.com/astral-sh/ruff-pre-commit
8+
# Ruff version.
9+
rev: v0.11.4
10+
hooks:
11+
# Run the linter.
12+
- id: ruff
13+
args: [ --fix ]
14+
# Run the formatter.
15+
- id: ruff-format
2616
- repo: https://github.com/pre-commit/mirrors-mypy
2717
rev: 'v1.15.0'
2818
hooks:

pyhgf/distribution.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def logp(
133133
134134
"""
135135
if hgf.model_type == "continuous":
136-
137136
# update this network's attributes
138137
hgf.attributes[0]["precision"] = input_precision
139138

@@ -161,7 +160,6 @@ def logp(
161160
hgf.attributes[3]["volatility_coupling_children"] = (volatility_coupling_2,)
162161

163162
elif hgf.model_type == "binary":
164-
165163
# update this network's attributes
166164
hgf.attributes[0]["mean"] = mean_1
167165
hgf.attributes[1]["mean"] = mean_2
@@ -486,23 +484,26 @@ def perform(
486484
):
487485
"""Perform node operations."""
488486
(
489-
grad_mean_1,
490-
grad_mean_2,
491-
grad_mean_3,
492-
grad_precision_1,
493-
grad_precision_2,
494-
grad_precision_3,
495-
grad_tonic_volatility_1,
496-
grad_tonic_volatility_2,
497-
grad_tonic_volatility_3,
498-
grad_tonic_drift_1,
499-
grad_tonic_drift_2,
500-
grad_tonic_drift_3,
501-
grad_volatility_coupling_1,
502-
grad_volatility_coupling_2,
503-
grad_input_precision,
504-
grad_response_function_parameters,
505-
), _ = self.grad_logp(*inputs)
487+
(
488+
grad_mean_1,
489+
grad_mean_2,
490+
grad_mean_3,
491+
grad_precision_1,
492+
grad_precision_2,
493+
grad_precision_3,
494+
grad_tonic_volatility_1,
495+
grad_tonic_volatility_2,
496+
grad_tonic_volatility_3,
497+
grad_tonic_drift_1,
498+
grad_tonic_drift_2,
499+
grad_tonic_drift_3,
500+
grad_volatility_coupling_1,
501+
grad_volatility_coupling_2,
502+
grad_input_precision,
503+
grad_response_function_parameters,
504+
),
505+
_,
506+
) = self.grad_logp(*inputs)
506507

507508
outputs[0][0] = np.asarray(grad_mean_1, dtype=node.outputs[0].dtype)
508509
outputs[1][0] = np.asarray(grad_mean_2, dtype=node.outputs[1].dtype)

pyhgf/model/add_nodes.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def add_ef_state(
159159

160160
# loop over the indexes of nodes created in the previous step
161161
for node_idx in range(network.n_nodes - 1, network.n_nodes - n_nodes - 1, -1):
162-
163162
# create the sufficient statistic function and store in the side parameters
164163
if network.attributes[node_idx]["distribution"] == "normal":
165164
sufficient_stats_fn = Normal().sufficient_statistics_from_observations
@@ -178,7 +177,6 @@ def add_ef_state(
178177
] = sufficient_stats_fn
179178

180179
if "hgf" in network.attributes[node_idx]["learning"]:
181-
182180
# create a collection of continuous state nodes
183181
# to track the sufficient statistics of the implied distribution
184182
for i in range(n_suff_stats):

pyhgf/model/hgf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ def __init__(
156156
)
157157

158158
elif model_type == "binary":
159-
160159
if binary_precision == jnp.inf:
161160
# X - 0
162161
self.add_nodes(

pyhgf/model/network.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ def create_belief_propagation_fn(
142142

143143
# Create the generative scan function if it doesn't exist, and if requested.
144144
if (self.scan_fn_sample is None) and sampling_fn:
145-
146145
self.sample_scan_fn = Partial(
147146
beliefs_propagation,
148147
update_sequence=self.update_sequence,

pyhgf/plots/graphviz/plot_network.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def plot_network(network: "Network") -> "Source":
3838

3939
# create the rest of nodes
4040
for idx in range(len(network.edges)):
41-
4241
style = "filled" if idx in network.input_idxs else ""
4342

4443
if network.edges[idx].node_type == 1:
@@ -89,7 +88,6 @@ def plot_network(network: "Network") -> "Source":
8988

9089
if value_parents is not None:
9190
for value_parents_idx in value_parents:
92-
9391
# get the coupling function from the value parent
9492
child_idx = network.edges[value_parents_idx].value_children.index(i)
9593
coupling_fn = network.edges[value_parents_idx].coupling_fn[child_idx]

pyhgf/plots/matplotlib/plot_nodes.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,9 @@ def plot_nodes(
148148
# plotting standard deviation - in the case of a binary input node, the
149149
# CI should be read from the value parent using the sigmoid transform
150150
if ci is True:
151-
152151
# get parent nodes and sum predictions
153152
mean_parent, precision_parent = 0.0, 0.0
154153
for idx in network.edges[node_idx].value_parents: # type: ignore
155-
156154
# compute mu +/- sd at time t-1
157155
# and use the sigmoid transform before plotting
158156
mean_parent += trajectories_df[f"x_{idx}_expected_mean"]
@@ -164,7 +162,6 @@ def plot_nodes(
164162
y2 = 1 / (1 + np.exp(-mean_parent - sd))
165163

166164
if ci is True:
167-
168165
axs[i].fill_between(
169166
x=trajectories_df["time"],
170167
y1=y1,
@@ -180,7 +177,6 @@ def plot_nodes(
180177
# plotting state nodes
181178
# --------------------
182179
else:
183-
184180
axs[i].set_title(
185181
f"State Node {node_idx}",
186182
loc="left",

pyhgf/plots/matplotlib/plot_trajectories.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from pyhgf.plots.matplotlib.plot_nodes import plot_nodes
99

1010
if TYPE_CHECKING:
11-
1211
from pyhgf.model import Network
1312

1413

@@ -128,7 +127,6 @@ def plot_trajectories(
128127
# --------------
129128
ax_i = n_nodes - 1
130129
for node_idx in range(n_nodes):
131-
132130
if node_idx in network.input_idxs:
133131
_show_posterior = True
134132
color = "#4c72b0"

pyhgf/plots/networkx/plot_network.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import matplotlib.pyplot as plt
77

88
if TYPE_CHECKING:
9-
109
from pyhgf.model import Network
1110

1211

pyhgf/updates/posterior/continuous/posterior_update_mean_continuous_node.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,11 @@ def posterior_update_mean_continuous_node(
147147
# sum the precision weigthed prediction errors over all children
148148
value_precision_weigthed_prediction_error += (
149149
(
150-
(
151-
value_coupling
152-
* attributes[value_child_idx]["expected_precision"]
153-
* coupling_fn_prime
154-
)
155-
/ node_precision
150+
value_coupling
151+
* attributes[value_child_idx]["expected_precision"]
152+
* coupling_fn_prime
156153
)
154+
/ node_precision
157155
) * value_prediction_error
158156

159157
# Volatility coupling updates - update the mean of a volatility parent

0 commit comments

Comments
 (0)