|
18 | 18 | import time |
19 | 19 | import warnings |
20 | 20 | import numpy as np |
| 21 | +import random |
21 | 22 | import logging |
22 | 23 | import paddle.fluid as fluid |
23 | 24 |
|
24 | 25 | from paddlerec.core.utils import envs |
| 26 | +from paddlerec.core.utils.util import shuffle_files |
25 | 27 | from paddlerec.core.metric import Metric |
26 | 28 |
|
27 | 29 | logging.basicConfig( |
@@ -92,7 +94,6 @@ def _executor_dataset_train(self, model_dict, context): |
92 | 94 | reader_name = model_dict["dataset_name"] |
93 | 95 | model_name = model_dict["name"] |
94 | 96 | model_class = context["model"][model_dict["name"]]["model"] |
95 | | - |
96 | 97 | fetch_vars = [] |
97 | 98 | fetch_alias = [] |
98 | 99 | fetch_period = int( |
@@ -395,7 +396,12 @@ def run(self, context): |
395 | 396 | for model_dict in context["phases"]: |
396 | 397 | model_class = context["model"][model_dict["name"]]["model"] |
397 | 398 | metrics = model_class._metrics |
398 | | - |
| 399 | + if "shuffle_filelist" in model_dict: |
| 400 | + need_shuffle_files = model_dict.get("shuffle_filelist", |
| 401 | + None) |
| 402 | + filelist = context["file_list"] |
| 403 | + context["file_list"] = shuffle_files(need_shuffle_files, |
| 404 | + filelist) |
399 | 405 | begin_time = time.time() |
400 | 406 | result = self._run(context, model_dict) |
401 | 407 | end_time = time.time() |
@@ -439,6 +445,11 @@ def run(self, context): |
439 | 445 | model_class = context["model"][model_dict["name"]]["model"] |
440 | 446 | metrics = model_class._metrics |
441 | 447 | for epoch in range(epochs): |
| 448 | + if "shuffle_filelist" in model_dict: |
| 449 | + need_shuffle_files = model_dict.get("shuffle_filelist", None) |
| 450 | + filelist = context["file_list"] |
| 451 | + context["file_list"] = shuffle_files(need_shuffle_files, |
| 452 | + filelist) |
442 | 453 | begin_time = time.time() |
443 | 454 | result = self._run(context, model_dict) |
444 | 455 | end_time = time.time() |
@@ -484,6 +495,11 @@ def run(self, context): |
484 | 495 | ".epochs")) |
485 | 496 | model_dict = context["env"]["phase"][0] |
486 | 497 | for epoch in range(epochs): |
| 498 | + if "shuffle_filelist" in model_dict: |
| 499 | + need_shuffle_files = model_dict.get("shuffle_filelist", None) |
| 500 | + filelist = context["file_list"] |
| 501 | + context["file_list"] = shuffle_files(need_shuffle_files, |
| 502 | + filelist) |
487 | 503 | begin_time = time.time() |
488 | 504 | self._run(context, model_dict) |
489 | 505 | end_time = time.time() |
@@ -512,6 +528,11 @@ def run(self, context): |
512 | 528 | envs.get_global_env("runner." + context["runner_name"] + |
513 | 529 | ".epochs")) |
514 | 530 | for epoch in range(epochs): |
| 531 | + if "shuffle_filelist" in model_dict: |
| 532 | + need_shuffle_files = model_dict.get("shuffle_filelist", None) |
| 533 | + filelist = context["file_list"] |
| 534 | + context["file_list"] = shuffle_files(need_shuffle_files, |
| 535 | + filelist) |
515 | 536 | begin_time = time.time() |
516 | 537 | self._run(context, model_dict) |
517 | 538 | end_time = time.time() |
@@ -574,6 +595,12 @@ def run(self, context): |
574 | 595 | metrics = model_class._infer_results |
575 | 596 | self._load(context, model_dict, |
576 | 597 | self.epoch_model_path_list[index]) |
| 598 | + if "shuffle_filelist" in model_dict: |
| 599 | + need_shuffle_files = model_dict.get("shuffle_filelist", |
| 600 | + None) |
| 601 | + filelist = context["file_list"] |
| 602 | + context["file_list"] = shuffle_files(need_shuffle_files, |
| 603 | + filelist) |
577 | 604 | begin_time = time.time() |
578 | 605 | result = self._run(context, model_dict) |
579 | 606 | end_time = time.time() |
|
0 commit comments