Skip to content

Commit a587a02

Browse files
committed
Attempted fix to use_exported_data
1 parent d455026 commit a587a02

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

Pilot1/Uno/uno_baseline_keras2.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,13 @@ def run(params):
268268
config.gpu_options.visible_device_list = ",".join(map(str, args.gpus))
269269
tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))
270270

271+
if args.use_exported_data is not None:
272+
if os.environ["CANDLE_DATA_DIR"] is not None:
273+
datadir = os.environ["CANDLE_DATA_DIR"]
274+
datafile = os.path.join(dtadir, args.use_exported_data)
275+
else:
276+
datafile = args.use_exported_data
277+
271278
loader = CombinedDataLoader(seed=args.rng_seed)
272279
loader.load(
273280
cache=args.cache,
@@ -289,7 +296,7 @@ def run(params):
289296
test_sources=args.test_sources,
290297
embed_feature_source=not args.no_feature_source,
291298
encode_response_source=not args.no_response_source,
292-
use_exported_data=args.use_exported_data,
299+
use_exported_data=datafile
293300
)
294301

295302
target = args.agg_dose or "Growth"
@@ -485,8 +492,14 @@ def warmup_scheduler(epoch):
485492
callbacks.append(MultiGPUCheckpoint(args.save_weights))
486493

487494
if args.use_exported_data is not None:
495+
if os.environ["CANDLE_DATA_DIR"] is not None:
496+
datadir = os.environ["CANDLE_DATA_DIR"]
497+
datafile = os.path.join(dtadir, args.use_exported_data)
498+
else:
499+
datafile = args.use_exported_data
500+
488501
train_gen = DataFeeder(
489-
filename=args.use_exported_data,
502+
filename=datafile,
490503
batch_size=args.batch_size,
491504
shuffle=args.shuffle,
492505
single=args.single,
@@ -495,7 +508,7 @@ def warmup_scheduler(epoch):
495508
)
496509
val_gen = DataFeeder(
497510
partition="val",
498-
filename=args.use_exported_data,
511+
filename=datafile,
499512
batch_size=args.batch_size,
500513
shuffle=args.shuffle,
501514
single=args.single,
@@ -504,7 +517,7 @@ def warmup_scheduler(epoch):
504517
)
505518
test_gen = DataFeeder(
506519
partition="test",
507-
filename=args.use_exported_data,
520+
filename=datafile,
508521
batch_size=args.batch_size,
509522
shuffle=args.shuffle,
510523
single=args.single,

0 commit comments

Comments
 (0)