Skip to content

Commit e4ae491

Browse files
committed
[fix] bug - two different devices for nn part 1 and nn part 2 of yolo network
1 parent 6814006 commit e4ae491

File tree

1 file changed

+6
-0
lines changed
  • compressai_vision/model_wrappers

1 file changed

+6
-0
lines changed

compressai_vision/model_wrappers/jde.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,12 @@ def features_to_output(self, x: Dict, device: str):
126126
self.darknet = self.darknet.to(device).eval()
127127
self.darknet.device = device # Please refer to Darknet
128128

129+
for module_def, module in zip(
130+
self.darknet.module_defs[::-1], self.darknet.module_list[::-1]
131+
):
132+
if module_def["type"] == "yolo":
133+
module[0].device = device
134+
129135
return self._feature_pyramid_to_output(
130136
x["data"], x["org_input_size"], x["input_size"]
131137
)

0 commit comments

Comments
 (0)