forked from AVC2-UESTC/DAR-TR-PEFT
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_debug.py
More file actions
148 lines (93 loc) · 4.54 KB
/
test_debug.py
File metadata and controls
148 lines (93 loc) · 4.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import warnings
import time
import datetime
from typing import Union, List, Dict
# from collections import OrderedDict
import torch
from torch.utils import data
# from train_utils import get_parameter_groups
from train_utils.train_and_eval import evaluate
# from src.Models.builder import build_metric
from mmengine.config import Config as MMConfig
from src.builder import build_dataset, build_model, build_scheduler
def main(scheduler_cfg, dataset_cfg, model_cfg, runtime: Dict):
# print(scheduler_cfg)
# print(dataset_cfg)
# print(model_cfg)
logger_name = 'default'
logger_args = None
if scheduler_cfg.seed:
torch.manual_seed(scheduler_cfg.seed)
device = torch.device(scheduler_cfg.device if torch.cuda.is_available() else "cpu")
print(f'Using {device} for testing')
batch_size = 1
results_file = "test_results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
#load dataset cfg
# instantiate dataset
# train_dataset = dataset_cfg.dataset_train
val_dataset = dataset_cfg.dataset_val
num_workers = scheduler_cfg.num_workers
# co_fn_train = getattr(train_dataset, 'collate_fn', None)
co_fn_val = getattr(val_dataset, 'collate_fn', None)
# train_data_loader = data.DataLoader(train_dataset,
# batch_size=batch_size,
# num_workers=num_workers,
# shuffle=True,
# pin_memory=False,
# collate_fn=co_fn_train
# )
val_data_loader = data.DataLoader(val_dataset,
batch_size=1, # must be 1
num_workers=num_workers,
pin_memory=False,
collate_fn=co_fn_val
)
# ===================== model define =================================
model, save_weights_keys = model_cfg.model
# print(model)
if model_cfg.pretrained_weights is not None:
print('Weights loaded')
else:
warnings.warn('No pretrained weights are loaded.')
model.to(device)
metric_dict = scheduler_cfg.get_metric_dict
if logger_name == 'default':
print('Using default logger.')
logger = None
else:
raise ValueError(f"Unsupported logger: {logger}")
# ======================= Begin training =============================
start_time = time.time()
metric_info_dict = evaluate(model, val_data_loader,
device=device,
epoch=0,
metrics=metric_dict,
logger_name=logger_name,
logger=logger)
# save results
# write into txt
with open(results_file, "a") as f:
write_info = f"[epoch: {0}] Val_Metrics: {metric_info_dict} \n"
f.write(write_info)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Test time usage: {}".format(total_time_str))
if __name__ == "__main__":
config = './configs/dinov2/config_dinov2_b_dar_fgseg.py'
# config = './configs/dinov2/config_dinov2_b_dar_distill_fgseg_test.py'
assert os.path.exists(config), f"No such file: {config}"
config = MMConfig.fromfile(config)
# args = cfg.cfg_segformer_sod() # load config file
Scheduler_cfg = config.get("Scheduler_cfg")
Dataset_cfg = config.get("Dataset_cfg")
Model_cfg = config.get("Model_cfg")
runtime = config.get("runtime")
# dict
scheduler_cfg_inst = build_scheduler(scheduler_cfg_name=Scheduler_cfg['scheduler_cfg_name'],
scheduler_cfg_args=Scheduler_cfg['scheduler_cfg_args'])
dataset_cfg_inst = build_dataset(dataset_cfg_name=Dataset_cfg['dataset_cfg_name'],
dataset_cfg_args=Dataset_cfg['dataset_cfg_args'])
model_cfg_inst = build_model(model_cfg_name=Model_cfg['model_cfg_name'],
model_cfg_args=Model_cfg['model_cfg_args'])
main(scheduler_cfg_inst, dataset_cfg_inst, model_cfg_inst, runtime)