Skip to content

Commit 00fcd1f

Browse files
authored
Merge pull request #4 from YmShan/main
Directly runnable MAE pre training.
2 parents 61e2c5f + 952d692 commit 00fcd1f

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

SDT_V3/Classification/Model_Large/MAE_SDT.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def __init__(
280280
self.layer_scale2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
281281

282282
def forward(self, x):
283-
# T, B, C, N = x.shape
283+
T, B, C, N = x.shape
284284
if self.model=="base":
285285
x= x + self.rep_conv(self.lif(x).flatten(0, 1)).reshape(T, B, C, N)
286286
# TODO: need channel-wise layer scale, init as 1e-6

SDT_V3/Classification/Model_Large/engine_pretrain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def train_one_epoch(model: torch.nn.Module,
5555
optimizer.zero_grad()
5656

5757
torch.cuda.synchronize()
58-
functional.reset_net(model)
58+
# functional.reset_net(model)
5959
metric_logger.update(loss=loss_value)
6060

6161
lr = optimizer.param_groups[0]["lr"]
@@ -74,4 +74,4 @@ def train_one_epoch(model: torch.nn.Module,
7474
# gather the stats from all processes
7575
metric_logger.synchronize_between_processes()
7676
print("Averaged stats:", metric_logger)
77-
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
77+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

0 commit comments

Comments
 (0)