-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathValidBagging.py
More file actions
40 lines (32 loc) · 1.44 KB
/
ValidBagging.py
File metadata and controls
40 lines (32 loc) · 1.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
import torch
import torch.nn as nn
from bagging.MergeResults import BaggingResult
from Utils import DataUtils
csv_path = "q1_data/train2.csv"
dataset_path = "q1_data/train.npy"
val_refer_list = "bagging/val.npy"
BATCH_SIZE = 20
CUDA_DEVICE = 2
CLASS_NUM = 100
UP_SIZE = (224,224)
bag_pkl_paths=["./pklmodels/Class100_A_epoch_40.pkl",
"./pklmodels/Class100_B_epoch_40.pkl",
"./pklmodels/Class100_C_epoch_40.pkl"]
ValDataset = DataUtils.DatasetLoader(csv_path, dataset_path, refer_list=np.load(val_refer_list),
mode="Valid", up_size=UP_SIZE)
validloader = torch.utils.data.DataLoader(ValDataset, batch_size=BATCH_SIZE, num_workers=2, shuffle=True)
results = BaggingResult(CUDA_DEVICE, bag_pkl_paths=bag_pkl_paths, class_num=CLASS_NUM)
merge_accuracy = []
split_accuracy = [[] for i in range(len(bag_pkl_paths))]
for i, data in enumerate(validloader):
_, val_x, val_label = data
merge_res, split_res = results.pred(val_x)
merge_accuracy.append((val_label==merge_res).numpy().mean())
print(i*BATCH_SIZE, " - ", (i+1)*BATCH_SIZE)
for j in range(split_res.shape[0]):
res = split_res[j]
split_accuracy[j].append((val_label==res).numpy().mean())
print("Merge Accuracy: {:.4f}".format(np.array(merge_accuracy).mean()))
for j in range(split_res.shape[0]):
print("Bag {:d} Accuracy: {:.4f}".format(j, np.array(split_accuracy[j]).mean()))