@@ -363,6 +363,7 @@ def __init__(self, context):
363363 def build_network (self , context ):
364364 context ["model" ] = {}
365365 if len (context ["env" ]["phase" ]) > 1 :
366+ print ("CollectiveNetwork phase:{}" .format (context ["env" ]["phase" ]))
366367 warnings .warn (
367368 "Cluster Train Only Support One Phase." ,
368369 category = UserWarning ,
@@ -407,16 +408,17 @@ def build_network(self, context):
407408 context ["model" ][model_dict ["name" ]]["compiled_program" ] = None
408409
409410 context ["dataset" ] = {}
410- for dataset in context ["env" ]["dataset" ]:
411- type = envs .get_global_env ("dataset." + dataset ["name" ] + ".type" )
411+ for phase in context ["env" ]["phase" ]:
412+ type = envs .get_global_env ("dataset." + phase ["dataset_name" ] +
413+ ".type" )
412414 if type == "QueueDataset" :
413415 raise ValueError (
414416 "Collective don't support QueueDataset training, please use DataLoader."
415417 )
416418 dataset_class = QueueDataset (context )
417- context ["dataset" ][dataset [
418- "name " ]] = dataset_class .create_dataset (dataset [ "name" ],
419- context )
419+ context ["dataset" ][phase [
420+ "dataset_name " ]] = dataset_class .create_dataset (
421+ phase [ "dataset_name" ], context )
420422 context ["status" ] = "startup_pass"
421423
422424 def _build_strategy (self , context ):
0 commit comments