2121device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
2222
2323
24+
25+
2426weight_diag = 10
2527weight_offdiag = 0
2628weight_f = 0.1
3032
3133exponent = 1.0
3234exponent_off = 0.1
33- exponent_f = 50
35+ exponent_f = 20
3436time_df = 1
3537trans = 1.0
3638transoffdig = 1.0
3739numm = 16
3840
39- batches_per_epoch = 128
4041
41- ODE_FC_odebatch = 64
42+
43+
44+
45+ ODE_FC_odebatch = 32
4246
4347class Identity (nn .Module ):
4448 def __init__ (self ):
@@ -56,7 +60,7 @@ def forward(self, t, x):
5660class ODEfunc_mlp (nn .Module ):
5761 def __init__ (self , dim ):
5862 super (ODEfunc_mlp , self ).__init__ ()
59- self .fc1 = ConcatFC (64 , 64 )
63+ self .fc1 = ConcatFC (128 , 128 )
6064 self .act1 = torch .sin
6165 self .nfe = 0
6266 def forward (self , t , x ):
@@ -65,21 +69,8 @@ def forward(self, t, x):
6569 out = self .act1 (out )
6670 return out
6771
68- class ODEBlock (nn .Module ):
69- def __init__ (self , odefunc ):
70- super (ODEBlock , self ).__init__ ()
71- self .odefunc = odefunc
72- self .integration_time = torch .tensor ([0 , 5 ]).float ()
73- def forward (self , x ):
74- self .integration_time = self .integration_time .type_as (x )
75- out = odeint (self .odefunc , x , self .integration_time , rtol = 1e-3 , atol = 1e-3 )
76- return out [1 ]
77- @property
78- def nfe (self ):
79- return self .odefunc .nfe
80- @nfe .setter
81- def nfe (self , value ):
82- self .odefunc .nfe = value
72+
73+
8374
8475class ODEBlocktemp (nn .Module ):
8576 def __init__ (self , odefunc ):
@@ -100,7 +91,7 @@ class MLP_OUT_ORTH1024(nn.Module):
10091 def __init__ (self ,layer_dim_ ):
10192 super (MLP_OUT_ORTH1024 , self ).__init__ ()
10293 self .layer_dim_ = layer_dim_
103- self .fc0 = ORTHFC (self .layer_dim_ , 64 , False )
94+ self .fc0 = ORTHFC (self .layer_dim_ , 128 , False )
10495 def forward (self , input_ ):
10596 h1 = self .fc0 (input_ )
10697 return h1
@@ -149,7 +140,7 @@ class MLP_OUT_LINEAR(nn.Module):
149140 def __init__ (self ,class_numbers ):
150141 self .class_numbers = class_numbers
151142 super (MLP_OUT_LINEAR , self ).__init__ ()
152- self .fc0 = nn .Linear (64 , class_numbers )
143+ self .fc0 = nn .Linear (128 , class_numbers )
153144 def forward (self , input_ ):
154145 h1 = self .fc0 (input_ )
155146 return h1
@@ -158,8 +149,8 @@ class MLP_OUT_BALL(nn.Module):
158149 def __init__ (self ,class_numbers ):
159150 super (MLP_OUT_BALL , self ).__init__ ()
160151 self .class_numbers = class_numbers
161- self .fc0 = nn .Linear (64 , class_numbers , bias = False )
162- self .fc0 .weight .data = torch .randn ([class_numbers ,64 ])
152+ self .fc0 = nn .Linear (128 , class_numbers , bias = False )
153+ self .fc0 .weight .data = torch .randn ([class_numbers ,128 ])
163154 def forward (self , input_ ):
164155 h1 = self .fc0 (input_ )
165156 return h1
@@ -362,26 +353,24 @@ def train(net, epoch,trainloader,optimizer):
362353def one_hot (x , K ):
363354 return np .array (x [:, None ] == np .arange (K )[None , :], dtype = int )
364355
365-
366- class ODEBlock (nn .Module ):
367356
357+
358+ class ODEBlock (nn .Module ):
368359 def __init__ (self , odefunc ):
369360 super (ODEBlock , self ).__init__ ()
370361 self .odefunc = odefunc
371362 self .integration_time = torch .tensor ([0 , 5 ]).float ()
372-
373363 def forward (self , x ):
374364 self .integration_time = self .integration_time .type_as (x )
375365 out = odeint (self .odefunc , x , self .integration_time , rtol = 1e-3 , atol = 1e-3 )
376366 return out [1 ]
377-
378367 @property
379368 def nfe (self ):
380369 return self .odefunc .nfe
381-
382370 @nfe .setter
383371 def nfe (self , value ):
384372 self .odefunc .nfe = value
373+
385374
386375
387376
@@ -501,24 +490,102 @@ def f_regularizer(odefunc, z):
501490
502491
503492
504- def save_training_feature (model , dataset_loader , fake_embeddings_loader = None ):
493+
494+
495+ def save_testing_feature (model , dataset_loader ):
505496 x_save = []
506497 y_save = []
507498 modulelist = list (model )
508-
499+ layernum = 0
509500 for x , y in dataset_loader :
510501 x = x .to (device )
511- y_ = y .numpy () # No need to use np.array here
502+ y_ = np . array ( y .numpy ())
512503
513- # Forward pass through the model up to the desired layer
514504 for l in modulelist [0 :2 ]:
515- x = l (x )
505+ x = l (x )
516506 xo = x
517-
518507 x_ = xo .cpu ().detach ().numpy ()
519508 x_save .append (x_ )
520509 y_save .append (y_ )
521510
511+ x_save = np .concatenate (x_save )
512+ y_save = np .concatenate (y_save )
513+
514+ np .savez (test_savepath , x_save = x_save , y_save = y_save )
515+
516+
517+
518+
519+ class DensemnistDatasetTrain (Dataset ):
520+ def __init__ (self ):
521+ """
522+ """
523+ npzfile = np .load (train_savepath )
524+
525+ self .x = npzfile ['x_save' ]
526+ self .y = npzfile ['y_save' ]
527+ def __len__ (self ):
528+ return len (self .x )
529+
530+ def __getitem__ (self , idx ):
531+ x = self .x [idx ,...]
532+ y = self .y [idx ]
533+
534+ return x ,y
535+ class DensemnistDatasetTest (Dataset ):
536+ def __init__ (self ):
537+ """
538+ """
539+ npzfile = np .load (test_savepath )
540+
541+ self .x = npzfile ['x_save' ]
542+ self .y = npzfile ['y_save' ]
543+ def __len__ (self ):
544+ return len (self .x )
545+
546+ def __getitem__ (self , idx ):
547+ x = self .x [idx ,...]
548+ y = self .y [idx ]
549+
550+ return x ,y
551+
552+
553+
554+
555+ class ODEBlocktemp (nn .Module ):
556+ def __init__ (self , odefunc ):
557+ super (ODEBlocktemp , self ).__init__ ()
558+ self .odefunc = odefunc
559+ self .integration_time = torch .tensor ([0 , 5 ]).float ()
560+ def forward (self , x ):
561+ out = self .odefunc (0 , x )
562+ return out
563+ @property
564+ def nfe (self ):
565+ return self .odefunc .nfe
566+ @nfe .setter
567+ def nfe (self , value ):
568+ self .odefunc .nfe = value
569+
570+
571+ def accuracy (model , dataset_loader ):
572+ total_correct = 0
573+ for x , y in dataset_loader :
574+ x = x .to (device )
575+ y = one_hot (np .array (y .numpy ()), 10 )
576+
577+
578+ target_class = np .argmax (y , axis = 1 )
579+ predicted_class = np .argmax (model (x ).cpu ().detach ().numpy (), axis = 1 )
580+ total_correct += np .sum (predicted_class == target_class )
581+ return total_correct / len (dataset_loader .dataset )
582+
583+
584+
585+ def save_training_feature (model , dataset_loader , fake_embeddings_loader = None ):
586+ x_save = []
587+ y_save = []
588+ modulelist = list (model )
522589 # Processing fake embeddings if provided
523590 if fake_embeddings_loader is not None :
524591 for x , y in fake_embeddings_loader :
@@ -534,12 +601,30 @@ def save_training_feature(model, dataset_loader, fake_embeddings_loader=None ):
534601 x_save .append (x_ )
535602 y_save .append (y_ )
536603
537- # Concatenate all collected data before saving
604+
605+ x_save = []
606+ y_save = []
607+ for x , y in dataset_loader :
608+ x = x .to (device )
609+ y_ = y .numpy () # No need to use np.array here
610+
611+ # Forward pass through the model up to the desired layer
612+ for l in modulelist [0 :2 ]:
613+ x = l (x )
614+ xo = x
615+
616+ x_ = xo .cpu ().detach ().numpy ()
617+
618+ x_save .append (x_ )
619+
620+ y_save .append (y_ )
621+
622+
538623 x_save = np .concatenate (x_save )
624+
539625 y_save = np .concatenate (y_save )
540626
541-
542- # Save the concatenated arrays to a file
627+
543628 np .savez (train_savepath , x_save = x_save , y_save = y_save )
544629
545630
@@ -568,35 +653,3 @@ def save_testing_feature(model, dataset_loader):
568653
569654
570655
571- class DensemnistDatasetTrain (Dataset ):
572- def __init__ (self ):
573- """
574- """
575- npzfile = np .load (train_savepath )
576-
577- self .x = npzfile ['x_save' ]
578- self .y = npzfile ['y_save' ]
579- def __len__ (self ):
580- return len (self .x )
581-
582- def __getitem__ (self , idx ):
583- x = self .x [idx ,...]
584- y = self .y [idx ]
585-
586- return x ,y
587- class DensemnistDatasetTest (Dataset ):
588- def __init__ (self ):
589- """
590- """
591- npzfile = np .load (test_savepath )
592-
593- self .x = npzfile ['x_save' ]
594- self .y = npzfile ['y_save' ]
595- def __len__ (self ):
596- return len (self .x )
597-
598- def __getitem__ (self , idx ):
599- x = self .x [idx ,...]
600- y = self .y [idx ]
601-
602- return x ,y
0 commit comments