Skip to content

Commit b9a92ce

Browse files
committed
fixed inplace init
1 parent 2e04b5f commit b9a92ce

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

tests/normalized_optimizer_convergence_test.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,9 @@ def __init__(self, input_size=784, hidden_size=128, num_classes=10, dim=0):
3636
def _initialize_oblique_weights(self, dim):
3737
"""Initialize weights to be normalized for oblique optimization."""
3838
with torch.no_grad():
39-
# Normalize of oblique layers
40-
self.fc1.weight.data = self.fc1.weight.data / self.fc1.weight.data.norm(dim=dim, keepdim=True).clamp(
41-
min=1e-8
42-
)
43-
self.fc2.weight.data = self.fc2.weight.data / self.fc2.weight.data.norm(dim=dim, keepdim=True).clamp(
44-
min=1e-8
45-
)
39+
# Normalize in-place for oblique layers
40+
self.fc1.weight.data /= self.fc1.weight.data.norm(dim=dim, keepdim=True).clamp(min=1e-8)
41+
self.fc2.weight.data /= self.fc2.weight.data.norm(dim=dim, keepdim=True).clamp(min=1e-8)
4642

4743
def forward(self, x):
4844
x = x.view(x.size(0), -1) # Flatten
@@ -196,7 +192,7 @@ def test_optimizer_modes_convergence(self, optimizer_class: torch.optim.Optimize
196192
# Re-initialize for row normalization
197193
with torch.no_grad():
198194
for param in model.get_oblique_parameters():
199-
param.data = param.data / param.data.norm(dim=optimizer_kwargs["dim"], keepdim=True).clamp(min=1e-8)
195+
param.data /= param.data.norm(dim=optimizer_kwargs["dim"], keepdim=True).clamp(min=1e-8)
200196

201197
# Train model
202198
initial_loss, final_loss, final_accuracy = self._train_model(

0 commit comments

Comments
 (0)