|
| 1 | +# Pytorch CIFAR10图像分类 数据加载与可视化篇 |
| 2 | + |
| 3 | +[toc] |
| 4 | + |
| 5 | +这里贴一下汇总篇:[汇总篇](https://blog.csdn.net/weixin_45508265/article/details/119285255) |
| 6 | + |
| 7 | +**Pytorch一般有以下几个流程** |
| 8 | + |
| 9 | +1. 数据读取 |
| 10 | +2. 数据处理 |
| 11 | +3. 搭建网络 |
| 12 | +4. 模型训练 |
| 13 | +5. 模型上线 |
| 14 | + |
| 15 | +这里会先讲一下关于CIFAR10的数据加载和图片可视化,之后的模型篇会对网络进行介绍和实现。 |
| 16 | + |
| 17 | +### 1.数据读取 |
| 18 | + |
| 19 | +CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( arplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。 |
| 20 | + |
| 21 | +与 MNIST 数据集中目比, CIFAR-10 具有以下不同点: |
| 22 | + |
| 23 | +- CIFAR-10 是 3 通道的彩色 RGB 图像,而 MNIST 是灰度图像。 |
| 24 | +- CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。 |
| 25 | +- 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。 |
| 26 | + |
| 27 | + |
| 28 | + |
| 29 | + |
| 30 | + |
| 31 | +首先使用`torchvision`加载和归一化我们的训练数据和测试数据。 |
| 32 | + |
| 33 | +a、`torchvision`这个东西,实现了常用的一些深度学习的相关的图像数据的加载功能,比如cifar10、Imagenet、Mnist等等的,保存在`torchvision.datasets`模块中。 |
| 34 | + |
| 35 | +b、同时,也封装了一些处理数据的方法。保存在`torchvision.transforms`模块中 |
| 36 | + |
| 37 | +c、还封装了一些模型和工具封装在相应模型中,比如`torchvision.models`当中就包含了AlexNet,VGG,ResNet,SqueezeNet等模型。 |
| 38 | + |
| 39 | + |
| 40 | + |
| 41 | +**由于torchvision的datasets的输出是[0,1]的PILImage,所以我们先先归一化为[-1,1]的Tensor** |
| 42 | + |
| 43 | +首先定义了一个变换transform,利用的是上面提到的transforms模块中的Compose( )把多个变换组合在一起,可以看到这里面组合了ToTensor和Normalize这两个变换 |
| 44 | + |
| 45 | +`transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))`前面的(0.5,0.5,0.5) 是 R G B 三个通道上的均值, 后面(0.5, 0.5, 0.5)是三个通道的标准差,注意通道顺序是 R G B ,用过opencv的同学应该知道openCV读出来的图像是 BRG顺序。这两个tuple数据是用来对RGB 图像做归一化的,如其名称 Normalize 所示这里都取0.5只是一个近似的操作,实际上其均值和方差并不是这么多,但是就这个示例而言 影响可不计。精确值是通过分别计算R,G,B三个通道的数据算出来的。 |
| 46 | + |
| 47 | +```python |
| 48 | +transform = transforms.Compose([ |
| 49 | +# transforms.CenterCrop(224), |
| 50 | + transforms.RandomCrop(32,padding=4), # 数据增广 |
| 51 | + transforms.RandomHorizontalFlip(), # 数据增广 |
| 52 | + transforms.ToTensor(), |
| 53 | + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
| 54 | +]) |
| 55 | +``` |
| 56 | + |
| 57 | + `trainloader`其实是一个比较重要的东西,我们后面就是通过`trainloader`把数据传入网络,当然这里的`trainloader`其实是个变量名,可以随便取,重点是他是由后面的`torch.utils.data.DataLoader()`定义的,这个东西来源于`torch.utils.data`模块 |
| 58 | + |
| 59 | +```python |
| 60 | +Batch_Size = 256 |
| 61 | +``` |
| 62 | + |
| 63 | +```python |
| 64 | +trainset = datasets.CIFAR10(root='./data', train=True,download=True, transform=transform) |
| 65 | +testset = datasets.CIFAR10(root='./data',train=False,download=True,transform=transform) |
| 66 | +trainloader = torch.utils.data.DataLoader(trainset, batch_size=Batch_Size,shuffle=True, num_workers=2) |
| 67 | +testloader = torch.utils.data.DataLoader(testset, batch_size=Batch_Size,shuffle=True, num_workers=2) |
| 68 | +classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck') |
| 69 | +``` |
| 70 | + |
| 71 | +> ```python |
| 72 | +> Files already downloaded and verified |
| 73 | +> Files already downloaded and verified |
| 74 | +> ``` |
| 75 | +
|
| 76 | +### 2. 查看数据(格式,大小,形状) |
| 77 | +
|
| 78 | +首先可以查看类别 |
| 79 | +
|
| 80 | +```python |
| 81 | +classes = trainset.classes |
| 82 | +classes |
| 83 | +``` |
| 84 | +
|
| 85 | +> ```python |
| 86 | +> ['airplane', |
| 87 | +> 'automobile', |
| 88 | +> 'bird', |
| 89 | +> 'cat', |
| 90 | +> 'deer', |
| 91 | +> 'dog', |
| 92 | +> 'frog', |
| 93 | +> 'horse', |
| 94 | +> 'ship', |
| 95 | +> 'truck'] |
| 96 | +> ``` |
| 97 | +
|
| 98 | +```python |
| 99 | +trainset.class_to_idx |
| 100 | +``` |
| 101 | +
|
| 102 | +> ```python |
| 103 | +> {'airplane': 0, |
| 104 | +> 'automobile': 1, |
| 105 | +> 'bird': 2, |
| 106 | +> 'cat': 3, |
| 107 | +> 'deer': 4, |
| 108 | +> 'dog': 5, |
| 109 | +> 'frog': 6, |
| 110 | +> 'horse': 7, |
| 111 | +> 'ship': 8, |
| 112 | +> 'truck': 9} |
| 113 | +> ``` |
| 114 | +
|
| 115 | +也可以查看一下训练集的数据 |
| 116 | +
|
| 117 | +```python |
| 118 | +trainset.data.shape #50000是图片数量,32x32是图片大小,3是通道数量RGB |
| 119 | +``` |
| 120 | +
|
| 121 | +> ```python |
| 122 | +> (50000, 32, 32, 3) |
| 123 | +> ``` |
| 124 | +
|
| 125 | +查看数据类型 |
| 126 | +
|
| 127 | +```python |
| 128 | +#查看数据类型 |
| 129 | +print(type(trainset.data)) |
| 130 | +print(type(trainset)) |
| 131 | +``` |
| 132 | +
|
| 133 | +> ```python |
| 134 | +> <class 'numpy.ndarray'> |
| 135 | +> <class 'torchvision.datasets.cifar.CIFAR10'> |
| 136 | +> ``` |
| 137 | +> |
| 138 | +> |
| 139 | +
|
| 140 | +**总结:** |
| 141 | +
|
| 142 | +`trainset.data.shape`是标准的numpy.ndarray类型,其中50000是图片数量,32x32是图片大小,3是通道数量RGB; |
| 143 | +`trainset`是标准的??类型,其中50000为图片数量,0表示取前面的数据,2表示3通道数RGB,32*32表示图片大小 |
| 144 | +
|
| 145 | +### 3. 查看图片 |
| 146 | +
|
| 147 | +接下来我们对图片进行可视化 |
| 148 | +
|
| 149 | +```python |
| 150 | +import numpy as np |
| 151 | +import matplotlib.pyplot as plt |
| 152 | +plt.imshow(trainset.data[0]) |
| 153 | +im,label = iter(trainloader).next() |
| 154 | +``` |
| 155 | +
|
| 156 | + |
| 157 | + |
| 158 | +#### np.ndarray转为torch.Tensor |
| 159 | + |
| 160 | +在深度学习中,原始图像需要转换为深度学习框架自定义的数据格式,在pytorch中,需要转为`torch.Tensor`。 |
| 161 | +pytorch提供了`torch.Tensor` 与`numpy.ndarray`转换为接口: |
| 162 | + |
| 163 | +| 方法名 | 作用 | |
| 164 | +| ----------------------- | ------------------------------- | |
| 165 | +| `torch.from_numpy(xxx)` | `numpy.ndarray`转为torch.Tensor | |
| 166 | +| `tensor1.numpy()` | 获取tensor1对象的numpy格式数据 | |
| 167 | + |
| 168 | +`torch.Tensor` 高维矩阵的表示: N x C x H x W |
| 169 | + |
| 170 | +`numpy.ndarray` 高维矩阵的表示:N x H x W x C |
| 171 | + |
| 172 | +因此在两者转换的时候需要使用`numpy.transpose( )` 方法 。 |
| 173 | + |
| 174 | +```python |
| 175 | +def imshow(img): |
| 176 | + img = img / 2 + 0.5 |
| 177 | + img = np.transpose(img.numpy(),(1,2,0)) |
| 178 | + plt.imshow(img) |
| 179 | +``` |
| 180 | + |
| 181 | +```python |
| 182 | +imshow(im[0]) |
| 183 | +``` |
| 184 | + |
| 185 | + |
| 186 | + |
| 187 | +我们也可以批量可视化图片,不过这里需要用到`make_grid` |
| 188 | + |
| 189 | +```python |
| 190 | +plt.figure(figsize=(8,12)) |
| 191 | +imshow(torchvision.utils.make_grid(im[:32])) |
| 192 | +``` |
| 193 | + |
| 194 | + |
0 commit comments