Skip to content

Commit c5044c1

Browse files
use 0-based index for internal representation
1 parent a3c1fb6 commit c5044c1

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

bayesflow/experimental/graphical_simulator/graphical_simulator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]:
7070
# root node: generate independent samples
7171
node_samples = [
7272
{"__batch_idx": batch_idx, f"__{node}_idx": i} | self._call_sample_fn(sampling_fn, {})
73-
for i in range(1, reps + 1)
73+
for i in range(reps)
7474
]
7575
else:
7676
# non-root node: depends on parent samples
@@ -87,7 +87,7 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]:
8787
index_entries
8888
| {f"__{node}_idx": i}
8989
| self._call_sample_fn(sampling_fn, sampling_fn_input)
90-
for i in range(1, reps + 1)
90+
for i in range(reps)
9191
]
9292
)
9393

@@ -113,8 +113,8 @@ def _collect_output(self, samples):
113113
# build dict of node repetitions
114114
reps = {}
115115
for ancestor in ancestors:
116-
reps[ancestor] = max(s[f"__{ancestor}_idx"] for s in samples.flat[0])
117-
reps[node] = max(s[f"__{node}_idx"] for s in samples.flat[0])
116+
reps[ancestor] = max(s[f"__{ancestor}_idx"] for s in samples.flat[0]) + 1
117+
reps[node] = max(s[f"__{node}_idx"] for s in samples.flat[0]) + 1
118118

119119
variable_names = self._variable_names(samples)
120120

@@ -130,11 +130,11 @@ def _collect_output(self, samples):
130130
# add index elements for ancestors
131131
for ancestor in ancestors:
132132
if reps[ancestor] != 1:
133-
idx.append(sample[f"__{ancestor}_idx"] - 1) # -1 for 0-based indexing
133+
idx.append(sample[f"__{ancestor}_idx"])
134134

135135
# add index elements for node
136136
if reps[node] != 1:
137-
idx.append(sample[f"__{node}_idx"] - 1) # -1 for 0-based indexing
137+
idx.append(sample[f"__{node}_idx"])
138138

139139
output_dict[variable][tuple(idx)] = sample[variable]
140140

@@ -154,12 +154,12 @@ def _output_shape(self, samples, variable):
154154

155155
# add ancestor reps
156156
for ancestor in ancestors:
157-
node_reps = max(s[f"__{ancestor}_idx"] for s in samples.flat[0])
157+
node_reps = max(s[f"__{ancestor}_idx"] for s in samples.flat[0]) + 1
158158
if node_reps != 1:
159159
output_shape.append(node_reps)
160160

161161
# add node reps
162-
node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0])
162+
node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0]) + 1
163163
if node_reps != 1:
164164
output_shape.append(node_reps)
165165

0 commit comments

Comments
 (0)