Skip to content

Commit f857ccc

Browse files
authored
Update utils.py
1 parent 3e8ac20 commit f857ccc

File tree

1 file changed

+122
-69
lines changed

1 file changed

+122
-69
lines changed

AROS/utils.py

Lines changed: 122 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2222

2323

24+
25+
2426
weight_diag = 10
2527
weight_offdiag = 0
2628
weight_f = 0.1
@@ -30,15 +32,17 @@
3032

3133
exponent = 1.0
3234
exponent_off = 0.1
33-
exponent_f = 50
35+
exponent_f = 20
3436
time_df = 1
3537
trans = 1.0
3638
transoffdig = 1.0
3739
numm = 16
3840

39-
batches_per_epoch = 128
4041

41-
ODE_FC_odebatch = 64
42+
43+
44+
45+
ODE_FC_odebatch = 32
4246

4347
class Identity(nn.Module):
4448
def __init__(self):
@@ -56,7 +60,7 @@ def forward(self, t, x):
5660
class ODEfunc_mlp(nn.Module):
5761
def __init__(self, dim):
5862
super(ODEfunc_mlp, self).__init__()
59-
self.fc1 = ConcatFC(64, 64)
63+
self.fc1 = ConcatFC(128, 128)
6064
self.act1 = torch.sin
6165
self.nfe = 0
6266
def forward(self, t, x):
@@ -65,21 +69,8 @@ def forward(self, t, x):
6569
out = self.act1(out)
6670
return out
6771

68-
class ODEBlock(nn.Module):
69-
def __init__(self, odefunc):
70-
super(ODEBlock, self).__init__()
71-
self.odefunc = odefunc
72-
self.integration_time = torch.tensor([0, 5]).float()
73-
def forward(self, x):
74-
self.integration_time = self.integration_time.type_as(x)
75-
out = odeint(self.odefunc, x, self.integration_time, rtol=1e-3, atol=1e-3)
76-
return out[1]
77-
@property
78-
def nfe(self):
79-
return self.odefunc.nfe
80-
@nfe.setter
81-
def nfe(self, value):
82-
self.odefunc.nfe = value
72+
73+
8374

8475
class ODEBlocktemp(nn.Module):
8576
def __init__(self, odefunc):
@@ -100,7 +91,7 @@ class MLP_OUT_ORTH1024(nn.Module):
10091
def __init__(self,layer_dim_):
10192
super(MLP_OUT_ORTH1024, self).__init__()
10293
self.layer_dim_ = layer_dim_
103-
self.fc0 = ORTHFC(self.layer_dim_, 64, False)
94+
self.fc0 = ORTHFC(self.layer_dim_, 128, False)
10495
def forward(self, input_):
10596
h1 = self.fc0(input_)
10697
return h1
@@ -149,7 +140,7 @@ class MLP_OUT_LINEAR(nn.Module):
149140
def __init__(self,class_numbers):
150141
self.class_numbers = class_numbers
151142
super(MLP_OUT_LINEAR, self).__init__()
152-
self.fc0 = nn.Linear(64, class_numbers)
143+
self.fc0 = nn.Linear(128, class_numbers)
153144
def forward(self, input_):
154145
h1 = self.fc0(input_)
155146
return h1
@@ -158,8 +149,8 @@ class MLP_OUT_BALL(nn.Module):
158149
def __init__(self,class_numbers):
159150
super(MLP_OUT_BALL, self).__init__()
160151
self.class_numbers = class_numbers
161-
self.fc0 = nn.Linear(64, class_numbers, bias=False)
162-
self.fc0.weight.data = torch.randn([class_numbers,64])
152+
self.fc0 = nn.Linear(128, class_numbers, bias=False)
153+
self.fc0.weight.data = torch.randn([class_numbers,128])
163154
def forward(self, input_):
164155
h1 = self.fc0(input_)
165156
return h1
@@ -362,26 +353,24 @@ def train(net, epoch,trainloader,optimizer):
362353
def one_hot(x, K):
363354
return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)
364355

365-
366-
class ODEBlock(nn.Module):
367356

357+
358+
class ODEBlock(nn.Module):
368359
def __init__(self, odefunc):
369360
super(ODEBlock, self).__init__()
370361
self.odefunc = odefunc
371362
self.integration_time = torch.tensor([0, 5]).float()
372-
373363
def forward(self, x):
374364
self.integration_time = self.integration_time.type_as(x)
375365
out = odeint(self.odefunc, x, self.integration_time, rtol=1e-3, atol=1e-3)
376366
return out[1]
377-
378367
@property
379368
def nfe(self):
380369
return self.odefunc.nfe
381-
382370
@nfe.setter
383371
def nfe(self, value):
384372
self.odefunc.nfe = value
373+
385374

386375

387376

@@ -501,24 +490,102 @@ def f_regularizer(odefunc, z):
501490

502491

503492

504-
def save_training_feature(model, dataset_loader, fake_embeddings_loader=None ):
493+
494+
495+
def save_testing_feature(model, dataset_loader):
505496
x_save = []
506497
y_save = []
507498
modulelist = list(model)
508-
499+
layernum = 0
509500
for x, y in dataset_loader:
510501
x = x.to(device)
511-
y_ = y.numpy() # No need to use np.array here
502+
y_ = np.array(y.numpy())
512503

513-
# Forward pass through the model up to the desired layer
514504
for l in modulelist[0:2]:
515-
x = l(x)
505+
x = l(x)
516506
xo = x
517-
518507
x_ = xo.cpu().detach().numpy()
519508
x_save.append(x_)
520509
y_save.append(y_)
521510

511+
x_save = np.concatenate(x_save)
512+
y_save = np.concatenate(y_save)
513+
514+
np.savez(test_savepath, x_save=x_save, y_save=y_save)
515+
516+
517+
518+
519+
class DensemnistDatasetTrain(Dataset):
520+
def __init__(self):
521+
"""
522+
"""
523+
npzfile = np.load(train_savepath)
524+
525+
self.x = npzfile['x_save']
526+
self.y = npzfile['y_save']
527+
def __len__(self):
528+
return len(self.x)
529+
530+
def __getitem__(self, idx):
531+
x = self.x[idx,...]
532+
y = self.y[idx]
533+
534+
return x,y
535+
class DensemnistDatasetTest(Dataset):
536+
def __init__(self):
537+
"""
538+
"""
539+
npzfile = np.load(test_savepath)
540+
541+
self.x = npzfile['x_save']
542+
self.y = npzfile['y_save']
543+
def __len__(self):
544+
return len(self.x)
545+
546+
def __getitem__(self, idx):
547+
x = self.x[idx,...]
548+
y = self.y[idx]
549+
550+
return x,y
551+
552+
553+
554+
555+
class ODEBlocktemp(nn.Module):
556+
def __init__(self, odefunc):
557+
super(ODEBlocktemp, self).__init__()
558+
self.odefunc = odefunc
559+
self.integration_time = torch.tensor([0, 5]).float()
560+
def forward(self, x):
561+
out = self.odefunc(0, x)
562+
return out
563+
@property
564+
def nfe(self):
565+
return self.odefunc.nfe
566+
@nfe.setter
567+
def nfe(self, value):
568+
self.odefunc.nfe = value
569+
570+
571+
def accuracy(model, dataset_loader):
572+
total_correct = 0
573+
for x, y in dataset_loader:
574+
x = x.to(device)
575+
y = one_hot(np.array(y.numpy()), 10)
576+
577+
578+
target_class = np.argmax(y, axis=1)
579+
predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
580+
total_correct += np.sum(predicted_class == target_class)
581+
return total_correct / len(dataset_loader.dataset)
582+
583+
584+
585+
def save_training_feature(model, dataset_loader, fake_embeddings_loader=None ):
586+
x_save = []
587+
y_save = []
588+
modulelist = list(model)
522589
# Processing fake embeddings if provided
523590
if fake_embeddings_loader is not None:
524591
for x, y in fake_embeddings_loader:
@@ -534,12 +601,30 @@ def save_training_feature(model, dataset_loader, fake_embeddings_loader=None ):
534601
x_save.append(x_)
535602
y_save.append(y_)
536603

537-
# Concatenate all collected data before saving
604+
605+
x_save = []
606+
y_save = []
607+
for x, y in dataset_loader:
608+
x = x.to(device)
609+
y_ = y.numpy() # No need to use np.array here
610+
611+
# Forward pass through the model up to the desired layer
612+
for l in modulelist[0:2]:
613+
x = l(x)
614+
xo = x
615+
616+
x_ = xo.cpu().detach().numpy()
617+
618+
x_save.append(x_)
619+
620+
y_save.append(y_)
621+
622+
538623
x_save = np.concatenate(x_save)
624+
539625
y_save = np.concatenate(y_save)
540626

541-
542-
# Save the concatenated arrays to a file
627+
543628
np.savez(train_savepath, x_save=x_save, y_save=y_save)
544629

545630

@@ -568,35 +653,3 @@ def save_testing_feature(model, dataset_loader):
568653

569654

570655

571-
class DensemnistDatasetTrain(Dataset):
572-
def __init__(self):
573-
"""
574-
"""
575-
npzfile = np.load(train_savepath)
576-
577-
self.x = npzfile['x_save']
578-
self.y = npzfile['y_save']
579-
def __len__(self):
580-
return len(self.x)
581-
582-
def __getitem__(self, idx):
583-
x = self.x[idx,...]
584-
y = self.y[idx]
585-
586-
return x,y
587-
class DensemnistDatasetTest(Dataset):
588-
def __init__(self):
589-
"""
590-
"""
591-
npzfile = np.load(test_savepath)
592-
593-
self.x = npzfile['x_save']
594-
self.y = npzfile['y_save']
595-
def __len__(self):
596-
return len(self.x)
597-
598-
def __getitem__(self, idx):
599-
x = self.x[idx,...]
600-
y = self.y[idx]
601-
602-
return x,y

0 commit comments

Comments
 (0)