Skip to content

Commit 3850cf6

Browse files
authored
update SOC code
1 parent bf6c3ea commit 3850cf6

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,12 +277,12 @@ def soc_adaptation_iter(
277277

278278
# NOTE: using the formulas in our paper to calculate the following losses has similar results
279279
# sub-objectives consistency between `pred_detail` and `pred_backup_detail` (on boundaries only)
280-
backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail)
280+
backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail, reduction='none')
281281
backup_detail_loss = torch.sum(backup_detail_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3))
282282
backup_detail_loss = torch.mean(backup_detail_loss)
283283

284284
# sub-objectives consistency between pred_matte` and `pred_backup_matte` (on boundaries only)
285-
backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte)
285+
backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte, reduction='none')
286286
backup_matte_loss = torch.sum(backup_matte_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3))
287287
backup_matte_loss = torch.mean(backup_matte_loss)
288288

0 commit comments

Comments
 (0)