PyTorch Python torch nn ResNet ResNet18 pre-trained
针对 CIFAR-10 分类问题,搭建神经网络:AlexNet、GoogLeNet、ResNet、ResNet18。最后选择预训练后的 ResNet18 进行该问题的训练、验证和测试。包含自定义数据集 Dataset 类、自定义训练、验证和测试函数、自定义结果表格函数等。
kaggle: private score = 0.76750, ranked about 58. (training 10 epochs with GPU)
kaggle: private score = 0.68100, ranked about 71. (just training once as using CPU)
device: cpu
项目目录
├── README.md
├── TempData # 取少量图片模拟 CIFAR10 数据集
│ └── competitions
│ └── cifar-10
├── checkpoints # 存放训练完成后的模型参数 model.save() 存放处
├── config.py # 默认的配置文件
├── data # 自定义的数据集 Dataset 类
│ ├── __init__.py
│ └── dataset.py
├── logfile # 记录模型表现 csv 文件的目录
├── logs # 存放 tensorboard 文件
├── main.py # 主程序: 包含 train, test 等主要函数
├── models # 搭建的各种神经网络 `AlexNet`、`GoogLeNet`、`ResNet`、`ResNet18`
│ ├── __init__.py
│ ├── alexnet.py
│ ├── basic.py
│ ├── googlenet.py
│ ├── resnet.py
│ └── resnet18.py
├── requirements.txt # 安装依赖
├── result_example.csv # 少量数据模拟数据集得到的测试结果
└── sampleSubmission.csv # 最后生成的可提交 kaggle 的最终测试结果在终端运行
git clone https://github.com/isKage/cifar10-classification.git在项目根目录下终端输入
pip install -r requirements.txt教程见 从 kaggle 下载数据集 (mac & win)。
在 config.py 中配置相关参数。例如数据集路径。相关配置均已配置好,但需要自己配置数据集的位置。
在 _parse() 方法中,需修改 cifar 数据集的路径。例如我的配置:cifar-10 文件夹放在用户目录下的 AllData/competitions/ 下。
if config.real_or_try == "real":
# 如果数据放在用户目录的 'AllData' 下则
config.root = os.path.join(config.user_root, 'AllData', 'competitions', 'cifar-10') # 【本地设置: 数据目录】
config.res_path = os.path.join(config.working_root, 'sampleSubmission.csv')
else:
# 样本数据尝试
config.root = os.path.join(config.working_root, 'TempData', 'competitions', 'cifar-10')
config.res_path = os.path.join(config.working_root, 'result_example.csv')注意,默认的数据集为模拟数据集,故如果想在完整数据集训练,在指定路径后还需传入参数
--real_or_try=real,或者直接在config.py中 修改默认
在第 3 步设置完成数据集下载的路径后,终端输入
python main.py unzip即可解压数据集。
使用 fire 库方便的在终端中进行训练、测试过程。可以在 config.py 中输入默认参数。例如:model
为选择模型,默认使用 "ResNet18" 模型,
会自动进行下载,下载的预训练模型参数保存在 checkpoints/ 文件夹里。
在终端运行
python main.py train可以使用 --<参数名>=参数值 在终端覆盖默认参数
python main.py train model=AlexNet # 指定 AlexNet 为模型
python main.py train rea_or_try=real # 使用完整 CIFAR10 数据集,而不是模拟数据集 终端运行
tensorboard --logdir=./logs # http://localhost:6006/打开浏览器观察训练过程可视化:
终端运行
python main.py test即可得到测试后的结果表格 result_example.csv 或 sampleSubmission.csv (取决与使用的是模拟数据集还是完整的数据集)。
注意,测试完成后终端输入一下指令,对结果表格按照 id 进行排序。
python main.py sort_csv最后可以将 sampleSubmission.csv 上传到 kaggle CIFAR-10 competition 。
- 关注我的知乎账号 Zhuhu 不错过我的笔记更新。
- 我会在个人博客 isKage`Blog 更新相关项目和学习资料。


