Skip to content

Commit 4ca9ccd

Browse files
committed
pass args.single to CombinedDataGenerator
1 parent 2e64c2a commit 4ca9ccd

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

Pilot1/Uno/uno_baseline_keras2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,8 @@ def warmup_scheduler(epoch):
426426
train_gen = DataFeeder(filename=args.use_exported_data, batch_size=args.batch_size, shuffle=args.shuffle, single=args.single, agg_dose=args.agg_dose)
427427
val_gen = DataFeeder(partition='val', filename=args.use_exported_data, batch_size=args.batch_size, shuffle=args.shuffle, single=args.single, agg_dose=args.agg_dose)
428428
else:
429-
train_gen = CombinedDataGenerator(loader, fold=fold, batch_size=args.batch_size, shuffle=args.shuffle)
430-
val_gen = CombinedDataGenerator(loader, partition='val', fold=fold, batch_size=args.batch_size, shuffle=args.shuffle)
429+
train_gen = CombinedDataGenerator(loader, fold=fold, batch_size=args.batch_size, shuffle=args.shuffle, single=args.single)
430+
val_gen = CombinedDataGenerator(loader, partition='val', fold=fold, batch_size=args.batch_size, shuffle=args.shuffle, single=args.single)
431431

432432
df_val = val_gen.get_response(copy=True)
433433
y_val = df_val[target].values

Pilot1/Uno/uno_data.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -997,10 +997,11 @@ def close(self):
997997
class CombinedDataGenerator(keras.utils.Sequence):
998998
"""Generate training, validation or testing batches from loaded data
999999
"""
1000-
def __init__(self, data, partition='train', fold=0, source=None, batch_size=32, shuffle=True, rank=0, total_ranks=1):
1000+
def __init__(self, data, partition='train', fold=0, source=None, batch_size=32, shuffle=True, single=False, rank=0, total_ranks=1):
10011001
self.data = data
10021002
self.partition = partition
10031003
self.batch_size = batch_size
1004+
self.single = single
10041005

10051006
if partition == 'train':
10061007
index = data.train_indexes[fold]
@@ -1031,7 +1032,7 @@ def __len__(self):
10311032

10321033
def __getitem__(self, idx):
10331034
shard = self.index[idx * self.batch_size:(idx + 1) * self.batch_size]
1034-
x_list, y = self.get_slice(self.batch_size, partial_index=shard)
1035+
x_list, y = self.get_slice(self.batch_size, single=self.single, partial_index=shard)
10351036
return x_list, y
10361037

10371038
def reset(self):

0 commit comments

Comments
 (0)