Skip to content

Commit 0571c4e

Browse files
authored
Update main.py
1 parent a4339a9 commit 0571c4e

File tree

1 file changed

+1
-8
lines changed

1 file changed

+1
-8
lines changed

AROS/main.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
def main():
1313
parser = argparse.ArgumentParser(description="Hyperparameters for the script")
1414

15-
# Define the hyperparameters controlled via CLI 'Ding2020MMA'
1615
parser.add_argument('--fast', type=bool, default=True, help='Toggle between fast and full fake data generation modes')
1716
parser.add_argument('--epoch1', type=int, default=2, help='Number of epochs for stage 1')
1817
parser.add_argument('--epoch2', type=int, default=1, help='Number of epochs for stage 2')
@@ -96,19 +95,13 @@ def main():
9695
fake_data = np.vstack(fake_data)
9796
fake_data = torch.tensor(fake_data).float()
9897
fake_data = F.normalize(fake_data, p=2, dim=1)
99-
10098
fake_labels = torch.full((fake_data.shape[0],), 10)
10199
fake_loader = DataLoader(TensorDataset(fake_data, fake_labels), batch_size=128, shuffle=True)
102100

103101
if args.fast==True:
104-
105-
106-
noise_std = 0.1 # standard deviation of noise
107-
noisy_embeddings = torch.tensor(embeddings) + noise_std * torch.randn_like(torch.tensor(embeddings))
108-
102+
noisy_embeddings = torch.tensor(embeddings) + args.noise_std * torch.randn_like(torch.tensor(embeddings))
109103
# Normalize Noisy Embeddings
110104
noisy_embeddings = F.normalize(noisy_embeddings, p=2, dim=1)[:len(trainloader.dataset)//num_classes]
111-
112105
# Convert to DataLoader if needed
113106
fake_labels = torch.full((noisy_embeddings.shape[0],), num_classes)[:len(trainloader.dataset)//num_classes]
114107
fake_loader = DataLoader(TensorDataset(noisy_embeddings, fake_labels), batch_size=128, shuffle=True)

0 commit comments

Comments
 (0)