@@ -238,18 +238,18 @@ def build_network(self, context):
238238 else :
239239 context ["fleet" ].init_worker ()
240240 context ["dataset" ] = {}
241- for dataset in context ["env" ]["dataset " ]:
242- type = envs .get_global_env ("dataset." + dataset [ "name " ] +
241+ for phase in context ["env" ]["phase " ]:
242+ type = envs .get_global_env ("dataset." + phase [ "dataset_name " ] +
243243 ".type" )
244244 if type == "DataLoader" :
245245 data_loader = DataLoader (context )
246246 data_loader .get_dataloader (context , dataset_name ,
247247 model ._data_loader )
248248 elif type == "QueueDataset" :
249249 dataset_class = QueueDataset (context )
250- context ["dataset" ][dataset [
251- "name " ]] = dataset_class .create_dataset (
252- dataset [ "name " ], context )
250+ context ["dataset" ][phase [
251+ "dataset_name " ]] = dataset_class .create_dataset (
252+ phase [ "dataset_name " ], context )
253253 context ["status" ] = "startup_pass"
254254
255255 def _build_strategy (self , context ):
@@ -336,7 +336,7 @@ def build_network(self, context):
336336 self ._server (context )
337337 else :
338338 context ["dataset" ] = {}
339- for dataset in context ["env" ]["dataset " ]:
339+ for phase in context ["env" ]["phase " ]:
340340 type = envs .get_global_env ("dataset." + dataset ["name" ] +
341341 ".type" )
342342 if type == "DataLoader" :
@@ -363,6 +363,7 @@ def __init__(self, context):
363363 def build_network (self , context ):
364364 context ["model" ] = {}
365365 if len (context ["env" ]["phase" ]) > 1 :
366+ print ("CollectiveNetwork phase:{}" .format (context ["env" ]["phase" ]))
366367 warnings .warn (
367368 "Cluster Train Only Support One Phase." ,
368369 category = UserWarning ,
@@ -407,16 +408,17 @@ def build_network(self, context):
407408 context ["model" ][model_dict ["name" ]]["compiled_program" ] = None
408409
409410 context ["dataset" ] = {}
410- for dataset in context ["env" ]["dataset" ]:
411- type = envs .get_global_env ("dataset." + dataset ["name" ] + ".type" )
411+ for phase in context ["env" ]["phase" ]:
412+ type = envs .get_global_env ("dataset." + phase ["dataset_name" ] +
413+ ".type" )
412414 if type == "QueueDataset" :
413415 raise ValueError (
414416 "Collective don't support QueueDataset training, please use DataLoader."
415417 )
416418 dataset_class = QueueDataset (context )
417- context ["dataset" ][dataset [
418- "name " ]] = dataset_class .create_dataset (dataset [ "name" ],
419- context )
419+ context ["dataset" ][phase [
420+ "dataset_name " ]] = dataset_class .create_dataset (
421+ phase [ "dataset_name" ], context )
420422 context ["status" ] = "startup_pass"
421423
422424 def _build_strategy (self , context ):
0 commit comments