Skip to content

Commit 090bff4

Browse files
[Fea] Support nvtx profiling (#825)
* support nvtx profiling via NVTX=1 * rename trainer to solver * add user guide for NVTX * refine code of paddle.framework.core and add cache for 3 context_manager of Solver * update nsys chapter in user_guide * fix solver.py
1 parent cd21d80 commit 090bff4

File tree

5 files changed

+168
-52
lines changed

5 files changed

+168
-52
lines changed

docs/zh/user_guide.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,3 +795,49 @@ solver = ppsci.solver.Solver(
795795
!!! info "影响说明"
796796
797797
个别多任务学习方法(如weight based method)可能会改变**训练过程**中损失函数的计算方式,但仅限于影响训练过程,模型**评估过程**的损失计算方式保持不变。
798+
799+
## 3. 使用 Nsight 进行性能分析
800+
801+
Nsight是NVIDIA面相开发者提供的开发工具套件,能提供深入的跟踪、调试、评测和分析,以优化跨 NVIDIA GPU和CPU的复杂计算应用程序。详细文档可参考:[Nsight Systems Document](https://docs.nvidia.com/nsight-systems/index.html)
802+
803+
PaddleScience 初步支持使用 Nsight 进行性能分析,以 linux 开发环境 + laplace2d 案例为例,按照如下步骤使用 nsight 工具生成性能分析报告并查看分析结果。
804+
805+
1. 安装 nsight-system
806+
807+
开发机上下载 linux nsight-system 软件:nsight-systems/2023.4.1,并将 nsight 添加到环境变量 `PATH` 中:
808+
809+
执行:`PATH=/path/to/nsight-systems/2023.4.1/bin:$PATH`,同时在 windows 机器上安装**相同版本**的 nsight-system 软件。
810+
811+
2. 用 nsys 命令运行程序,生成性能分析文件
812+
813+
``` sh
814+
{==NVTX=1 nsys profile -t cuda,nvtx --stats=true -o==} {++laplace2d++} python laplace2d.py
815+
```
816+
817+
3. 查看分析结果
818+
819+
程序结束后,在终端内会打印出性能分析数据(如下所示),同时在上述 `-o` 参数指定的相对文件路径生成 `{++laplace2d++}.nsys-rep` 和 `{++laplace2d++}.sqlite` 两个文件。
820+
821+
在 windows 上使用 NVIDIA Nsight Systems 软件打开 `laplace2d.nsys-rep`,即可在图形化的界面上查看性能分析数据。
822+
823+
``` log
824+
...
825+
...
826+
Only run 25 steps when 'NVTX' is set in environment for nsight analysis. Exit now ......
827+
828+
Generating '/tmp/nsys-report-18e4.qdstrm'
829+
[1/7] [========================100%] laplace2d.nsys-rep
830+
[2/7] [========================100%] laplace2d.sqlite
831+
[3/7] Executing 'nvtx_sum' stats report
832+
833+
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range
834+
-------- --------------- --------- ------------- ------------- ----------- ----------- ------------- ------- ------------------------------------
835+
15.1 794,212,341 25 31,768,493.6 5,446,410.0 5,328,471 661,841,104 131,265,333.9 PushPop Loss computation
836+
14.5 766,452,142 25 30,658,085.7 4,369,873.0 4,281,927 659,795,434 131,070,475.4 PushPop Constraint EQ
837+
13.0 687,324,359 1,300 528,711.0 32,567.5 21,218 641,625,892 17,794,532.4 PushPop matmul dygraph
838+
12.9 678,475,194 1 678,475,194.0 678,475,194.0 678,475,194 678,475,194 0.0 PushPop Training iteration 1
839+
12.8 673,614,062 1,300 518,164.7 19,802.5 14,499 641,525,121 17,792,027.2 PushPop matmul compute
840+
3.9 203,945,648 25 8,157,825.9 8,029,819.0 7,797,185 9,119,496 359,173.3 PushPop Loss backward
841+
...
842+
...
843+
```

ppsci/solver/printer.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -30,59 +30,59 @@
3030

3131

3232
def update_train_loss(
33-
trainer: "solver.Solver", loss_dict: Dict[str, float], batch_size: int
33+
solver: "solver.Solver", loss_dict: Dict[str, float], batch_size: int
3434
):
3535
for key in loss_dict:
36-
if key not in trainer.train_output_info:
37-
trainer.train_output_info[key] = misc.AverageMeter(key, "7.5f")
38-
trainer.train_output_info[key].update(float(loss_dict[key]), batch_size)
39-
if key not in trainer.train_loss_info:
40-
trainer.train_loss_info[key] = misc.AverageMeter(key, ".5f")
41-
trainer.train_loss_info[key].update(float(loss_dict[key]))
36+
if key not in solver.train_output_info:
37+
solver.train_output_info[key] = misc.AverageMeter(key, "7.5f")
38+
solver.train_output_info[key].update(float(loss_dict[key]), batch_size)
39+
if key not in solver.train_loss_info:
40+
solver.train_loss_info[key] = misc.AverageMeter(key, ".5f")
41+
solver.train_loss_info[key].update(float(loss_dict[key]))
4242

4343

4444
def update_eval_loss(
45-
trainer: "solver.Solver", loss_dict: Dict[str, float], batch_size: int
45+
solver: "solver.Solver", loss_dict: Dict[str, float], batch_size: int
4646
):
4747
for key in loss_dict:
48-
if key not in trainer.eval_output_info:
49-
trainer.eval_output_info[key] = misc.AverageMeter(key, "7.5f")
50-
trainer.eval_output_info[key].update(float(loss_dict[key]), batch_size)
48+
if key not in solver.eval_output_info:
49+
solver.eval_output_info[key] = misc.AverageMeter(key, "7.5f")
50+
solver.eval_output_info[key].update(float(loss_dict[key]), batch_size)
5151

5252

5353
def log_train_info(
54-
trainer: "solver.Solver", batch_size: int, epoch_id: int, iter_id: int
54+
solver: "solver.Solver", batch_size: int, epoch_id: int, iter_id: int
5555
):
56-
lr_msg = f"lr: {trainer.optimizer.get_lr():.5f}"
56+
lr_msg = f"lr: {solver.optimizer.get_lr():.5f}"
5757

5858
metric_msg = ", ".join(
5959
[
60-
f"{key}: {trainer.train_output_info[key].avg:.5f}"
61-
for key in trainer.train_output_info
60+
f"{key}: {solver.train_output_info[key].avg:.5f}"
61+
for key in solver.train_output_info
6262
]
6363
)
6464

6565
time_msg = ", ".join(
66-
[trainer.train_time_info[key].mean for key in trainer.train_time_info]
66+
[solver.train_time_info[key].mean for key in solver.train_time_info]
6767
)
6868

69-
ips_msg = f"ips: {batch_size / trainer.train_time_info['batch_cost'].avg:.2f}"
70-
if trainer.benchmark_flag:
69+
ips_msg = f"ips: {batch_size / solver.train_time_info['batch_cost'].avg:.2f}"
70+
if solver.benchmark_flag:
7171
ips_msg += " samples/s"
7272

7373
eta_sec = (
74-
(trainer.epochs - epoch_id + 1) * trainer.iters_per_epoch - iter_id
75-
) * trainer.train_time_info["batch_cost"].avg
74+
(solver.epochs - epoch_id + 1) * solver.iters_per_epoch - iter_id
75+
) * solver.train_time_info["batch_cost"].avg
7676
eta_msg = f"eta: {str(datetime.timedelta(seconds=int(eta_sec)))}"
7777

78-
epoch_width = len(str(trainer.epochs))
79-
iters_width = len(str(trainer.iters_per_epoch))
78+
epoch_width = len(str(solver.epochs))
79+
iters_width = len(str(solver.iters_per_epoch))
8080
log_str = (
81-
f"[Train][Epoch {epoch_id:>{epoch_width}}/{trainer.epochs}]"
82-
f"[Iter {iter_id:>{iters_width}}/{trainer.iters_per_epoch}] {lr_msg}, "
81+
f"[Train][Epoch {epoch_id:>{epoch_width}}/{solver.epochs}]"
82+
f"[Iter {iter_id:>{iters_width}}/{solver.iters_per_epoch}] {lr_msg}, "
8383
f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}"
8484
)
85-
if trainer.benchmark_flag:
85+
if solver.benchmark_flag:
8686
max_mem_reserved_msg = (
8787
f"max_mem_reserved: {device.cuda.max_memory_reserved() // (1 << 20)} MB"
8888
)
@@ -94,57 +94,57 @@ def log_train_info(
9494

9595
logger.scalar(
9696
{
97-
"train/lr": trainer.optimizer.get_lr(),
97+
"train/lr": solver.optimizer.get_lr(),
9898
**{
99-
f"train/{key}": trainer.train_output_info[key].avg
100-
for key in trainer.train_output_info
99+
f"train/{key}": solver.train_output_info[key].avg
100+
for key in solver.train_output_info
101101
},
102102
},
103-
step=trainer.global_step,
104-
vdl_writer=trainer.vdl_writer,
105-
wandb_writer=trainer.wandb_writer,
106-
tbd_writer=trainer.tbd_writer,
103+
step=solver.global_step,
104+
vdl_writer=solver.vdl_writer,
105+
wandb_writer=solver.wandb_writer,
106+
tbd_writer=solver.tbd_writer,
107107
)
108108

109109

110110
def log_eval_info(
111-
trainer: "solver.Solver",
111+
solver: "solver.Solver",
112112
batch_size: int,
113113
epoch_id: int,
114114
iters_per_epoch: int,
115115
iter_id: int,
116116
):
117117
metric_msg = ", ".join(
118118
[
119-
f"{key}: {trainer.eval_output_info[key].avg:.5f}"
120-
for key in trainer.eval_output_info
119+
f"{key}: {solver.eval_output_info[key].avg:.5f}"
120+
for key in solver.eval_output_info
121121
]
122122
)
123123

124124
time_msg = ", ".join(
125-
[trainer.eval_time_info[key].mean for key in trainer.eval_time_info]
125+
[solver.eval_time_info[key].mean for key in solver.eval_time_info]
126126
)
127127

128-
ips_msg = f"ips: {batch_size / trainer.eval_time_info['batch_cost'].avg:.2f}"
128+
ips_msg = f"ips: {batch_size / solver.eval_time_info['batch_cost'].avg:.2f}"
129129

130-
eta_sec = (iters_per_epoch - iter_id) * trainer.eval_time_info["batch_cost"].avg
130+
eta_sec = (iters_per_epoch - iter_id) * solver.eval_time_info["batch_cost"].avg
131131
eta_msg = f"eta: {str(datetime.timedelta(seconds=int(eta_sec)))}"
132132

133-
epoch_width = len(str(trainer.epochs))
133+
epoch_width = len(str(solver.epochs))
134134
iters_width = len(str(iters_per_epoch))
135135
logger.info(
136-
f"[Eval][Epoch {epoch_id:>{epoch_width}}/{trainer.epochs}]"
136+
f"[Eval][Epoch {epoch_id:>{epoch_width}}/{solver.epochs}]"
137137
f"[Iter {iter_id:>{iters_width}}/{iters_per_epoch}] "
138138
f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}"
139139
)
140140

141141
logger.scalar(
142142
{
143-
f"eval/{key}": trainer.eval_output_info[key].avg
144-
for key in trainer.eval_output_info
143+
f"eval/{key}": solver.eval_output_info[key].avg
144+
for key in solver.eval_output_info
145145
},
146-
step=trainer.global_step,
147-
vdl_writer=trainer.vdl_writer,
148-
wandb_writer=trainer.wandb_writer,
149-
tbd_writer=trainer.tbd_writer,
146+
step=solver.global_step,
147+
vdl_writer=solver.vdl_writer,
148+
wandb_writer=solver.wandb_writer,
149+
tbd_writer=solver.tbd_writer,
150150
)

ppsci/solver/solver.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import contextlib
18+
import functools
1819
import importlib
1920
import itertools
2021
import os
@@ -39,6 +40,7 @@
3940
from paddle import nn
4041
from paddle import optimizer as optim
4142
from paddle.distributed import fleet
43+
from paddle.framework import core
4244
from paddle.static import InputSpec
4345
from typing_extensions import Literal
4446

@@ -444,11 +446,19 @@ def convert_expr(
444446
# set up benchmark flag, will print memory stat if enabled
445447
self.benchmark_flag: bool = os.getenv("BENCHMARK_ROOT", None) is not None
446448

449+
# set up nvtx flag for nsight analysis
450+
self.nvtx_flag: bool = os.getenv("NVTX", None) is not None
451+
self.forward_helper.nvtx_flag = self.nvtx_flag
452+
447453
def train(self) -> None:
448454
"""Training."""
449455
self.global_step = self.best_metric["epoch"] * self.iters_per_epoch
450456
start_epoch = self.best_metric["epoch"] + 1
451457

458+
if self.nvtx_flag:
459+
core.nvprof_start()
460+
core.nvprof_enable_record_event()
461+
452462
for epoch_id in range(start_epoch, self.epochs + 1):
453463
self.train_epoch_func(self, epoch_id, self.log_freq)
454464
self.train_output_info.clear()
@@ -764,6 +774,7 @@ def export(
764774
)
765775
logger.message(f"ONNX model has been exported to: {export_path}.onnx")
766776

777+
@functools.lru_cache()
767778
def autocast_context_manager(
768779
self, enable: bool, level: Literal["O0", "O1", "O2", "OD"] = "O1"
769780
) -> contextlib.AbstractContextManager:
@@ -786,6 +797,7 @@ def autocast_context_manager(
786797
)
787798
return ctx_manager
788799

800+
@functools.lru_cache()
789801
def no_grad_context_manager(
790802
self, enable: bool
791803
) -> contextlib.AbstractContextManager:
@@ -807,6 +819,7 @@ def no_grad_context_manager(
807819
)
808820
return ctx_manager
809821

822+
@functools.lru_cache()
810823
def no_sync_context_manager(
811824
self,
812825
enable: bool,

ppsci/solver/train.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414

1515
from __future__ import annotations
1616

17+
import sys
1718
import time
1819
from typing import TYPE_CHECKING
1920

2021
import paddle
2122
from paddle.distributed.fleet.utils import hybrid_parallel_util as hpu
23+
from paddle.framework import core
2224

2325
from ppsci.solver import printer
2426
from ppsci.utils import misc
@@ -38,6 +40,11 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
3840
batch_tic = time.perf_counter()
3941

4042
for iter_id in range(1, solver.iters_per_epoch + 1):
43+
if solver.nvtx_flag: # only for nsight analysis
44+
core.nvprof_nvtx_push(
45+
f"Training iteration {solver.global_step + 1}"
46+
) # Training iteration
47+
4148
total_loss = 0.0
4249
total_batch_size = 0
4350
reader_cost = 0.0
@@ -77,6 +84,9 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
7784
# forward for every constraint, including model and equation expression
7885
with solver.no_sync_context_manager(solver.world_size > 1, solver.model):
7986
with solver.autocast_context_manager(solver.use_amp, solver.amp_level):
87+
if solver.nvtx_flag: # only for nsight analysis
88+
core.nvprof_nvtx_push("Loss computation")
89+
8090
constraint_losses = solver.forward_helper.train_forward(
8191
tuple(
8292
_constraint.output_expr
@@ -88,17 +98,31 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
8898
label_dicts,
8999
weight_dicts,
90100
)
101+
102+
if solver.nvtx_flag: # only for nsight analysis
103+
core.nvprof_nvtx_pop() # Loss computation
104+
91105
# accumulate all losses
106+
if solver.nvtx_flag: # only for nsight analysis
107+
core.nvprof_nvtx_push("Loss aggregator")
108+
92109
for i, _constraint in enumerate(solver.constraint.values()):
93110
total_loss += constraint_losses[i]
94111
loss_dict[_constraint.name] += (
95112
float(constraint_losses[i]) / solver.update_freq
96113
)
97114
if solver.update_freq > 1:
98115
total_loss = total_loss / solver.update_freq
116+
117+
if solver.nvtx_flag: # only for nsight analysis
118+
core.nvprof_nvtx_pop() # Loss aggregator
119+
99120
loss_dict["loss"] = float(total_loss)
100121

101122
# backward
123+
if solver.nvtx_flag: # only for nsight analysis
124+
core.nvprof_nvtx_push("Loss backward")
125+
102126
if solver.loss_aggregator is None:
103127
if solver.use_amp:
104128
total_loss_scaled = solver.scaler.scale(total_loss)
@@ -108,8 +132,14 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
108132
else:
109133
solver.loss_aggregator(constraint_losses, solver.global_step).backward()
110134

135+
if solver.nvtx_flag: # only for nsight analysis
136+
core.nvprof_nvtx_pop() # Loss backward
137+
111138
# update parameters
112139
if iter_id % solver.update_freq == 0 or iter_id == solver.iters_per_epoch:
140+
if solver.nvtx_flag: # only for nsight analysis
141+
core.nvprof_nvtx_push("Optimizer update")
142+
113143
if solver.world_size > 1:
114144
# fuse + allreduce manually before optimization if use DDP + no_sync
115145
# details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
@@ -118,6 +148,10 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
118148
solver.scaler.minimize(solver.optimizer, total_loss_scaled)
119149
else:
120150
solver.optimizer.step()
151+
152+
if solver.nvtx_flag: # only for nsight analysis
153+
core.nvprof_nvtx_pop() # Optimizer update
154+
121155
solver.optimizer.clear_grad()
122156

123157
# update learning rate by step
@@ -138,6 +172,17 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
138172

139173
batch_tic = time.perf_counter()
140174

175+
if solver.nvtx_flag: # only for nsight analysis
176+
core.nvprof_nvtx_pop() # Training iteration
177+
NVTX_STOP_ITER = 25
178+
if solver.global_step >= NVTX_STOP_ITER:
179+
print(
180+
f"Only run {NVTX_STOP_ITER} steps when 'NVTX' is set in environment"
181+
" for nsight analysis. Exit now ......\n"
182+
)
183+
core.nvprof_stop()
184+
sys.exit(0)
185+
141186

142187
def train_LBFGS_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
143188
"""Train function for one epoch with L-BFGS optimizer.

0 commit comments

Comments
 (0)