Skip to content

Commit fc860a3

Browse files
authored
Fix load RNG compatibility. (#8451)
1 parent debb2ad commit fc860a3

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,16 +1591,13 @@ def _load_rng_state(self, checkpoint):
15911591
if os.path.isfile(rng_file):
15921592
rng_file_list = paddle.load(rng_file, return_numpy=True)
15931593
paddle.distributed.broadcast_object_list(rng_file_list, src=0)
1594-
# if rng_file_list still empty, then use old style rng_state
1594+
# if rng_file_list still empty, not log rng state.
15951595
if rng_file_list[0] is None:
1596-
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
1597-
if not os.path.isfile(rng_file):
1598-
logger.info(
1599-
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
1600-
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
1601-
)
1602-
return
1603-
checkpoint_rng_state = paddle.load(rng_file, return_numpy=True)
1596+
logger.info(
1597+
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
1598+
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
1599+
)
1600+
return
16041601
else:
16051602
checkpoint_rng_state = rng_file_list[process_index]
16061603
else:

0 commit comments

Comments
 (0)