Skip to content

Commit b3eb2b4

Browse files
committed
Call loss consistently
When unrolling models, we were not unpacking the tuple returned by model.loss like we were in other methods. This keeps things consisent.
1 parent 8fec536 commit b3eb2b4

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

common/darts/architecture.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,10 @@ def comp_unrolled_model(self, data, target, eta, optimizer):
4444
model_unrolled
4545
"""
4646
# forward to get loss
47-
loss = self.model.loss(data, target)
47+
_, loss = self.model.loss(data, target, reduce='mean')
4848
# flatten current weights
4949
theta = F.flatten(self.model.parameters()).detach()
5050
# theta: torch.Size([1930618])
51-
# print('theta:', theta.shape)
5251
try:
5352
# fetch momentum data from theta optimizer
5453
moment = F.flatten(optimizer.state[v]['momentum_buffer'] for v in self.model.parameters())
@@ -57,6 +56,7 @@ def comp_unrolled_model(self, data, target, eta, optimizer):
5756
moment = torch.zeros_like(theta)
5857

5958
# flatten all gradients
59+
gradient= autograd.grad(loss, self.model.parameters(), allow_unused=True)
6060
dtheta = F.flatten(autograd.grad(loss, self.model.parameters())).data
6161
# indeed, here we implement a simple SGD with momentum and weight decay
6262
# theta = theta - eta * (moment + weight decay + dtheta)

examples/darts/uno/default_model.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[Global_Params]
22
model_name = 'darts_uno'
3-
unrolled = False
3+
unrolled = True
44
data_url = 'http://ftp.mcs.anl.gov/pub/candle/public/benchmarks/Pilot1/uno/'
55
savepath = '.'
66
log_interval = 10

0 commit comments

Comments
 (0)