Skip to content

Commit 7a35959

Browse files
fix split files at PY3 (#103)
* fix split files at PY3 * fix linux at PY3 * fix desc error * fix collective cards and worknum Co-authored-by: tangwei <[email protected]>
1 parent 947395b commit 7a35959

File tree

5 files changed

+54
-12
lines changed

5 files changed

+54
-12
lines changed

core/trainers/framework/dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
from __future__ import print_function
1616

1717
import os
18-
import warnings
1918

2019
import paddle.fluid as fluid
2120
from paddlerec.core.utils import envs
2221
from paddlerec.core.utils import dataloader_instance
2322
from paddlerec.core.reader import SlotReader
2423
from paddlerec.core.trainer import EngineMode
24+
from paddlerec.core.utils.util import split_files
2525

2626
__all__ = ["DatasetBase", "DataLoader", "QueueDataset"]
2727

@@ -123,7 +123,8 @@ def _get_dataset(self, dataset_name, context):
123123
for x in os.listdir(train_data_path)
124124
]
125125
if context["engine"] == EngineMode.LOCAL_CLUSTER:
126-
file_list = context["fleet"].split_files(file_list)
126+
file_list = split_files(file_list, context["fleet"].worker_index(),
127+
context["fleet"].worker_num())
127128

128129
dataset.set_filelist(file_list)
129130
for model_dict in context["phases"]:

core/utils/dataloader_instance.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from paddlerec.core.utils.envs import get_runtime_environ
2020
from paddlerec.core.reader import SlotReader
2121
from paddlerec.core.trainer import EngineMode
22+
from paddlerec.core.utils.util import split_files
2223

2324

2425
def dataloader_by_name(readerclass,
@@ -39,7 +40,8 @@ def dataloader_by_name(readerclass,
3940

4041
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
4142
if context["engine"] == EngineMode.LOCAL_CLUSTER:
42-
files = context["fleet"].split_files(files)
43+
files = split_files(files, context["fleet"].worker_index(),
44+
context["fleet"].worker_num())
4345
print("file_list : {}".format(files))
4446

4547
reader = reader_class(yaml_file)
@@ -80,7 +82,8 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
8082

8183
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
8284
if context["engine"] == EngineMode.LOCAL_CLUSTER:
83-
files = context["fleet"].split_files(files)
85+
files = split_files(files, context["fleet"].worker_index(),
86+
context["fleet"].worker_num())
8487
print("file_list: {}".format(files))
8588

8689
sparse = get_global_env(name + "sparse_slots", "#")
@@ -133,7 +136,8 @@ def slotdataloader(readerclass, train, yaml_file, context):
133136

134137
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
135138
if context["engine"] == EngineMode.LOCAL_CLUSTER:
136-
files = context["fleet"].split_files(files)
139+
files = split_files(files, context["fleet"].worker_index(),
140+
context["fleet"].worker_num())
137141
print("file_list: {}".format(files))
138142

139143
sparse = get_global_env("sparse_slots", "#", namespace)

core/utils/envs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
import socket
2020
import sys
21+
import six
2122
import traceback
2223
import six
2324

@@ -102,6 +103,12 @@ def fatten_env_namespace(namespace_nests, local_envs):
102103
name = ".".join(["dataset", dataset["name"], "type"])
103104
global_envs[name] = "DataLoader"
104105

106+
if get_platform() == "LINUX" and six.PY3:
107+
print("QueueDataset can not support PY3, change to DataLoader")
108+
for dataset in envs["dataset"]:
109+
name = ".".join(["dataset", dataset["name"], "type"])
110+
global_envs[name] = "DataLoader"
111+
105112

106113
def get_global_env(env_name, default_value=None, namespace=None):
107114
"""

core/utils/util.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,8 @@
1919
import numpy as np
2020
from paddle import fluid
2121

22-
from paddlerec.core.utils import fs as fs
23-
2422

2523
def save_program_proto(path, program=None):
26-
2724
if program is None:
2825
_program = fluid.default_main_program()
2926
else:
@@ -171,6 +168,39 @@ def print_cost(cost, params):
171168
return log_str
172169

173170

171+
def split_files(files, trainer_id, trainers):
172+
"""
173+
split files before distributed training,
174+
example 1: files is [a, b, c ,d, e] and trainer_num = 2, then trainer
175+
0 gets [a, b, c] and trainer 1 gets [d, e].
176+
example 2: files is [a, b], and trainer_num = 3, then trainer 0 gets
177+
[a], trainer 1 gets [b], trainer 2 gets []
178+
179+
Args:
180+
files(list): file list need to be read.
181+
182+
Returns:
183+
list: files belongs to this worker.
184+
"""
185+
if not isinstance(files, list):
186+
raise TypeError("files should be a list of file need to be read.")
187+
188+
remainder = len(files) % trainers
189+
blocksize = int(len(files) / trainers)
190+
191+
blocks = [blocksize] * trainers
192+
for i in range(remainder):
193+
blocks[i] += 1
194+
195+
trainer_files = [[]] * trainers
196+
begin = 0
197+
for i in range(trainers):
198+
trainer_files[i] = files[begin:begin + blocks[i]]
199+
begin += blocks[i]
200+
201+
return trainer_files[trainer_id]
202+
203+
174204
class CostPrinter(object):
175205
"""
176206
For count cost time && print cost log

run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def get_engine(args, running_config, mode):
139139
engine = "LOCAL_CLUSTER_TRAIN"
140140

141141
if engine not in engine_choices:
142-
raise ValueError("{} can not be chosen in {}".format(engine_class,
143-
engine_choices))
142+
raise ValueError("{} can only be chosen in {}".format(engine_class,
143+
engine_choices))
144144

145145
run_engine = engines[transpiler].get(engine, None)
146146
return run_engine
@@ -439,8 +439,8 @@ def get_worker_num(run_extras, workers):
439439
if fleet_mode == "COLLECTIVE":
440440
cluster_envs["selected_gpus"] = selected_gpus
441441
gpus = selected_gpus.split(",")
442-
gpu_num = get_worker_num(run_extras, len(gpus))
443-
cluster_envs["selected_gpus"] = ','.join(gpus[:gpu_num])
442+
worker_num = get_worker_num(run_extras, len(gpus))
443+
cluster_envs["selected_gpus"] = ','.join(gpus[:worker_num])
444444

445445
cluster_envs["server_num"] = server_num
446446
cluster_envs["worker_num"] = worker_num

0 commit comments

Comments
 (0)