Skip to content

Commit 7cf9bf8

Browse files
committed
Pytorch CIFAR10图像分类
1 parent 5091a9c commit 7cf9bf8

8 files changed

+6392
-0
lines changed

CIFAR-10/Pytorch CIFAR10图像分类 AlexNet篇.md

Lines changed: 557 additions & 0 deletions
Large diffs are not rendered by default.

CIFAR-10/Pytorch CIFAR10图像分类 DenseNet篇.md

Lines changed: 1685 additions & 0 deletions
Large diffs are not rendered by default.

CIFAR-10/Pytorch CIFAR10图像分类 GoogLeNet篇.md

Lines changed: 1461 additions & 0 deletions
Large diffs are not rendered by default.

CIFAR-10/Pytorch CIFAR10图像分类 LeNet篇.md

Lines changed: 491 additions & 0 deletions
Large diffs are not rendered by default.

CIFAR-10/Pytorch CIFAR10图像分类 ResNet篇.md

Lines changed: 879 additions & 0 deletions
Large diffs are not rendered by default.

CIFAR-10/Pytorch CIFAR10图像分类 VGG篇.md

Lines changed: 622 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Pytorch CIFAR10图像分类 数据加载与可视化篇
2+
3+
[toc]
4+
5+
这里贴一下汇总篇:[汇总篇](https://blog.csdn.net/weixin_45508265/article/details/119285255)
6+
7+
**Pytorch一般有以下几个流程**
8+
9+
1. 数据读取
10+
2. 数据处理
11+
3. 搭建网络
12+
4. 模型训练
13+
5. 模型上线
14+
15+
这里会先讲一下关于CIFAR10的数据加载和图片可视化,之后的模型篇会对网络进行介绍和实现。
16+
17+
### 1.数据读取
18+
19+
CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( arplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。
20+
21+
与 MNIST 数据集中目比, CIFAR-10 具有以下不同点:
22+
23+
- CIFAR-10 是 3 通道的彩色 RGB 图像,而 MNIST 是灰度图像。
24+
- CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。
25+
- 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。
26+
27+
![在这里插入图片描述](https://img-blog.csdnimg.cn/16f85f24a70e452e8659a1874616420f.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NTUwODI2NQ==,size_16,color_FFFFFF,t_70#pic_center)
28+
29+
30+
31+
首先使用`torchvision`加载和归一化我们的训练数据和测试数据。
32+
33+
a、`torchvision`这个东西,实现了常用的一些深度学习的相关的图像数据的加载功能,比如cifar10、Imagenet、Mnist等等的,保存在`torchvision.datasets`模块中。
34+
35+
b、同时,也封装了一些处理数据的方法。保存在`torchvision.transforms`模块中
36+
37+
c、还封装了一些模型和工具封装在相应模型中,比如`torchvision.models`当中就包含了AlexNet,VGG,ResNet,SqueezeNet等模型。
38+
39+
40+
41+
**由于torchvision的datasets的输出是[0,1]的PILImage,所以我们先先归一化为[-1,1]的Tensor**
42+
43+
首先定义了一个变换transform,利用的是上面提到的transforms模块中的Compose( )把多个变换组合在一起,可以看到这里面组合了ToTensor和Normalize这两个变换
44+
45+
`transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))`前面的(0.5,0.5,0.5) 是 R G B 三个通道上的均值, 后面(0.5, 0.5, 0.5)是三个通道的标准差,注意通道顺序是 R G B ,用过opencv的同学应该知道openCV读出来的图像是 BRG顺序。这两个tuple数据是用来对RGB 图像做归一化的,如其名称 Normalize 所示这里都取0.5只是一个近似的操作,实际上其均值和方差并不是这么多,但是就这个示例而言 影响可不计。精确值是通过分别计算R,G,B三个通道的数据算出来的。
46+
47+
```python
48+
transform = transforms.Compose([
49+
# transforms.CenterCrop(224),
50+
transforms.RandomCrop(32,padding=4), # 数据增广
51+
transforms.RandomHorizontalFlip(), # 数据增广
52+
transforms.ToTensor(),
53+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
54+
])
55+
```
56+
57+
`trainloader`其实是一个比较重要的东西,我们后面就是通过`trainloader`把数据传入网络,当然这里的`trainloader`其实是个变量名,可以随便取,重点是他是由后面的`torch.utils.data.DataLoader()`定义的,这个东西来源于`torch.utils.data`模块
58+
59+
```python
60+
Batch_Size = 256
61+
```
62+
63+
```python
64+
trainset = datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
65+
testset = datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
66+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=Batch_Size,shuffle=True, num_workers=2)
67+
testloader = torch.utils.data.DataLoader(testset, batch_size=Batch_Size,shuffle=True, num_workers=2)
68+
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
69+
```
70+
71+
> ```python
72+
> Files already downloaded and verified
73+
> Files already downloaded and verified
74+
> ```
75+
76+
### 2. 查看数据(格式,大小,形状)
77+
78+
首先可以查看类别
79+
80+
```python
81+
classes = trainset.classes
82+
classes
83+
```
84+
85+
> ```python
86+
> ['airplane',
87+
> 'automobile',
88+
> 'bird',
89+
> 'cat',
90+
> 'deer',
91+
> 'dog',
92+
> 'frog',
93+
> 'horse',
94+
> 'ship',
95+
> 'truck']
96+
> ```
97+
98+
```python
99+
trainset.class_to_idx
100+
```
101+
102+
> ```python
103+
> {'airplane': 0,
104+
> 'automobile': 1,
105+
> 'bird': 2,
106+
> 'cat': 3,
107+
> 'deer': 4,
108+
> 'dog': 5,
109+
> 'frog': 6,
110+
> 'horse': 7,
111+
> 'ship': 8,
112+
> 'truck': 9}
113+
> ```
114+
115+
也可以查看一下训练集的数据
116+
117+
```python
118+
trainset.data.shape #50000是图片数量,32x32是图片大小,3是通道数量RGB
119+
```
120+
121+
> ```python
122+
> (50000, 32, 32, 3)
123+
> ```
124+
125+
查看数据类型
126+
127+
```python
128+
#查看数据类型
129+
print(type(trainset.data))
130+
print(type(trainset))
131+
```
132+
133+
> ```python
134+
> <class 'numpy.ndarray'>
135+
> <class 'torchvision.datasets.cifar.CIFAR10'>
136+
> ```
137+
>
138+
>
139+
140+
**总结:**
141+
142+
`trainset.data.shape`是标准的numpy.ndarray类型,其中50000是图片数量,32x32是图片大小,3是通道数量RGB
143+
`trainset`是标准的??类型,其中50000为图片数量,0表示取前面的数据,2表示3通道数RGB32*32表示图片大小
144+
145+
### 3. 查看图片
146+
147+
接下来我们对图片进行可视化
148+
149+
```python
150+
import numpy as np
151+
import matplotlib.pyplot as plt
152+
plt.imshow(trainset.data[0])
153+
im,label = iter(trainloader).next()
154+
```
155+
156+
![在这里插入图片描述](https://img-blog.csdnimg.cn/d2733fe9714446caa0f6ff0d8501adcd.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NTUwODI2NQ==,size_16,color_FFFFFF,t_70)
157+
158+
#### np.ndarray转为torch.Tensor
159+
160+
在深度学习中,原始图像需要转换为深度学习框架自定义的数据格式,在pytorch中,需要转为`torch.Tensor`
161+
pytorch提供了`torch.Tensor``numpy.ndarray`转换为接口:
162+
163+
| 方法名 | 作用 |
164+
| ----------------------- | ------------------------------- |
165+
| `torch.from_numpy(xxx)` | `numpy.ndarray`转为torch.Tensor |
166+
| `tensor1.numpy()` | 获取tensor1对象的numpy格式数据 |
167+
168+
`torch.Tensor` 高维矩阵的表示: N x C x H x W
169+
170+
`numpy.ndarray` 高维矩阵的表示:N x H x W x C
171+
172+
因此在两者转换的时候需要使用`numpy.transpose( )` 方法 。
173+
174+
```python
175+
def imshow(img):
176+
img = img / 2 + 0.5
177+
img = np.transpose(img.numpy(),(1,2,0))
178+
plt.imshow(img)
179+
```
180+
181+
```python
182+
imshow(im[0])
183+
```
184+
185+
![在这里插入图片描述](https://img-blog.csdnimg.cn/bc61a5f1af4b45f696e2a6d8bfc7b223.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NTUwODI2NQ==,size_16,color_FFFFFF,t_70)
186+
187+
我们也可以批量可视化图片,不过这里需要用到`make_grid`
188+
189+
```python
190+
plt.figure(figsize=(8,12))
191+
imshow(torchvision.utils.make_grid(im[:32]))
192+
```
193+
194+
![在这里插入图片描述](https://img-blog.csdnimg.cn/b9beb6f0fe4b4fcb9ef26f300bbb242b.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NTUwODI2NQ==,size_16,color_FFFFFF,t_70)

0 commit comments

Comments
 (0)