Skip to content

Commit c26b0e7

Browse files
vslyuseiriosPlus
andauthored
add support for file_list shuffle each epoch and fix float learning rate bug (#197)
* add support for file_list shuffle each epoch, test=develop * fix float learning rate bug * optimized code for shuffle files Co-authored-by: tangwei12 <[email protected]>
1 parent 770693a commit c26b0e7

File tree

5 files changed

+53
-5
lines changed

5 files changed

+53
-5
lines changed

core/model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,13 @@ def optimizer(self):
177177
opt_name = envs.get_global_env("hyper_parameters.optimizer.class")
178178
opt_lr = envs.get_global_env(
179179
"hyper_parameters.optimizer.learning_rate")
180+
if not isinstance(opt_lr, (float, Variable)):
181+
try:
182+
opt_lr = float(opt_lr)
183+
except ValueError:
184+
raise ValueError(
185+
"In your config yaml, 'learning_rate': %s must be written as a floating piont number,such as 0.001 or 1e-3"
186+
% opt_lr)
180187
opt_strategy = envs.get_global_env(
181188
"hyper_parameters.optimizer.strategy")
182189

core/trainers/framework/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ def _get_dataset(self, dataset_name, context):
143143
if need_split_files:
144144
file_list = split_files(file_list, context["fleet"].worker_index(),
145145
context["fleet"].worker_num())
146+
147+
context["file_list"] = file_list
146148
print("File_list: {}".format(file_list))
147149

148150
dataset.set_filelist(file_list)

core/trainers/framework/runner.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
import time
1919
import warnings
2020
import numpy as np
21+
import random
2122
import logging
2223
import paddle.fluid as fluid
2324

2425
from paddlerec.core.utils import envs
26+
from paddlerec.core.utils.util import shuffle_files
2527
from paddlerec.core.metric import Metric
2628

2729
logging.basicConfig(
@@ -92,7 +94,6 @@ def _executor_dataset_train(self, model_dict, context):
9294
reader_name = model_dict["dataset_name"]
9395
model_name = model_dict["name"]
9496
model_class = context["model"][model_dict["name"]]["model"]
95-
9697
fetch_vars = []
9798
fetch_alias = []
9899
fetch_period = int(
@@ -395,7 +396,12 @@ def run(self, context):
395396
for model_dict in context["phases"]:
396397
model_class = context["model"][model_dict["name"]]["model"]
397398
metrics = model_class._metrics
398-
399+
if "shuffle_filelist" in model_dict:
400+
need_shuffle_files = model_dict.get("shuffle_filelist",
401+
None)
402+
filelist = context["file_list"]
403+
context["file_list"] = shuffle_files(need_shuffle_files,
404+
filelist)
399405
begin_time = time.time()
400406
result = self._run(context, model_dict)
401407
end_time = time.time()
@@ -439,6 +445,11 @@ def run(self, context):
439445
model_class = context["model"][model_dict["name"]]["model"]
440446
metrics = model_class._metrics
441447
for epoch in range(epochs):
448+
if "shuffle_filelist" in model_dict:
449+
need_shuffle_files = model_dict.get("shuffle_filelist", None)
450+
filelist = context["file_list"]
451+
context["file_list"] = shuffle_files(need_shuffle_files,
452+
filelist)
442453
begin_time = time.time()
443454
result = self._run(context, model_dict)
444455
end_time = time.time()
@@ -484,6 +495,11 @@ def run(self, context):
484495
".epochs"))
485496
model_dict = context["env"]["phase"][0]
486497
for epoch in range(epochs):
498+
if "shuffle_filelist" in model_dict:
499+
need_shuffle_files = model_dict.get("shuffle_filelist", None)
500+
filelist = context["file_list"]
501+
context["file_list"] = shuffle_files(need_shuffle_files,
502+
filelist)
487503
begin_time = time.time()
488504
self._run(context, model_dict)
489505
end_time = time.time()
@@ -512,6 +528,11 @@ def run(self, context):
512528
envs.get_global_env("runner." + context["runner_name"] +
513529
".epochs"))
514530
for epoch in range(epochs):
531+
if "shuffle_filelist" in model_dict:
532+
need_shuffle_files = model_dict.get("shuffle_filelist", None)
533+
filelist = context["file_list"]
534+
context["file_list"] = shuffle_files(need_shuffle_files,
535+
filelist)
515536
begin_time = time.time()
516537
self._run(context, model_dict)
517538
end_time = time.time()
@@ -574,6 +595,12 @@ def run(self, context):
574595
metrics = model_class._infer_results
575596
self._load(context, model_dict,
576597
self.epoch_model_path_list[index])
598+
if "shuffle_filelist" in model_dict:
599+
need_shuffle_files = model_dict.get("shuffle_filelist",
600+
None)
601+
filelist = context["file_list"]
602+
context["file_list"] = shuffle_files(need_shuffle_files,
603+
filelist)
577604
begin_time = time.time()
578605
result = self._run(context, model_dict)
579606
end_time = time.time()

core/utils/dataloader_instance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def dataloader_by_name(readerclass,
5959
if need_split_files:
6060
files = split_files(files, context["fleet"].worker_index(),
6161
context["fleet"].worker_num())
62-
62+
context["file_list"] = files
6363
reader = reader_class(yaml_file)
6464
reader.init()
6565

@@ -121,7 +121,7 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
121121
if need_split_files:
122122
files = split_files(files, context["fleet"].worker_index(),
123123
context["fleet"].worker_num())
124-
124+
context["file_list"] = files
125125
sparse = get_global_env(name + "sparse_slots", "#")
126126
if sparse == "":
127127
sparse = "#"
@@ -191,7 +191,7 @@ def slotdataloader(readerclass, train, yaml_file, context):
191191
if need_split_files:
192192
files = split_files(files, context["fleet"].worker_index(),
193193
context["fleet"].worker_num())
194-
194+
context["file_list"] = files
195195
sparse = get_global_env("sparse_slots", "#", namespace)
196196
if sparse == "":
197197
sparse = "#"

core/utils/util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import os
1717
import sys
1818
import time
19+
import warnings
20+
import random
1921
import numpy as np
2022
from paddle import fluid
2123

@@ -223,6 +225,16 @@ def check_filelist(hidden_file_list, data_file_list, train_data_path):
223225
return hidden_file_list, data_file_list
224226

225227

228+
def shuffle_files(need_shuffle_files, filelist):
229+
if not isinstance(need_shuffle_files, bool):
230+
raise ValueError(
231+
"In your config yaml, 'shuffle_filelist': %s must be written as a boolean type,such as True or False"
232+
% need_shuffle_files)
233+
elif need_shuffle_files:
234+
random.shuffle(filelist)
235+
return filelist
236+
237+
226238
class CostPrinter(object):
227239
"""
228240
For count cost time && print cost log

0 commit comments

Comments
 (0)