Skip to content

Commit 703306c

Browse files
committed
fix Collective multi dataset_name
1 parent 70b6bbe commit 703306c

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

core/trainers/framework/network.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def __init__(self, context):
363363
def build_network(self, context):
364364
context["model"] = {}
365365
if len(context["env"]["phase"]) > 1:
366+
print("CollectiveNetwork phase:{}".format(context["env"]["phase"]))
366367
warnings.warn(
367368
"Cluster Train Only Support One Phase.",
368369
category=UserWarning,
@@ -407,16 +408,17 @@ def build_network(self, context):
407408
context["model"][model_dict["name"]]["compiled_program"] = None
408409

409410
context["dataset"] = {}
410-
for dataset in context["env"]["dataset"]:
411-
type = envs.get_global_env("dataset." + dataset["name"] + ".type")
411+
for phase in context["env"]["phase"]:
412+
type = envs.get_global_env("dataset." + phase["dataset_name"] +
413+
".type")
412414
if type == "QueueDataset":
413415
raise ValueError(
414416
"Collective don't support QueueDataset training, please use DataLoader."
415417
)
416418
dataset_class = QueueDataset(context)
417-
context["dataset"][dataset[
418-
"name"]] = dataset_class.create_dataset(dataset["name"],
419-
context)
419+
context["dataset"][phase[
420+
"dataset_name"]] = dataset_class.create_dataset(
421+
phase["dataset_name"], context)
420422
context["status"] = "startup_pass"
421423

422424
def _build_strategy(self, context):

0 commit comments

Comments
 (0)