Skip to content

Commit 3cb654a

Browse files
committed
test
1 parent 5c592ba commit 3cb654a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+7942
-0
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
### Results on Imagenet-1K
2+
3+
Trained weights of 5.1M: [here](https://drive.google.com/file/d/1LMkOTPehDNpQE79bvB7jFTf6UzDjpAHQ/view?usp=drive_link).
4+
5+
Trained weights of 10M: [here](https://drive.google.com/file/d/1pHrampLjyE1kLr-4DS1WgSdnCVPzL6Tq/view?usp=sharing).
6+
7+
Trained weights of 19M: [here](https://drive.google.com/file/d/1pSGCOzrZNgHDxQXAp-Uelx61snIbQC1H/view?usp=drive_link).
8+
9+
Others weights are coming soon.
10+
### Train
11+
12+
Train:
13+
14+
```shell
15+
torchrun --standalone --nproc_per_node=8 \
16+
main_finetune.py \
17+
--batch_size 256 \
18+
--blr 6e-4 \
19+
--warmup_epochs 5 \
20+
--epochs 200 \
21+
--model Efficient_Spiking_Transformer_s \
22+
--data_path /your/data/path \
23+
--output_dir outputs/T1 \
24+
--log_dir outputs/T1 \
25+
--model_mode ms \
26+
--dist_eval
27+
```
28+
29+
30+
31+
### Data Prepare
32+
33+
ImageNet with the following folder structure, you can extract imagenet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4).
34+
35+
```shell
36+
│imagenet/
37+
├──train/
38+
│ ├── n01440764
39+
│ │ ├── n01440764_10026.JPEG
40+
│ │ ├── n01440764_10027.JPEG
41+
│ │ ├── ......
42+
│ ├── ......
43+
├──val/
44+
│ ├── n01440764
45+
│ │ ├── ILSVRC2012_val_00000293.JPEG
46+
│ │ ├── ILSVRC2012_val_00002138.JPEG
47+
│ │ ├── ......
48+
│ ├── ......
49+
```
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# --------------------------------------------------------
7+
# References:
8+
# DeiT: https://github.com/facebookresearch/deit
9+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
10+
# --------------------------------------------------------
11+
12+
import math
13+
import sys
14+
from typing import Iterable, Optional
15+
16+
import torch
17+
18+
from timm.data import Mixup
19+
from timm.utils import accuracy
20+
21+
import util.misc as misc
22+
import util.lr_sched as lr_sched
23+
24+
def train_one_epoch(
25+
model,
26+
criterion,
27+
data_loader,
28+
optimizer,
29+
device,
30+
epoch,
31+
loss_scaler,
32+
max_norm,
33+
mixup_fn,
34+
log_writer,
35+
args,
36+
):
37+
model.train()
38+
metric_logger = misc.MetricLogger(delimiter=" ")
39+
metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
40+
header = "Epoch: [{}]".format(epoch)
41+
print_freq = 100
42+
43+
accum_iter = args.accum_iter
44+
45+
optimizer.zero_grad()
46+
47+
48+
if log_writer is not None:
49+
print("log_dir: {}".format(log_writer.log_dir))
50+
51+
for data_iter_step, (samples, targets) in enumerate(
52+
metric_logger.log_every(data_loader, print_freq, header)
53+
):
54+
# we use a per iteration (instead of per epoch) lr scheduler
55+
if data_iter_step % accum_iter == 0:
56+
lr_sched.adjust_learning_rate(
57+
optimizer, data_iter_step / len(data_loader) + epoch, args
58+
)
59+
60+
samples = samples.to(device, non_blocking=True)
61+
targets = targets.to(device, non_blocking=True)
62+
targets_nomix = targets
63+
if mixup_fn is not None:
64+
samples, targets = mixup_fn(samples, targets)
65+
66+
with torch.cuda.amp.autocast():
67+
outputs = model(samples)
68+
if args.kd:
69+
loss = criterion(samples, outputs, targets)
70+
outputs_acc, _ = outputs
71+
else:
72+
loss = criterion(outputs, targets)
73+
outputs_acc = outputs
74+
loss_value = loss.item()
75+
76+
if not math.isfinite(loss_value):
77+
print("Loss is {}, stopping training".format(loss_value))
78+
sys.exit(1)
79+
80+
loss = loss / accum_iter
81+
loss_scaler(
82+
loss,
83+
optimizer,
84+
clip_grad=max_norm,
85+
parameters=model.parameters(),
86+
create_graph=False,
87+
update_grad=(data_iter_step + 1) % accum_iter == 0,
88+
)
89+
if (data_iter_step + 1) % accum_iter == 0:
90+
optimizer.zero_grad()
91+
torch.cuda.synchronize()
92+
batch_size = samples.shape[0]
93+
acc1, acc5 = accuracy(outputs_acc, targets_nomix, topk=(1, 5))
94+
95+
metric_logger.update(loss=loss_value)
96+
min_lr = 10.0
97+
max_lr = 0.0
98+
for group in optimizer.param_groups:
99+
min_lr = min(min_lr, group["lr"])
100+
max_lr = max(max_lr, group["lr"])
101+
102+
metric_logger.update(lr=max_lr)
103+
loss_value_reduce = misc.all_reduce_mean(loss_value)
104+
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
105+
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
106+
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
107+
"""We use epoch_1000x as the x-axis in tensorboard.
108+
This calibrates different curves when batch size changes.
109+
"""
110+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
111+
log_writer.add_scalar("loss", loss_value_reduce, epoch_1000x)
112+
log_writer.add_scalar("lr", max_lr, epoch_1000x)
113+
# gather the stats from all processes
114+
metric_logger.synchronize_between_processes()
115+
print("Averaged stats:", metric_logger)
116+
print(
117+
"* Train_Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}".format(
118+
top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss
119+
)
120+
)
121+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
122+
123+
def cal_acc(metric_logger,output,target):
124+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
125+
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
126+
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
127+
return metric_logger.acc1,metric_logger.acc5
128+
129+
@torch.no_grad()
130+
def evaluate(data_loader, model, device):
131+
criterion = torch.nn.CrossEntropyLoss()
132+
133+
metric_logger = misc.MetricLogger(delimiter=" ")
134+
header = "Test:"
135+
136+
# switch to evaluation mode
137+
model.eval()
138+
139+
for batch in metric_logger.log_every(data_loader, 500, header):
140+
images = batch[0]
141+
target = batch[-1]
142+
images = images.to(device, non_blocking=True)
143+
target = target.to(device, non_blocking=True)
144+
145+
# compute output
146+
with torch.cuda.amp.autocast():
147+
output = model(images)
148+
loss = criterion(output, target)
149+
150+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
151+
batch_size = images.shape[0]
152+
metric_logger.update(loss=loss.item())
153+
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
154+
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
155+
# gather the stats from all processes
156+
metric_logger.synchronize_between_processes()
157+
print(
158+
"* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}".format(
159+
top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss
160+
)
161+
)
162+
163+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

0 commit comments

Comments
 (0)