3838from util .misc import NativeScalerWithGradNormCount as NativeScaler
3939from util .kd_loss import DistillationLoss
4040
41- import models
41+ import spikformer
4242from engine_finetune import train_one_epoch , evaluate
4343from timm .data import create_loader
4444
@@ -63,7 +63,7 @@ def get_args_parser():
6363 )
6464 parser .add_argument ("--finetune" , default = "" , help = "finetune from checkpoint" )
6565 parser .add_argument (
66- "--data_path" , default = "/raid/ligq/imagenet1-k/ " , type = str , help = "dataset path"
66+ "--data_path" , default = "" , type = str , help = "dataset path"
6767 )
6868
6969 # Model parameters
@@ -236,12 +236,12 @@ def get_args_parser():
236236
237237 parser .add_argument (
238238 "--output_dir" ,
239- default = "/raid/ligq/htx/spikemae/output_dir " ,
239+ default = "" ,
240240 help = "path where to save, empty for no saving" ,
241241 )
242242 parser .add_argument (
243243 "--log_dir" ,
244- default = "/raid/ligq/htx/spikemae/output_dir " ,
244+ default = "" ,
245245 help = "path where to tensorboard log" ,
246246 )
247247 parser .add_argument (
@@ -399,7 +399,7 @@ def main(args):
399399 )
400400
401401
402- model = models .__dict__ [args .model ](kd = args .kd )
402+ model = spikformer .__dict__ [args .model ](kd = args .kd )
403403 model .T = args .time_steps
404404 model_ema = None
405405 if args .finetune :
0 commit comments