本项目基于PyTorch实现了一个简洁的SimCLR(A Simple Framework for Contrastive Learning of Visual Representations)框架,并在CIFAR-10数据集上进行了预训练和线性评估。
本项目旨在通过代码实践,深入理解自监督学习的核心思想。SimCLR是一种先进的自监督学习方法,它通过最大化同一图片不同增强视图之间的一致性,来学习图像的通用、高质量的特征表示,整个过程无需任何人工标签。
本项目完整流程包括:
- 阶段一:自监督预训练:使用无标签的CIFAR-10训练集,通过对比学习训练一个ResNet-18编码器。
- 阶段二:线性探测评估:冻结预训练好的编码器,在其后接一个线性分类器,并使用带标签的CIFAR-10训练集来训练这个分类器,最终在测试集上评估编码器所学特征的质量。
.
├── simclr_model.py # 定义SimCLR模型架构 (Encoder + Projector)
├── losses.py # 实现NT-Xent对比损失函数
├── train.py # 执行SimCLR预训练的主脚本
├── linear_eval.py # 执行线性探测评估的主脚本
├── checkpoints/ # 存放训练好的模型权重(运行train.py后自动生成)
└── data/ # 存放下载的数据集(运行train.py后自动生成)
# 使用 conda
conda create -n simclr python=3.10
conda activate simclr首先,请确保你已经安装了PyTorch。然后,通过以下命令安装其他依赖:
pip install -r requirements.txt提示: 你可以通过在你的终端运行
pip freeze > requirements.txt命令来自动生成你当前环境的依赖文件。
该脚本将使用CIFAR-10的训练集进行SimCLR预训练,并将训练好的Encoder权重保存在./checkpoints/encoder.pth。
# 使用默认参数进行训练 (100个epoch, batch_size 256, lr 3e-4)
python train.py
# 或者,自定义参数进行训练
python train.py --epochs 200 --batch_size 512 --learning_rate 0.001该脚本将加载预训练好的Encoder,训练一个线性分类头,并在CIFAR-10测试集上评估分类准确率。
python linear_eval.py在我自己的机器上,经过 100 个epoch的预训练后,在线性评估阶段达到的最终测试集准确率为:
- Top-1 Accuracy: Acc: 67.69%
(在这里,你还可以添加更多实验,比如不同超参数下的结果对比表格)
为了直观地展示Encoder学习到的特征表示的质量,我使用了t-SNE对CIFAR-10测试集的特征向量进行降维可视化。
A. 表现优秀的方面 (优点)
交通工具类别区分度高:
左下角的橙色 (car) 点云和底部的青色 (truck) 点云,各自形成了非常密集且独立的“大陆”,并且与其他大部分类别都分得比较开。这说明Encoder对于识别“汽车”和“卡车”这类人造物体的特征非常在行。
一个有趣的观察是,car和truck这两个簇虽然分离,但相对位置比较近,这也符合我们的直觉(它们都属于“四轮车辆”)。
物品类别聚类良好: 右下角的草绿色 (ship) 点云也形成了一个相对紧凑的聚类。 左上角的蓝色 (plane) 点云同样有很好的聚集趋势。
但它们各自的分布比较松散,没有形成像car那样紧凑的簇。
可能的原因: 这两个类别的背景变化非常大。飞机可能在天空中、在跑道上;轮船可能在广阔的海面、也可能停靠在复杂的港口。多变的背景给特征提取带来了干扰。
B. 存在挑战的方面 (弱点与可改进之处)
核心区域的动物类别混淆:
观察图中央和右侧的大片区域,发现红色 (cat)、紫色 (deer)、灰色 (horse) 以及棕色 (dog) 的点云严重地混合、重叠在了一起。
可能的原因: 对于模型来说,这些动物(猫、鹿、马)都属于“四条腿的哺乳动物”,它们常常出现在相似的背景下(草地、树林),姿态也千变万化。区分它们的难度本身就比区分“汽车”和“飞机”要大得多。模型在这个“动物混合区”犯的错误,也许是导致线性评估准确率停留在~68%的主要原因。
