Skip to content

Commit b94ed4b

Browse files
authored
Merge pull request #109 from Zeyi-Lin/main
docs: add swanlab visualization
2 parents 0e15a70 + 1654890 commit b94ed4b

11 files changed

+297
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ PyTorch是利用深度学习进行数据科学研究的重要工具,在灵活
6262
- 可视化CNN卷积层
6363
- 使用TensorBoard可视化训练过程
6464
- 使用wandb可视化训练过程
65+
- 使用SwanLab可视化训练过程
6566
- 第八章:PyTorch生态简介
6667
- 简介
6768
- 图像—torchvision
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
# 7.5 使用SwanLab可视化训练过程
2+
3+
在上一节中,我们使用了Wandb可视化训练过程,但是Wandb将数据存储在海外,在国内的网络环境下访问速度较慢,且容易断连。SwanLab是一个由中国团队开发的训练可视化平台,国内访问稳定流畅,在功能上支持自动记录模型训练过程中的超参数和输出指标,然后可视化和比较结果,并快速与其他人共享结果。目前它还支持监控昇腾NPU的训练情况,能够和PyTorch、Keras、MMDetection、LLaMA Factory、LightGBM、XGBoost等框架结合使用。
4+
5+
经过本节的学习,你将收获:
6+
7+
- SwanLab的安装
8+
- SwanLab的基本使用
9+
- SwanLab跟踪MNIST案例
10+
- SwanLab跟踪YOLO案例
11+
12+
## 7.5.1 SwanLab的安装
13+
14+
SwanLab的安装非常简单,我们只需要使用pip安装即可。
15+
16+
```bash
17+
pip install swanlab
18+
```
19+
20+
安装完成后,我们需要在[官网](https://swanlab.cn/)注册一个账号并复制下自己的API keys,然后在本地使用下面的命令登录。
21+
22+
```bash
23+
swanlab login
24+
```
25+
26+
这时,我们会看到下面的界面,只需要粘贴你的API keys即可。
27+
28+
![](./figures/swanlab_login.png)
29+
30+
## 7.5.2 SwanLab的基本使用
31+
32+
SwanLab的使用也非常简单,只需要在代码中添加几行代码即可,大概分为两步。
33+
34+
**第一步,初始化项目:**
35+
36+
```python
37+
import swanlab
38+
39+
swanlab.init(project="my-project", experiment_name="first_exp")
40+
```
41+
42+
这里的project和experiment_name是你在swanlab上创建的项目名称和实验名。
43+
44+
项目和实验的关系有点类似PC中的文件夹和文件的关系,你的每次训练进程都是一个实验,而项目是实验的集合,用来进行多个实验之间的对比与管理。
45+
46+
**第二步,记录数据:**
47+
48+
```python
49+
for i in range(10):
50+
swanlab.log({"loss": 1-0.1*i, "acc": 0.1*i})
51+
```
52+
53+
这里的log是记录指标的函数,它接收一个字典,字典的key是指标的名称,value是指标的值。
54+
55+
56+
**Hello World代码**
57+
58+
```python
59+
import swanlab
60+
import random
61+
62+
# 创建一个SwanLab项目
63+
swanlab.init(
64+
# 设置项目名
65+
project="my-awesome-project",
66+
67+
# 设置超参数
68+
config={
69+
"learning_rate": 0.02,
70+
"architecture": "CNN",
71+
"dataset": "CIFAR-100",
72+
"epochs": 10
73+
}
74+
)
75+
76+
# 模拟一次训练
77+
epochs = 10
78+
offset = random.random() / 5
79+
for epoch in range(2, epochs):
80+
acc = 1 - 2 ** -epoch - random.random() / epoch - offset
81+
loss = 2 ** -epoch + random.random() / epoch + offset
82+
83+
# 记录训练指标
84+
swanlab.log({"acc": acc, "loss": loss})
85+
86+
# [可选] 完成训练,这在notebook环境中是必要的
87+
swanlab.finish()
88+
```
89+
90+
当我们运行完上面的代码后,就可以在swanlab的界面看到我们的训练结果了:
91+
92+
![swanlab hello world](./figures/swanlab_hello_world.png)
93+
94+
95+
## 7.5.3 SwanLab跟踪MNIST案例
96+
97+
下面我们使用一个MNSIT手写体识别的demo来演示SwanLab的使用。 [预览链接](https://swanlab.cn/@ZeyiLin/MNIST-example/runs/4plp6w0qehoqpt0uq2tcy/chart)
98+
99+
100+
```python
101+
import os
102+
import torch
103+
from torch import nn, optim, utils
104+
import torch.nn.functional as F
105+
import torchvision
106+
from torchvision.datasets import MNIST
107+
from torchvision.transforms import ToTensor
108+
import swanlab
109+
110+
# CNN网络构建
111+
class ConvNet(nn.Module):
112+
def __init__(self):
113+
super().__init__()
114+
# 1,28x28
115+
self.conv1 = nn.Conv2d(1, 10, 5) # 10, 24x24
116+
self.conv2 = nn.Conv2d(10, 20, 3) # 128, 10x10
117+
self.fc1 = nn.Linear(20 * 10 * 10, 500)
118+
self.fc2 = nn.Linear(500, 10)
119+
120+
def forward(self, x):
121+
in_size = x.size(0)
122+
out = self.conv1(x) # 24
123+
out = F.relu(out)
124+
out = F.max_pool2d(out, 2, 2) # 12
125+
out = self.conv2(out) # 10
126+
out = F.relu(out)
127+
out = out.view(in_size, -1)
128+
out = self.fc1(out)
129+
out = F.relu(out)
130+
out = self.fc2(out)
131+
out = F.log_softmax(out, dim=1)
132+
return out
133+
134+
135+
# 捕获并可视化前20张图像
136+
def log_images(loader, num_images=16):
137+
images_logged = 0
138+
logged_images = []
139+
for images, labels in loader:
140+
# images: batch of images, labels: batch of labels
141+
for i in range(images.shape[0]):
142+
if images_logged < num_images:
143+
# 使用swanlab.Image将图像转换为wandb可视化格式
144+
logged_images.append(swanlab.Image(images[i], caption=f"Label: {labels[i]}"))
145+
images_logged += 1
146+
else:
147+
break
148+
if images_logged >= num_images:
149+
break
150+
swanlab.log({"MNIST-Preview": logged_images})
151+
152+
153+
def train(model, device, train_dataloader, optimizer, criterion, epoch, num_epochs):
154+
model.train()
155+
# 1. 循环调用train_dataloader,每次取出1个batch_size的图像和标签
156+
for iter, (inputs, labels) in enumerate(train_dataloader):
157+
inputs, labels = inputs.to(device), labels.to(device)
158+
optimizer.zero_grad()
159+
# 2. 传入到resnet18模型中得到预测结果
160+
outputs = model(inputs)
161+
# 3. 将结果和标签传入损失函数中计算交叉熵损失
162+
loss = criterion(outputs, labels)
163+
# 4. 根据损失计算反向传播
164+
loss.backward()
165+
# 5. 优化器执行模型参数更新
166+
optimizer.step()
167+
print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(train_dataloader),
168+
loss.item()))
169+
# 6. 每20次迭代,用SwanLab记录一下loss的变化
170+
if iter % 20 == 0:
171+
swanlab.log({"train/loss": loss.item()})
172+
173+
def test(model, device, val_dataloader, epoch):
174+
model.eval()
175+
correct = 0
176+
total = 0
177+
with torch.no_grad():
178+
# 1. 循环调用val_dataloader,每次取出1个batch_size的图像和标签
179+
for inputs, labels in val_dataloader:
180+
inputs, labels = inputs.to(device), labels.to(device)
181+
# 2. 传入到resnet18模型中得到预测结果
182+
outputs = model(inputs)
183+
# 3. 获得预测的数字
184+
_, predicted = torch.max(outputs, 1)
185+
total += labels.size(0)
186+
# 4. 计算与标签一致的预测结果的数量
187+
correct += (predicted == labels).sum().item()
188+
189+
# 5. 得到最终的测试准确率
190+
accuracy = correct / total
191+
# 6. 用SwanLab记录一下准确率的变化
192+
swanlab.log({"val/accuracy": accuracy}, step=epoch)
193+
194+
195+
if __name__ == "__main__":
196+
197+
#检测是否支持mps
198+
try:
199+
use_mps = torch.backends.mps.is_available()
200+
except AttributeError:
201+
use_mps = False
202+
203+
#检测是否支持cuda
204+
if torch.cuda.is_available():
205+
device = "cuda"
206+
elif use_mps:
207+
device = "mps"
208+
else:
209+
device = "cpu"
210+
211+
# 初始化swanlab
212+
run = swanlab.init(
213+
project="MNIST-example",
214+
experiment_name="PlainCNN",
215+
config={
216+
"model": "ResNet18",
217+
"optim": "Adam",
218+
"lr": 1e-4,
219+
"batch_size": 256,
220+
"num_epochs": 10,
221+
"device": device,
222+
},
223+
)
224+
225+
# 设置MNIST训练集和验证集
226+
dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
227+
train_dataset, val_dataset = utils.data.random_split(dataset, [55000, 5000])
228+
229+
train_dataloader = utils.data.DataLoader(train_dataset, batch_size=run.config.batch_size, shuffle=True)
230+
val_dataloader = utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)
231+
232+
# (可选)看一下数据集的前16张图像
233+
log_images(train_dataloader, 16)
234+
235+
# 初始化模型
236+
model = ConvNet()
237+
model.to(torch.device(device))
238+
239+
# 打印模型
240+
print(model)
241+
242+
# 定义损失函数和优化器
243+
criterion = nn.CrossEntropyLoss()
244+
optimizer = optim.Adam(model.parameters(), lr=run.config.lr)
245+
246+
# 开始训练和测试循环
247+
for epoch in range(1, run.config.num_epochs+1):
248+
swanlab.log({"train/epoch": epoch}, step=epoch)
249+
train(model, device, train_dataloader, optimizer, criterion, epoch, run.config.num_epochs)
250+
if epoch % 2 == 0:
251+
test(model, device, val_dataloader, epoch)
252+
253+
# 保存模型
254+
# 如果不存在checkpoint文件夹,则自动创建一个
255+
if not os.path.exists("checkpoint"):
256+
os.makedirs("checkpoint")
257+
torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth')
258+
```
259+
260+
运行代码后,我们查看实验结果:
261+
262+
![](./figures/swanlab_mnist_1.png)
263+
264+
![](./figures/swanlab_mnist_2.png)
265+
266+
## 7.5.4 SwanLab跟踪YOLO案例
267+
268+
下面我们使用一个Ultralytics框架训练Yolo模型的demo来演示SwanLab的使用。 [预览链接](https://swanlab.cn/@ZeyiLin/ultratest/runs/yux7vclmsmmsar9ear7u5/chart)
269+
270+
```python
271+
from ultralytics import YOLO
272+
from swanlab.integration.ultralytics import add_swanlab_callback
273+
274+
275+
if __name__ == "__main__":
276+
model = YOLO("yolov8n.yaml")
277+
model.load()
278+
# 添加swanlab回调
279+
add_swanlab_callback(model)
280+
281+
model.train(
282+
data="./coco128.yaml",
283+
epochs=3,
284+
imgsz=320,
285+
)
286+
```
287+
288+
![](./figures/swanlab_yolo_1.png)
289+
290+
![](./figures/swanlab_yolo_2.png)
291+
292+
![](./figures/swanlab_yolo_3.png)
293+
294+
我们可以发现,使用swanlab可以很方便地可视化训练过程和在线查看实验进展。更多功能请见[官方文档](https://docs.swanlab.cn/guide_cloud/general/what-is-swanlab.html)
295+
286 KB
Loading
151 KB
Loading
25.7 KB
Loading
125 KB
Loading
74.7 KB
Loading
991 KB
Loading
136 KB
Loading
238 KB
Loading

0 commit comments

Comments
 (0)