@@ -268,6 +268,13 @@ def run(params):
268
268
config .gpu_options .visible_device_list = "," .join (map (str , args .gpus ))
269
269
tf .compat .v1 .keras .backend .set_session (tf .compat .v1 .Session (config = config ))
270
270
271
+ if args .use_exported_data is not None :
272
+ if os .environ ["CANDLE_DATA_DIR" ] is not None :
273
+ datadir = os .environ ["CANDLE_DATA_DIR" ]
274
+ datafile = os .path .join (dtadir , args .use_exported_data )
275
+ else :
276
+ datafile = args .use_exported_data
277
+
271
278
loader = CombinedDataLoader (seed = args .rng_seed )
272
279
loader .load (
273
280
cache = args .cache ,
@@ -289,7 +296,7 @@ def run(params):
289
296
test_sources = args .test_sources ,
290
297
embed_feature_source = not args .no_feature_source ,
291
298
encode_response_source = not args .no_response_source ,
292
- use_exported_data = args . use_exported_data ,
299
+ use_exported_data = datafile
293
300
)
294
301
295
302
target = args .agg_dose or "Growth"
@@ -485,8 +492,14 @@ def warmup_scheduler(epoch):
485
492
callbacks .append (MultiGPUCheckpoint (args .save_weights ))
486
493
487
494
if args .use_exported_data is not None :
495
+ if os .environ ["CANDLE_DATA_DIR" ] is not None :
496
+ datadir = os .environ ["CANDLE_DATA_DIR" ]
497
+ datafile = os .path .join (dtadir , args .use_exported_data )
498
+ else :
499
+ datafile = args .use_exported_data
500
+
488
501
train_gen = DataFeeder (
489
- filename = args . use_exported_data ,
502
+ filename = datafile ,
490
503
batch_size = args .batch_size ,
491
504
shuffle = args .shuffle ,
492
505
single = args .single ,
@@ -495,7 +508,7 @@ def warmup_scheduler(epoch):
495
508
)
496
509
val_gen = DataFeeder (
497
510
partition = "val" ,
498
- filename = args . use_exported_data ,
511
+ filename = datafile ,
499
512
batch_size = args .batch_size ,
500
513
shuffle = args .shuffle ,
501
514
single = args .single ,
@@ -504,7 +517,7 @@ def warmup_scheduler(epoch):
504
517
)
505
518
test_gen = DataFeeder (
506
519
partition = "test" ,
507
- filename = args . use_exported_data ,
520
+ filename = datafile ,
508
521
batch_size = args .batch_size ,
509
522
shuffle = args .shuffle ,
510
523
single = args .single ,
0 commit comments