File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change 2121import logging
2222import os
2323from pickle import load as pickle_load
24- import tensorflow as tf
2524
2625import numpy as np
26+ import tensorflow as tf
2727from tqdm .autonotebook import tqdm
2828
2929from bayesflow .amortizers import (
@@ -737,7 +737,10 @@ def train_from_presimulation(
737737 input_dict = self .configurator (epoch_data [index ])
738738
739739 # Like the number of iterations, the batch size is inferred from presimulated dictionary or list
740- batch_size = epoch_data [index ][DEFAULT_KEYS ["sim_data" ]].shape [0 ]
740+ if isinstance (self .amortizer , AmortizedModelComparison ):
741+ batch_size = input_dict [DEFAULT_KEYS ["summary_conditions" ]].shape [0 ]
742+ else :
743+ batch_size = epoch_data [index ][DEFAULT_KEYS ["sim_data" ]].shape [0 ]
741744 loss = self ._train_step (batch_size , _backprop_step , input_dict , ** kwargs )
742745
743746 # Store returned loss
You can’t perform that action at this time.
0 commit comments