Skip to content

Commit c1c18b4

Browse files
authored
Update main_finetune.py
1 parent cac39d1 commit c1c18b4

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

SDT_V3/Classification/Model_Large/main_finetune.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from util.misc import NativeScalerWithGradNormCount as NativeScaler
3939
from util.kd_loss import DistillationLoss
4040

41-
import models
41+
import spikformer
4242
from engine_finetune import train_one_epoch, evaluate
4343
from timm.data import create_loader
4444

@@ -63,7 +63,7 @@ def get_args_parser():
6363
)
6464
parser.add_argument("--finetune", default="", help="finetune from checkpoint")
6565
parser.add_argument(
66-
"--data_path", default="/raid/ligq/imagenet1-k/", type=str, help="dataset path"
66+
"--data_path", default="", type=str, help="dataset path"
6767
)
6868

6969
# Model parameters
@@ -236,12 +236,12 @@ def get_args_parser():
236236

237237
parser.add_argument(
238238
"--output_dir",
239-
default="/raid/ligq/htx/spikemae/output_dir",
239+
default="",
240240
help="path where to save, empty for no saving",
241241
)
242242
parser.add_argument(
243243
"--log_dir",
244-
default="/raid/ligq/htx/spikemae/output_dir",
244+
default="",
245245
help="path where to tensorboard log",
246246
)
247247
parser.add_argument(
@@ -399,7 +399,7 @@ def main(args):
399399
)
400400

401401

402-
model = models.__dict__[args.model](kd=args.kd)
402+
model = spikformer.__dict__[args.model](kd=args.kd)
403403
model.T = args.time_steps
404404
model_ema = None
405405
if args.finetune:

0 commit comments

Comments
 (0)