Skip to content

Commit 6f70cd2

Browse files
rename sampling_fn variables to sample_fn, precompute topological sort, remove nesting in node_samples computation
1 parent c5044c1 commit 6f70cd2

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

bayesflow/experimental/graphical_simulator/graphical_simulator.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
class GraphicalSimulator(Simulator):
1515
"""
1616
A graph-based simulator that generates samples by traversing a DAG
17-
and calling user-defined sampling functions at each node.
17+
and calling user-defined sample functions at each node.
1818
1919
Parameters
2020
----------
2121
meta_fn : Optional[Callable[[], dict[str, Any]]]
2222
A callable that returns a dictionary of meta data.
23-
This meta data can be used to dynamically vary the number of sampling repetitions (`reps`)
23+
This meta data can be used to dynamically vary the number of sample repetitions (`reps`)
2424
for nodes added via `add_node`.
2525
"""
2626

@@ -39,7 +39,7 @@ def add_edge(self, from_node: str, to_node: str):
3939
def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]:
4040
"""
4141
Generates samples by topologically traversing the DAG.
42-
For each node, the sampling function is called based on parent values.
42+
For each node, the sample function is called based on parent values.
4343
4444
Parameters
4545
----------
@@ -57,19 +57,21 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]:
5757
for node in self.graph.nodes:
5858
samples_by_node[node] = np.empty(batch_shape, dtype="object")
5959

60+
ordered_nodes = list(nx.topological_sort(self.graph))
61+
6062
for batch_idx in np.ndindex(batch_shape):
61-
for node in nx.topological_sort(self.graph):
63+
for node in ordered_nodes:
6264
node_samples = []
6365

6466
parent_nodes = list(self.graph.predecessors(node))
65-
sampling_fn = self.graph.nodes[node]["sample_fn"]
67+
sample_fn = self.graph.nodes[node]["sample_fn"]
6668
reps_field = self.graph.nodes[node]["reps"]
6769
reps = reps_field if isinstance(reps_field, int) else meta_dict[reps_field]
6870

6971
if not parent_nodes:
7072
# root node: generate independent samples
7173
node_samples = [
72-
{"__batch_idx": batch_idx, f"__{node}_idx": i} | self._call_sample_fn(sampling_fn, {})
74+
{"__batch_idx": batch_idx, f"__{node}_idx": i} | self._call_sample_fn(sample_fn, {})
7375
for i in range(reps)
7476
]
7577
else:
@@ -81,15 +83,12 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]:
8183
index_entries = {k: v for k, v in merged.items() if k.startswith("__")}
8284
variable_entries = {k: v for k, v in merged.items() if not k.startswith("__")}
8385

84-
sampling_fn_input = variable_entries | meta_dict
85-
node_samples.extend(
86-
[
87-
index_entries
88-
| {f"__{node}_idx": i}
89-
| self._call_sample_fn(sampling_fn, sampling_fn_input)
90-
for i in range(reps)
91-
]
92-
)
86+
sample_fn_input = variable_entries | meta_dict
87+
samples = [
88+
index_entries | {f"__{node}_idx": i} | self._call_sample_fn(sample_fn, sample_fn_input)
89+
for i in range(reps)
90+
]
91+
node_samples.extend(samples)
9392

9493
samples_by_node[node][batch_idx] = node_samples
9594

0 commit comments

Comments
 (0)