@@ -40,6 +40,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4040 return x
4141
4242
43+ class MultiHeadLogisticRegression (nn .Module ):
44+ def __init__ (self ):
45+ super ().__init__ ()
46+ self .fc1 = nn .Linear (2 , 2 )
47+ self .head1 = nn .Linear (2 , 1 )
48+ self .head2 = nn .Linear (2 , 1 )
49+
50+ def forward (self , x : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
51+ x = self .fc1 (x )
52+ x = F .relu (x )
53+ return self .head1 (x ), self .head2 (x )
54+
55+
4356def make_dataset (num_samples : int = 100 , dims : int = 2 , seed : int = 42 ) -> Tuple [torch .Tensor , torch .Tensor ]:
4457 rng = np .random .RandomState (seed )
4558
@@ -181,6 +194,9 @@ def test_sam_optimizers(optimizer_config):
181194 loss_fn (y_data , model (x_data )).backward ()
182195 optimizer .second_step (zero_grad = True )
183196
197+ if init_loss == np .inf :
198+ init_loss = loss
199+
184200 assert init_loss > 2.0 * loss
185201
186202
@@ -190,9 +206,9 @@ def test_pc_grad_optimizers(optimizer_config):
190206
191207 x_data , y_data = make_dataset ()
192208
193- model : nn .Module = LogisticRegression ()
209+ model : nn .Module = MultiHeadLogisticRegression ()
194210 loss_fn_1 : nn .Module = nn .BCEWithLogitsLoss ()
195- loss_fn_2 : nn .Module = nn .BCEWithLogitsLoss ()
211+ loss_fn_2 : nn .Module = nn .L1Loss ()
196212
197213 optimizer_class , config , iterations = optimizer_config
198214 optimizer = PCGrad (optimizer_class (model .parameters (), ** config ))
@@ -201,8 +217,14 @@ def test_pc_grad_optimizers(optimizer_config):
201217 init_loss : float = np .inf
202218 for _ in range (iterations ):
203219 optimizer .zero_grad ()
204- y_pred = model (x_data )
205- loss1 , loss2 = loss_fn_1 (y_pred , y_data ), loss_fn_2 (y_pred , y_data )
220+ y_pred_1 , y_pred_2 = model (x_data )
221+ loss1 , loss2 = loss_fn_1 (y_pred_1 , y_data ), loss_fn_2 (y_pred_2 , y_data )
222+
223+ loss = (loss1 + loss2 ) / 2.0
224+ if init_loss == np .inf :
225+ init_loss = loss
226+
206227 optimizer .pc_backward ([loss1 , loss2 ])
228+ optimizer .step ()
207229
208230 assert init_loss > 2.0 * loss
0 commit comments