Skip to content

Commit 7d707b8

Browse files
committed
update: test_pc_grad_optimizers
1 parent 88b4370 commit 7d707b8

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

tests/test_optimizers.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4040
return x
4141

4242

43+
class MultiHeadLogisticRegression(nn.Module):
44+
def __init__(self):
45+
super().__init__()
46+
self.fc1 = nn.Linear(2, 2)
47+
self.head1 = nn.Linear(2, 1)
48+
self.head2 = nn.Linear(2, 1)
49+
50+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
51+
x = self.fc1(x)
52+
x = F.relu(x)
53+
return self.head1(x), self.head2(x)
54+
55+
4356
def make_dataset(num_samples: int = 100, dims: int = 2, seed: int = 42) -> Tuple[torch.Tensor, torch.Tensor]:
4457
rng = np.random.RandomState(seed)
4558

@@ -181,6 +194,9 @@ def test_sam_optimizers(optimizer_config):
181194
loss_fn(y_data, model(x_data)).backward()
182195
optimizer.second_step(zero_grad=True)
183196

197+
if init_loss == np.inf:
198+
init_loss = loss
199+
184200
assert init_loss > 2.0 * loss
185201

186202

@@ -190,9 +206,9 @@ def test_pc_grad_optimizers(optimizer_config):
190206

191207
x_data, y_data = make_dataset()
192208

193-
model: nn.Module = LogisticRegression()
209+
model: nn.Module = MultiHeadLogisticRegression()
194210
loss_fn_1: nn.Module = nn.BCEWithLogitsLoss()
195-
loss_fn_2: nn.Module = nn.BCEWithLogitsLoss()
211+
loss_fn_2: nn.Module = nn.L1Loss()
196212

197213
optimizer_class, config, iterations = optimizer_config
198214
optimizer = PCGrad(optimizer_class(model.parameters(), **config))
@@ -201,8 +217,14 @@ def test_pc_grad_optimizers(optimizer_config):
201217
init_loss: float = np.inf
202218
for _ in range(iterations):
203219
optimizer.zero_grad()
204-
y_pred = model(x_data)
205-
loss1, loss2 = loss_fn_1(y_pred, y_data), loss_fn_2(y_pred, y_data)
220+
y_pred_1, y_pred_2 = model(x_data)
221+
loss1, loss2 = loss_fn_1(y_pred_1, y_data), loss_fn_2(y_pred_2, y_data)
222+
223+
loss = (loss1 + loss2) / 2.0
224+
if init_loss == np.inf:
225+
init_loss = loss
226+
206227
optimizer.pc_backward([loss1, loss2])
228+
optimizer.step()
207229

208230
assert init_loss > 2.0 * loss

0 commit comments

Comments
 (0)