Skip to content

Commit 9c3ae1d

Browse files
authored
[Cherry-pick] Cherry-pick from fleety (#11047)
* add timer log in trainer (#10880) * add layer norm backward (#10886) * add memory usage message in tensorboard (#10887)
1 parent da9ddf7 commit 9c3ae1d

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import paddle.amp.auto_cast as autocast
4040
import paddle.distributed as dist
4141
import paddle.nn as nn
42+
import psutil
4243
from packaging import version
4344
from paddle import framework
4445
from paddle.distributed.fleet.meta_parallel import PipelineLayer
@@ -3201,6 +3202,14 @@ def log(self, logs: Dict[str, float], **kwargs) -> None:
32013202

32023203
if self.state.epoch is not None:
32033204
logs["progress_or_epoch"] = round(self.state.epoch, 4)
3205+
3206+
if self.timers:
3207+
logs.update(self.timers.info(self.timers.timers.keys()))
3208+
3209+
mem_info = psutil.virtual_memory()
3210+
logs["cpu_used_memory"] = round(mem_info.used / (1024**3), 2)
3211+
logs["cpu_available_memory"] = round(mem_info.available / (1024**3), 2)
3212+
32043213
self.state.log_history = []
32053214
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs, **kwargs)
32063215

slm/model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,5 +237,16 @@ PD_BUILD_GRAD_OP(fused_rms_norm)
237237
#endif
238238
;
239239

240+
PD_BUILD_OP(fused_rms_norm_grad_func)
241+
.Inputs({"x", "scale", "invvar", "dy"})
242+
.Outputs({"dx", "d_scale"})
243+
.Attrs({"epsilon: float"})
244+
.SetKernelFn(PD_KERNEL(RMSLnBwd))
245+
.SetInferShapeFn(PD_INFER_SHAPE(RMSLnBwdInferShape))
246+
#ifdef CUSTOM_OP_WITH_SPMD
247+
.SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::RmsNormGradInferSpmd))
248+
#endif
249+
;
250+
240251

241252
// https://github.com/NVIDIA/apex/blob/85e9eddece9d4ac72b48c2407f8162f2173e1bf4/csrc/layer_norm_cuda_kernel.cu#L679

0 commit comments

Comments
 (0)