Skip to content

Commit 49cc6e5

Browse files
committed
Remove unpacking
This is required when calling grad.
1 parent 533e281 commit 49cc6e5

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

common/darts/architecture.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ def comp_unrolled_model(self, data, target, eta, optimizer):
3333
data : torch.tensor
3434
3535
target : torch.tensor
36-
3736
eta : float
38-
3937
optimizer : torch.optim.optimizer
4038
optimizer of theta, not optimizer of alpha
4139
@@ -44,10 +42,11 @@ def comp_unrolled_model(self, data, target, eta, optimizer):
4442
model_unrolled
4543
"""
4644
# forward to get loss
47-
_, loss = self.model.loss(data, target, reduce='mean')
45+
loss = self.model.loss(data, target)
4846
# flatten current weights
4947
theta = F.flatten(self.model.parameters()).detach()
5048
# theta: torch.Size([1930618])
49+
# print('theta:', theta.shape)
5150
try:
5251
# fetch momentum data from theta optimizer
5352
moment = F.flatten(optimizer.state[v]['momentum_buffer'] for v in self.model.parameters())

0 commit comments

Comments
 (0)