Skip to content

Commit 88930ca

Browse files
authored
Fix loading checkpoint after 1st round of training for DFine-X model (#4738)
1 parent 58f990f commit 88930ca

File tree

1 file changed

+14
-0
lines changed
  • lib/src/otx/backend/native/models/detection

1 file changed

+14
-0
lines changed

lib/src/otx/backend/native/models/detection/d_fine.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,17 @@ def _optimization_config(self) -> dict[str, Any]:
158158
},
159159
},
160160
}
161+
162+
def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
163+
"""Load state dictionary from checkpoint state dictionary.
164+
165+
If a RuntimeError occurs due to size mismatch, non-trainable anchors and valid_mask
166+
are removed from the checkpoint before loading.
167+
"""
168+
try:
169+
return super().load_state_dict(ckpt, *args, **kwargs)
170+
except RuntimeError:
171+
# Remove non-trainable anchors and valid_mask from the checkpoint to avoid size mismatch
172+
ckpt.pop("model.decoder.anchors")
173+
ckpt.pop("model.decoder.valid_mask")
174+
return super().load_state_dict(ckpt, *args, strict=False, **kwargs)

0 commit comments

Comments
 (0)