Skip to content

Commit 05677cf

Browse files
committed
Add fix to Trainer
1 parent 839e7f7 commit 05677cf

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

bayesflow/trainers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -765,5 +765,9 @@ def _forward_inference(self, n_sim, n_obs, summarize=True, **kwargs):
765765
if summarize and self.summary_stats is not None:
766766
# Return shape in this case is (batch_size, n_sum)
767767
sim_data = self.summary_stats(sim_data)
768-
769-
return params.astype(np.float32), sim_data.astype(np.float32)
768+
769+
if type(params) is np.ndarray:
770+
params = params.astype(np.float32)
771+
if type(sim_data) is np.ndarray:
772+
sim_data = sim_data.astype(np.float32)
773+
return params, sim_data

0 commit comments

Comments
 (0)