1414class GraphicalSimulator (Simulator ):
1515 """
1616 A graph-based simulator that generates samples by traversing a DAG
17- and calling user-defined sampling functions at each node.
17+ and calling user-defined sample functions at each node.
1818
1919 Parameters
2020 ----------
2121 meta_fn : Optional[Callable[[], dict[str, Any]]]
2222 A callable that returns a dictionary of meta data.
23- This meta data can be used to dynamically vary the number of sampling repetitions (`reps`)
23+ This meta data can be used to dynamically vary the number of sample repetitions (`reps`)
2424 for nodes added via `add_node`.
2525 """
2626
@@ -39,7 +39,7 @@ def add_edge(self, from_node: str, to_node: str):
3939 def sample (self , batch_shape : Shape | int , ** kwargs ) -> dict [str , np .ndarray ]:
4040 """
4141 Generates samples by topologically traversing the DAG.
42- For each node, the sampling function is called based on parent values.
42+ For each node, the sample function is called based on parent values.
4343
4444 Parameters
4545 ----------
@@ -57,19 +57,21 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]:
5757 for node in self .graph .nodes :
5858 samples_by_node [node ] = np .empty (batch_shape , dtype = "object" )
5959
60+ ordered_nodes = list (nx .topological_sort (self .graph ))
61+
6062 for batch_idx in np .ndindex (batch_shape ):
61- for node in nx . topological_sort ( self . graph ) :
63+ for node in ordered_nodes :
6264 node_samples = []
6365
6466 parent_nodes = list (self .graph .predecessors (node ))
65- sampling_fn = self .graph .nodes [node ]["sample_fn" ]
67+ sample_fn = self .graph .nodes [node ]["sample_fn" ]
6668 reps_field = self .graph .nodes [node ]["reps" ]
6769 reps = reps_field if isinstance (reps_field , int ) else meta_dict [reps_field ]
6870
6971 if not parent_nodes :
7072 # root node: generate independent samples
7173 node_samples = [
72- {"__batch_idx" : batch_idx , f"__{ node } _idx" : i } | self ._call_sample_fn (sampling_fn , {})
74+ {"__batch_idx" : batch_idx , f"__{ node } _idx" : i } | self ._call_sample_fn (sample_fn , {})
7375 for i in range (reps )
7476 ]
7577 else :
@@ -81,15 +83,12 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]:
8183 index_entries = {k : v for k , v in merged .items () if k .startswith ("__" )}
8284 variable_entries = {k : v for k , v in merged .items () if not k .startswith ("__" )}
8385
84- sampling_fn_input = variable_entries | meta_dict
85- node_samples .extend (
86- [
87- index_entries
88- | {f"__{ node } _idx" : i }
89- | self ._call_sample_fn (sampling_fn , sampling_fn_input )
90- for i in range (reps )
91- ]
92- )
86+ sample_fn_input = variable_entries | meta_dict
87+ samples = [
88+ index_entries | {f"__{ node } _idx" : i } | self ._call_sample_fn (sample_fn , sample_fn_input )
89+ for i in range (reps )
90+ ]
91+ node_samples .extend (samples )
9392
9493 samples_by_node [node ][batch_idx ] = node_samples
9594
0 commit comments