-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisualize.py
More file actions
77 lines (62 loc) · 2.95 KB
/
visualize.py
File metadata and controls
77 lines (62 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from simclr_model import ResNet18Encoder # 导入你的Encoder定义
def visualize():
print("开始进行t-SNE特征可视化...")
# ===== 1. 设置设备和加载模型 =====
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
encoder_path = "./checkpoints/encoder.pth"
encoder = ResNet18Encoder().to(device)
encoder.load_state_dict(torch.load(encoder_path, map_location=device))
encoder.eval() # 切换到评估模式
# ===== 2. 准备测试数据集 =====
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.229, 0.224, 0.225])
])
# 注意:我们加载的是测试集 train=False
test_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=4)
# CIFAR-10的类别名称,用于图例
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# ===== 3. 提取所有测试集的特征和标签 =====
all_features = []
all_labels = []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
features = encoder(images).squeeze()
all_features.append(features.cpu().numpy())
all_labels.append(labels.cpu().numpy())
all_features = np.concatenate(all_features)
all_labels = np.concatenate(all_labels)
print(f"特征提取完成! 特征形状: {all_features.shape}, 标签形状: {all_labels.shape}")
# ===== 4. 使用t-SNE进行降维 =====
print("开始t-SNE降维,这可能需要几分钟...")
# n_components=2 表示降到2维
tsne = TSNE(n_components=2, random_state=42, max_iter=300, perplexity=30)
features_2d = tsne.fit_transform(all_features)
print("t-SNE降维完成!")
# ===== 5. 绘制散点图 =====
plt.figure(figsize=(12, 10))
# 遍历10个类别,为每个类别绘制不同颜色的散点
for i in range(10):
# 找到属于当前类别i的所有点的索引
indices = np.where(all_labels == i)
# 绘制这些点的散点图
plt.scatter(features_2d[indices, 0], features_2d[indices, 1], label=classes[i], s=10) # s是点的大小
plt.legend()
plt.title("SimCLR Encoder - t-SNE Visualization of CIFAR-10 Test Set")
plt.xlabel("t-SNE feature 1")
plt.ylabel("t-SNE feature 2")
# 保存图像到文件
output_path = "tsne_visualization.png"
plt.savefig(output_path, dpi=300)
print(f"✅ 可视化图像已保存到 {output_path}")
# plt.show() # 如果你想在运行时直接显示图像,可以取消这行注释
if __name__ == "__main__":
visualize()