|
| 1 | +# 7.5 使用SwanLab可视化训练过程 |
| 2 | + |
| 3 | +在上一节中,我们使用了Wandb可视化训练过程,但是Wandb将数据存储在海外,在国内的网络环境下访问速度较慢,且容易断连。SwanLab是一个由中国团队开发的训练可视化平台,国内访问稳定流畅,在功能上支持自动记录模型训练过程中的超参数和输出指标,然后可视化和比较结果,并快速与其他人共享结果。目前它还支持监控昇腾NPU的训练情况,能够和PyTorch、Keras、MMDetection、LLaMA Factory、LightGBM、XGBoost等框架结合使用。 |
| 4 | + |
| 5 | +经过本节的学习,你将收获: |
| 6 | + |
| 7 | +- SwanLab的安装 |
| 8 | +- SwanLab的基本使用 |
| 9 | +- SwanLab跟踪MNIST案例 |
| 10 | +- SwanLab跟踪YOLO案例 |
| 11 | + |
| 12 | +## 7.5.1 SwanLab的安装 |
| 13 | + |
| 14 | +SwanLab的安装非常简单,我们只需要使用pip安装即可。 |
| 15 | + |
| 16 | +```bash |
| 17 | +pip install swanlab |
| 18 | +``` |
| 19 | + |
| 20 | +安装完成后,我们需要在[官网](https://swanlab.cn/)注册一个账号并复制下自己的API keys,然后在本地使用下面的命令登录。 |
| 21 | + |
| 22 | +```bash |
| 23 | +swanlab login |
| 24 | +``` |
| 25 | + |
| 26 | +这时,我们会看到下面的界面,只需要粘贴你的API keys即可。 |
| 27 | + |
| 28 | + |
| 29 | + |
| 30 | +## 7.5.2 SwanLab的基本使用 |
| 31 | + |
| 32 | +SwanLab的使用也非常简单,只需要在代码中添加几行代码即可,大概分为两步。 |
| 33 | + |
| 34 | +**第一步,初始化项目:** |
| 35 | + |
| 36 | +```python |
| 37 | +import swanlab |
| 38 | + |
| 39 | +swanlab.init(project="my-project", experiment_name="first_exp") |
| 40 | +``` |
| 41 | + |
| 42 | +这里的project和experiment_name是你在swanlab上创建的项目名称和实验名。 |
| 43 | + |
| 44 | +项目和实验的关系有点类似PC中的文件夹和文件的关系,你的每次训练进程都是一个实验,而项目是实验的集合,用来进行多个实验之间的对比与管理。 |
| 45 | + |
| 46 | +**第二步,记录数据:** |
| 47 | + |
| 48 | +```python |
| 49 | +for i in range(10): |
| 50 | + swanlab.log({"loss": 1-0.1*i, "acc": 0.1*i}) |
| 51 | +``` |
| 52 | + |
| 53 | +这里的log是记录指标的函数,它接收一个字典,字典的key是指标的名称,value是指标的值。 |
| 54 | + |
| 55 | + |
| 56 | +**Hello World代码** |
| 57 | + |
| 58 | +```python |
| 59 | +import swanlab |
| 60 | +import random |
| 61 | + |
| 62 | +# 创建一个SwanLab项目 |
| 63 | +swanlab.init( |
| 64 | + # 设置项目名 |
| 65 | + project="my-awesome-project", |
| 66 | + |
| 67 | + # 设置超参数 |
| 68 | + config={ |
| 69 | + "learning_rate": 0.02, |
| 70 | + "architecture": "CNN", |
| 71 | + "dataset": "CIFAR-100", |
| 72 | + "epochs": 10 |
| 73 | + } |
| 74 | +) |
| 75 | + |
| 76 | +# 模拟一次训练 |
| 77 | +epochs = 10 |
| 78 | +offset = random.random() / 5 |
| 79 | +for epoch in range(2, epochs): |
| 80 | + acc = 1 - 2 ** -epoch - random.random() / epoch - offset |
| 81 | + loss = 2 ** -epoch + random.random() / epoch + offset |
| 82 | + |
| 83 | + # 记录训练指标 |
| 84 | + swanlab.log({"acc": acc, "loss": loss}) |
| 85 | + |
| 86 | +# [可选] 完成训练,这在notebook环境中是必要的 |
| 87 | +swanlab.finish() |
| 88 | +``` |
| 89 | + |
| 90 | +当我们运行完上面的代码后,就可以在swanlab的界面看到我们的训练结果了: |
| 91 | + |
| 92 | + |
| 93 | + |
| 94 | + |
| 95 | +## 7.5.3 SwanLab跟踪MNIST案例 |
| 96 | + |
| 97 | +下面我们使用一个MNSIT手写体识别的demo来演示SwanLab的使用。 [预览链接](https://swanlab.cn/@ZeyiLin/MNIST-example/runs/4plp6w0qehoqpt0uq2tcy/chart)。 |
| 98 | + |
| 99 | + |
| 100 | +```python |
| 101 | +import os |
| 102 | +import torch |
| 103 | +from torch import nn, optim, utils |
| 104 | +import torch.nn.functional as F |
| 105 | +import torchvision |
| 106 | +from torchvision.datasets import MNIST |
| 107 | +from torchvision.transforms import ToTensor |
| 108 | +import swanlab |
| 109 | + |
| 110 | +# CNN网络构建 |
| 111 | +class ConvNet(nn.Module): |
| 112 | + def __init__(self): |
| 113 | + super().__init__() |
| 114 | + # 1,28x28 |
| 115 | + self.conv1 = nn.Conv2d(1, 10, 5) # 10, 24x24 |
| 116 | + self.conv2 = nn.Conv2d(10, 20, 3) # 128, 10x10 |
| 117 | + self.fc1 = nn.Linear(20 * 10 * 10, 500) |
| 118 | + self.fc2 = nn.Linear(500, 10) |
| 119 | + |
| 120 | + def forward(self, x): |
| 121 | + in_size = x.size(0) |
| 122 | + out = self.conv1(x) # 24 |
| 123 | + out = F.relu(out) |
| 124 | + out = F.max_pool2d(out, 2, 2) # 12 |
| 125 | + out = self.conv2(out) # 10 |
| 126 | + out = F.relu(out) |
| 127 | + out = out.view(in_size, -1) |
| 128 | + out = self.fc1(out) |
| 129 | + out = F.relu(out) |
| 130 | + out = self.fc2(out) |
| 131 | + out = F.log_softmax(out, dim=1) |
| 132 | + return out |
| 133 | + |
| 134 | + |
| 135 | +# 捕获并可视化前20张图像 |
| 136 | +def log_images(loader, num_images=16): |
| 137 | + images_logged = 0 |
| 138 | + logged_images = [] |
| 139 | + for images, labels in loader: |
| 140 | + # images: batch of images, labels: batch of labels |
| 141 | + for i in range(images.shape[0]): |
| 142 | + if images_logged < num_images: |
| 143 | + # 使用swanlab.Image将图像转换为wandb可视化格式 |
| 144 | + logged_images.append(swanlab.Image(images[i], caption=f"Label: {labels[i]}")) |
| 145 | + images_logged += 1 |
| 146 | + else: |
| 147 | + break |
| 148 | + if images_logged >= num_images: |
| 149 | + break |
| 150 | + swanlab.log({"MNIST-Preview": logged_images}) |
| 151 | + |
| 152 | + |
| 153 | +def train(model, device, train_dataloader, optimizer, criterion, epoch, num_epochs): |
| 154 | + model.train() |
| 155 | + # 1. 循环调用train_dataloader,每次取出1个batch_size的图像和标签 |
| 156 | + for iter, (inputs, labels) in enumerate(train_dataloader): |
| 157 | + inputs, labels = inputs.to(device), labels.to(device) |
| 158 | + optimizer.zero_grad() |
| 159 | + # 2. 传入到resnet18模型中得到预测结果 |
| 160 | + outputs = model(inputs) |
| 161 | + # 3. 将结果和标签传入损失函数中计算交叉熵损失 |
| 162 | + loss = criterion(outputs, labels) |
| 163 | + # 4. 根据损失计算反向传播 |
| 164 | + loss.backward() |
| 165 | + # 5. 优化器执行模型参数更新 |
| 166 | + optimizer.step() |
| 167 | + print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(train_dataloader), |
| 168 | + loss.item())) |
| 169 | + # 6. 每20次迭代,用SwanLab记录一下loss的变化 |
| 170 | + if iter % 20 == 0: |
| 171 | + swanlab.log({"train/loss": loss.item()}) |
| 172 | + |
| 173 | +def test(model, device, val_dataloader, epoch): |
| 174 | + model.eval() |
| 175 | + correct = 0 |
| 176 | + total = 0 |
| 177 | + with torch.no_grad(): |
| 178 | + # 1. 循环调用val_dataloader,每次取出1个batch_size的图像和标签 |
| 179 | + for inputs, labels in val_dataloader: |
| 180 | + inputs, labels = inputs.to(device), labels.to(device) |
| 181 | + # 2. 传入到resnet18模型中得到预测结果 |
| 182 | + outputs = model(inputs) |
| 183 | + # 3. 获得预测的数字 |
| 184 | + _, predicted = torch.max(outputs, 1) |
| 185 | + total += labels.size(0) |
| 186 | + # 4. 计算与标签一致的预测结果的数量 |
| 187 | + correct += (predicted == labels).sum().item() |
| 188 | + |
| 189 | + # 5. 得到最终的测试准确率 |
| 190 | + accuracy = correct / total |
| 191 | + # 6. 用SwanLab记录一下准确率的变化 |
| 192 | + swanlab.log({"val/accuracy": accuracy}, step=epoch) |
| 193 | + |
| 194 | + |
| 195 | +if __name__ == "__main__": |
| 196 | + |
| 197 | + #检测是否支持mps |
| 198 | + try: |
| 199 | + use_mps = torch.backends.mps.is_available() |
| 200 | + except AttributeError: |
| 201 | + use_mps = False |
| 202 | + |
| 203 | + #检测是否支持cuda |
| 204 | + if torch.cuda.is_available(): |
| 205 | + device = "cuda" |
| 206 | + elif use_mps: |
| 207 | + device = "mps" |
| 208 | + else: |
| 209 | + device = "cpu" |
| 210 | + |
| 211 | + # 初始化swanlab |
| 212 | + run = swanlab.init( |
| 213 | + project="MNIST-example", |
| 214 | + experiment_name="PlainCNN", |
| 215 | + config={ |
| 216 | + "model": "ResNet18", |
| 217 | + "optim": "Adam", |
| 218 | + "lr": 1e-4, |
| 219 | + "batch_size": 256, |
| 220 | + "num_epochs": 10, |
| 221 | + "device": device, |
| 222 | + }, |
| 223 | + ) |
| 224 | + |
| 225 | + # 设置MNIST训练集和验证集 |
| 226 | + dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) |
| 227 | + train_dataset, val_dataset = utils.data.random_split(dataset, [55000, 5000]) |
| 228 | + |
| 229 | + train_dataloader = utils.data.DataLoader(train_dataset, batch_size=run.config.batch_size, shuffle=True) |
| 230 | + val_dataloader = utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False) |
| 231 | + |
| 232 | + # (可选)看一下数据集的前16张图像 |
| 233 | + log_images(train_dataloader, 16) |
| 234 | + |
| 235 | + # 初始化模型 |
| 236 | + model = ConvNet() |
| 237 | + model.to(torch.device(device)) |
| 238 | + |
| 239 | + # 打印模型 |
| 240 | + print(model) |
| 241 | + |
| 242 | + # 定义损失函数和优化器 |
| 243 | + criterion = nn.CrossEntropyLoss() |
| 244 | + optimizer = optim.Adam(model.parameters(), lr=run.config.lr) |
| 245 | + |
| 246 | + # 开始训练和测试循环 |
| 247 | + for epoch in range(1, run.config.num_epochs+1): |
| 248 | + swanlab.log({"train/epoch": epoch}, step=epoch) |
| 249 | + train(model, device, train_dataloader, optimizer, criterion, epoch, run.config.num_epochs) |
| 250 | + if epoch % 2 == 0: |
| 251 | + test(model, device, val_dataloader, epoch) |
| 252 | + |
| 253 | + # 保存模型 |
| 254 | + # 如果不存在checkpoint文件夹,则自动创建一个 |
| 255 | + if not os.path.exists("checkpoint"): |
| 256 | + os.makedirs("checkpoint") |
| 257 | + torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth') |
| 258 | +``` |
| 259 | + |
| 260 | +运行代码后,我们查看实验结果: |
| 261 | + |
| 262 | + |
| 263 | + |
| 264 | + |
| 265 | + |
| 266 | +## 7.5.4 SwanLab跟踪YOLO案例 |
| 267 | + |
| 268 | +下面我们使用一个Ultralytics框架训练Yolo模型的demo来演示SwanLab的使用。 [预览链接](https://swanlab.cn/@ZeyiLin/ultratest/runs/yux7vclmsmmsar9ear7u5/chart)。 |
| 269 | + |
| 270 | +```python |
| 271 | +from ultralytics import YOLO |
| 272 | +from swanlab.integration.ultralytics import add_swanlab_callback |
| 273 | + |
| 274 | + |
| 275 | +if __name__ == "__main__": |
| 276 | + model = YOLO("yolov8n.yaml") |
| 277 | + model.load() |
| 278 | + # 添加swanlab回调 |
| 279 | + add_swanlab_callback(model) |
| 280 | + |
| 281 | + model.train( |
| 282 | + data="./coco128.yaml", |
| 283 | + epochs=3, |
| 284 | + imgsz=320, |
| 285 | + ) |
| 286 | +``` |
| 287 | + |
| 288 | + |
| 289 | + |
| 290 | + |
| 291 | + |
| 292 | + |
| 293 | + |
| 294 | +我们可以发现,使用swanlab可以很方便地可视化训练过程和在线查看实验进展。更多功能请见[官方文档](https://docs.swanlab.cn/guide_cloud/general/what-is-swanlab.html)。 |
| 295 | + |
0 commit comments