Skip to content

Commit 70b6bbe

Browse files
committed
fix PSRunner multi dataset_name
1 parent 921d6c0 commit 70b6bbe

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

core/trainers/framework/network.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -238,18 +238,18 @@ def build_network(self, context):
238238
else:
239239
context["fleet"].init_worker()
240240
context["dataset"] = {}
241-
for dataset in context["env"]["dataset"]:
242-
type = envs.get_global_env("dataset." + dataset["name"] +
241+
for phase in context["env"]["phase"]:
242+
type = envs.get_global_env("dataset." + phase["dataset_name"] +
243243
".type")
244244
if type == "DataLoader":
245245
data_loader = DataLoader(context)
246246
data_loader.get_dataloader(context, dataset_name,
247247
model._data_loader)
248248
elif type == "QueueDataset":
249249
dataset_class = QueueDataset(context)
250-
context["dataset"][dataset[
251-
"name"]] = dataset_class.create_dataset(
252-
dataset["name"], context)
250+
context["dataset"][phase[
251+
"dataset_name"]] = dataset_class.create_dataset(
252+
phase["dataset_name"], context)
253253
context["status"] = "startup_pass"
254254

255255
def _build_strategy(self, context):
@@ -336,7 +336,7 @@ def build_network(self, context):
336336
self._server(context)
337337
else:
338338
context["dataset"] = {}
339-
for dataset in context["env"]["dataset"]:
339+
for phase in context["env"]["phase"]:
340340
type = envs.get_global_env("dataset." + dataset["name"] +
341341
".type")
342342
if type == "DataLoader":

0 commit comments

Comments
 (0)