@@ -70,7 +70,7 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]:
7070 # root node: generate independent samples
7171 node_samples = [
7272 {"__batch_idx" : batch_idx , f"__{ node } _idx" : i } | self ._call_sample_fn (sampling_fn , {})
73- for i in range (1 , reps + 1 )
73+ for i in range (reps )
7474 ]
7575 else :
7676 # non-root node: depends on parent samples
@@ -87,7 +87,7 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]:
8787 index_entries
8888 | {f"__{ node } _idx" : i }
8989 | self ._call_sample_fn (sampling_fn , sampling_fn_input )
90- for i in range (1 , reps + 1 )
90+ for i in range (reps )
9191 ]
9292 )
9393
@@ -113,8 +113,8 @@ def _collect_output(self, samples):
113113 # build dict of node repetitions
114114 reps = {}
115115 for ancestor in ancestors :
116- reps [ancestor ] = max (s [f"__{ ancestor } _idx" ] for s in samples .flat [0 ])
117- reps [node ] = max (s [f"__{ node } _idx" ] for s in samples .flat [0 ])
116+ reps [ancestor ] = max (s [f"__{ ancestor } _idx" ] for s in samples .flat [0 ]) + 1
117+ reps [node ] = max (s [f"__{ node } _idx" ] for s in samples .flat [0 ]) + 1
118118
119119 variable_names = self ._variable_names (samples )
120120
@@ -130,11 +130,11 @@ def _collect_output(self, samples):
130130 # add index elements for ancestors
131131 for ancestor in ancestors :
132132 if reps [ancestor ] != 1 :
133- idx .append (sample [f"__{ ancestor } _idx" ] - 1 ) # -1 for 0-based indexing
133+ idx .append (sample [f"__{ ancestor } _idx" ])
134134
135135 # add index elements for node
136136 if reps [node ] != 1 :
137- idx .append (sample [f"__{ node } _idx" ] - 1 ) # -1 for 0-based indexing
137+ idx .append (sample [f"__{ node } _idx" ])
138138
139139 output_dict [variable ][tuple (idx )] = sample [variable ]
140140
@@ -154,12 +154,12 @@ def _output_shape(self, samples, variable):
154154
155155 # add ancestor reps
156156 for ancestor in ancestors :
157- node_reps = max (s [f"__{ ancestor } _idx" ] for s in samples .flat [0 ])
157+ node_reps = max (s [f"__{ ancestor } _idx" ] for s in samples .flat [0 ]) + 1
158158 if node_reps != 1 :
159159 output_shape .append (node_reps )
160160
161161 # add node reps
162- node_reps = max (s [f"__{ node } _idx" ] for s in samples .flat [0 ])
162+ node_reps = max (s [f"__{ node } _idx" ] for s in samples .flat [0 ]) + 1
163163 if node_reps != 1 :
164164 output_shape .append (node_reps )
165165
0 commit comments