Skip to content

Commit a1a1bd6

Browse files
committed
doc: update docs/pytorch.md #649
1 parent 739b180 commit a1a1bd6

File tree

1 file changed

+79
-29
lines changed

1 file changed

+79
-29
lines changed

docs/pytorch.md

Lines changed: 79 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Pytorch 是一种开源机器学习框架,可加速从研究原型设计到生
1313
- [Pytorch 官方备忘清单](https://pytorch.org/tutorials/beginner/ptcheat.html) _(pytorch.org)_
1414

1515
### 认识 Pytorch
16+
<!--rehype:wrap-class=row-span-2-->
1617

1718
```python
1819
from __future__ import print_function
@@ -32,6 +33,7 @@ tensor([
3233
Tensors 张量: 张量的概念类似于Numpy中的ndarray数据结构, 最大的区别在于Tensor可以利用GPU的加速功能.
3334

3435
### 创建一个全零矩阵
36+
<!--rehype:wrap-class=row-span-2-->
3537

3638
```python
3739
x = torch.zeros(5, 3, dtype=torch.long)
@@ -95,6 +97,7 @@ tensor([[ 1.6978, -1.6979, 0.3093],
9597
```
9698

9799
### 加法操作(4)
100+
<!--rehype:wrap-class=row-span-2-->
98101

99102
```python
100103
y.add_(x)
@@ -118,6 +121,7 @@ tensor([-2.0902, -0.4489, -0.1441, 0.8035, -0.8341])
118121
<!--rehype:className=wrap-text-->
119122

120123
### 张量形状
124+
<!--rehype:wrap-class=row-span-2-->
121125

122126
```python
123127
x = torch.randn(4, 4)
@@ -178,61 +182,81 @@ tensor([2., 2., 2., 2., 2.], dtype=torch.float64)
178182

179183
```python
180184
>>> x = torch.rand(1, 2, 1, 28, 1)
181-
>>> x.squeeze().shape # squeeze不加参数,默认去除所有为1的维度
185+
186+
# squeeze不加参数,默认去除所有为1的维度
187+
>>> x.squeeze().shape
182188
torch.Size([2, 28])
183-
>>> x.squeeze(dim=0).shape # squeeze加参数,去除指定为1的维度
189+
190+
# squeeze加参数,去除指定为1的维度
191+
>>> x.squeeze(dim=0).shape
184192
torch.Size([2, 1, 28, 1])
185-
>>> x.squeeze(1).shape # squeeze加参数,如果不为1,则不变
193+
194+
# squeeze加参数,如果不为1,则不变
195+
>>> x.squeeze(1).shape
186196
torch.Size([1, 2, 1, 28, 1])
187-
>>> torch.squeeze(x,-1).shape # 既可以是函数,也可以是方法
197+
198+
# 既可以是函数,也可以是方法
199+
>>> torch.squeeze(x,-1).shape
188200
torch.Size([1, 2, 1, 28])
189201
```
190202

191203
### unsqueeze函数
192204

193205
```python
194206
>>> x = torch.rand(2, 28)
195-
>>> x.unsqueeze(0).shape # unsqueeze必须加参数, _ 2 _ 28 _
196-
torch.Size([1, 2, 28]) # 参数代表在哪里添加维度 0 1 2
197-
>>> torch.unsqueeze(x, -1).shape # 既可以是函数,也可以是方法
207+
# unsqueeze必须加参数, _ 2 _ 28 _
208+
>>> x.unsqueeze(0).shape
209+
# 参数代表在哪里添加维度 0 1 2
210+
torch.Size([1, 2, 28])
211+
# 既可以是函数,也可以是方法
212+
>>> torch.unsqueeze(x, -1).shape
198213
torch.Size([2, 28, 1])
199214
```
200215

201216
Cuda 相关
202217
---
218+
203219
### 检查 Cuda 是否可用
220+
204221
```python
205222
>>> import torch.cuda
206223
>>> torch.cuda.is_available()
207224
>>> True
208225
```
226+
209227
### 列出 GPU 设备
228+
<!--rehype:wrap-class=col-span-2 row-span-2-->
229+
210230
```python
211231
import torch
232+
212233
device_count = torch.cuda.device_count()
213234
print("CUDA 设备")
235+
214236
for i in range(device_count):
215237
device_name = torch.cuda.get_device_name(i)
216238
total_memory = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3)
217239
print(f"├── 设备 {i}: {device_name}, 容量: {total_memory:.2f} GiB")
240+
218241
print("└── (结束)")
219242
```
243+
220244
### 将模型、张量等数据在 GPU 和内存之间进行搬运
245+
221246
```python
222247
import torch
223-
# Replace 0 to your GPU device index. or use "cuda" directly.
248+
# 0 替换为您的 GPU 设备索引或者直接使用 "cuda"
224249
device = f"cuda:0"
225-
# Move to GPU
250+
# 移动到GPU
226251
tensor_m = torch.tensor([1, 2, 3])
227252
tensor_g = tensor_m.to(device)
228253
model_m = torch.nn.Linear(1, 1)
229254
model_g = model_m.to(device)
230-
# Move back.
255+
# 向后移动
231256
tensor_m = tensor_g.cpu()
232257
model_m = model_g.cpu()
233258
```
234259

235-
236260
导入 Imports
237261
---
238262

@@ -241,61 +265,87 @@ model_m = model_g.cpu()
241265
```python
242266
# 根包
243267
import torch
244-
# 数据集表示和加载
268+
```
269+
270+
数据集表示和加载
271+
272+
```python
245273
from torch.utils.data import Dataset, DataLoader
246274
```
247275
<!--rehype:className=wrap-text-->
248276

249277
### 神经网络 API
278+
<!--rehype:wrap-class=row-span-2-->
250279

251280
```python
252281
# 计算图
253282
import torch.autograd as autograd
254283
# 计算图中的张量节点
255284
from torch import Tensor
256-
# 神经网络
285+
```
286+
287+
神经网络
288+
289+
```python
257290
import torch.nn as nn
291+
258292
# 层、激活等
259293
import torch.nn.functional as F
260294
# 优化器,例如 梯度下降、ADAM等
261295
import torch.optim as optim
262-
# 混合前端装饰器和跟踪 jit
263-
from torch.jit import script, trace
264296
```
265297

266-
### Torchscript 和 JIT
267-
268-
```python
269-
torch.jit.trace()
270-
```
271-
272-
使用你的模块或函数和一个例子,数据输入,并追溯计算步骤,数据在模型中前进时遇到的情况
298+
混合前端装饰器和跟踪 jit
273299

274300
```python
275-
@script
301+
from torch.jit import script, trace
276302
```
277303

278-
装饰器用于指示被跟踪代码中的数据相关控制流
279-
280304
### ONNX
305+
<!--rehype:wrap-class=row-span-2-->
281306

282307
```python
283308
torch.onnx.export(model, dummy data, xxxx.proto)
284309
# 导出 ONNX 格式
285310
# 使用经过训练的模型模型,dummy
286311
# 数据和所需的文件名
312+
```
313+
<!--rehype:className=wrap-text-->
287314

315+
加载 ONNX 模型
316+
317+
```python
288318
model = onnx.load("alexnet.proto")
289-
# 加载 ONNX 模型
319+
```
320+
321+
检查模型,IT 是否结构良好
322+
323+
```python
290324
onnx.checker.check_model(model)
291-
# 检查模型,IT 是否结构良好
325+
```
326+
327+
打印一个人类可读的,图的表示
292328

329+
```python
293330
onnx.helper.printable_graph(model.graph)
294-
# 打印一个人类可读的,图的表示
295331
```
296-
<!--rehype:className=wrap-text-->
332+
333+
### Torchscript 和 JIT
334+
335+
```python
336+
torch.jit.trace()
337+
```
338+
339+
使用你的模块或函数和一个例子,数据输入,并追溯计算步骤,数据在模型中前进时遇到的情况
340+
341+
```python
342+
@script
343+
```
344+
345+
装饰器用于指示被跟踪代码中的数据相关控制流
297346

298347
### Vision
348+
<!--rehype:wrap-class=col-span-2-->
299349

300350
```python
301351
# 视觉数据集,架构 & 变换

0 commit comments

Comments
 (0)