Skip to content

Commit ca79444

Browse files
support bf16 loss in static (PaddlePaddle#7874)
1 parent eafa066 commit ca79444

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

llm/llama/auto_parallel/run_pretrain_auto.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import paddle
2828
import paddle.distributed as dist
2929
import paddle.distributed.auto_parallel as auto
30+
from paddle.base.data_feeder import convert_uint16_to_float
3031
from paddle.profiler.utils import job_schedule_profiler_range
3132

3233
from paddlenlp.ops import Topology
@@ -668,7 +669,10 @@ def loss_func(loss, outputs):
668669
outs = engine.run(micro_batch, mode="train")
669670

670671
if "loss" in outs:
671-
tr_loss_step = np.sum(outs["loss"])
672+
if outs["loss"].dtype == np.uint16:
673+
tr_loss_step = np.sum(convert_uint16_to_float(outs["loss"]))
674+
else:
675+
tr_loss_step = np.sum(outs["loss"])
672676
else:
673677
tr_loss_step = float(0)
674678

0 commit comments

Comments
 (0)