|
14 | 14 | from __future__ import print_function |
15 | 15 |
|
16 | 16 | import os |
| 17 | +import warnings |
17 | 18 | from paddlerec.core.utils.envs import lazy_instance_by_fliename |
18 | 19 | from paddlerec.core.utils.envs import get_global_env |
19 | 20 | from paddlerec.core.utils.envs import get_runtime_environ |
@@ -47,6 +48,16 @@ def dataloader_by_name(readerclass, |
47 | 48 |
|
48 | 49 | files.sort() |
49 | 50 |
|
| 51 | + # for local cluster: discard some files if files cannot be divided equally between GPUs |
| 52 | + if (context["device"] == "GPU") and "PADDLEREC_GPU_NUMS" in os.environ: |
| 53 | + selected_gpu_nums = int(os.getenv("PADDLEREC_GPU_NUMS")) |
| 54 | + discard_file_nums = len(files) % selected_gpu_nums |
| 55 | + if (discard_file_nums != 0): |
| 56 | + warnings.warn( |
| 57 | + "Because files cannot be divided equally between GPUs,discard these files:{}". |
| 58 | + format(files[-discard_file_nums:])) |
| 59 | + files = files[:len(files) - discard_file_nums] |
| 60 | + |
50 | 61 | need_split_files = False |
51 | 62 | if context["engine"] == EngineMode.LOCAL_CLUSTER: |
52 | 63 | # for local cluster: split files for multi process |
@@ -109,6 +120,16 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context): |
109 | 120 |
|
110 | 121 | files.sort() |
111 | 122 |
|
| 123 | + # for local cluster: discard some files if files cannot be divided equally between GPUs |
| 124 | + if (context["device"] == "GPU") and "PADDLEREC_GPU_NUMS" in os.environ: |
| 125 | + selected_gpu_nums = int(os.getenv("PADDLEREC_GPU_NUMS")) |
| 126 | + discard_file_nums = len(files) % selected_gpu_nums |
| 127 | + if (discard_file_nums != 0): |
| 128 | + warnings.warn( |
| 129 | + "Because files cannot be divided equally between GPUs,discard these files:{}". |
| 130 | + format(files[-discard_file_nums:])) |
| 131 | + files = files[:len(files) - discard_file_nums] |
| 132 | + |
112 | 133 | need_split_files = False |
113 | 134 | if context["engine"] == EngineMode.LOCAL_CLUSTER: |
114 | 135 | # for local cluster: split files for multi process |
@@ -153,73 +174,3 @@ def gen_batch_reader(): |
153 | 174 | if hasattr(reader, 'generate_batch_from_trainfiles'): |
154 | 175 | return gen_batch_reader() |
155 | 176 | return gen_reader |
156 | | - |
157 | | - |
158 | | -def slotdataloader(readerclass, train, yaml_file, context): |
159 | | - if train == "TRAIN": |
160 | | - reader_name = "SlotReader" |
161 | | - namespace = "train.reader" |
162 | | - data_path = get_global_env("train_data_path", None, namespace) |
163 | | - else: |
164 | | - reader_name = "SlotReader" |
165 | | - namespace = "evaluate.reader" |
166 | | - data_path = get_global_env("test_data_path", None, namespace) |
167 | | - |
168 | | - if data_path.startswith("paddlerec::"): |
169 | | - package_base = get_runtime_environ("PACKAGE_BASE") |
170 | | - assert package_base is not None |
171 | | - data_path = os.path.join(package_base, data_path.split("::")[1]) |
172 | | - |
173 | | - hidden_file_list, files = check_filelist( |
174 | | - hidden_file_list=[], data_file_list=[], train_data_path=data_path) |
175 | | - if (hidden_file_list is not None): |
176 | | - print( |
177 | | - "Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}". |
178 | | - format(hidden_file_list)) |
179 | | - |
180 | | - files.sort() |
181 | | - |
182 | | - need_split_files = False |
183 | | - if context["engine"] == EngineMode.LOCAL_CLUSTER: |
184 | | - # for local cluster: split files for multi process |
185 | | - need_split_files = True |
186 | | - elif context["engine"] == EngineMode.CLUSTER and context[ |
187 | | - "cluster_type"] == "K8S": |
188 | | - # for k8s mount mode, split files for every node |
189 | | - need_split_files = True |
190 | | - |
191 | | - if need_split_files: |
192 | | - files = split_files(files, context["fleet"].worker_index(), |
193 | | - context["fleet"].worker_num()) |
194 | | - context["file_list"] = files |
195 | | - sparse = get_global_env("sparse_slots", "#", namespace) |
196 | | - if sparse == "": |
197 | | - sparse = "#" |
198 | | - dense = get_global_env("dense_slots", "#", namespace) |
199 | | - if dense == "": |
200 | | - dense = "#" |
201 | | - padding = get_global_env("padding", 0, namespace) |
202 | | - reader = SlotReader(yaml_file) |
203 | | - reader.init(sparse, dense, int(padding)) |
204 | | - |
205 | | - def gen_reader(): |
206 | | - for file in files: |
207 | | - with open(file, 'r') as f: |
208 | | - for line in f: |
209 | | - line = line.rstrip('\n') |
210 | | - iter = reader.generate_sample(line) |
211 | | - for parsed_line in iter(): |
212 | | - if parsed_line is None: |
213 | | - continue |
214 | | - else: |
215 | | - values = [] |
216 | | - for pased in parsed_line: |
217 | | - values.append(pased[1]) |
218 | | - yield values |
219 | | - |
220 | | - def gen_batch_reader(): |
221 | | - return reader.generate_batch_from_trainfiles(files) |
222 | | - |
223 | | - if hasattr(reader, 'generate_batch_from_trainfiles'): |
224 | | - return gen_batch_reader() |
225 | | - return gen_reader |
0 commit comments