Skip to content

Commit fa11c9a

Browse files
moienrvfdev-5
andauthored
fixed output_list | now it has only 5 elements instead of 5*batchsize (#2914)
* fixed the output_list | now it has only 5 elements instead of 5*batch_size * Update ignite/metrics/ssim.py batch_size variable instead of using y_pred.size(0) directly Co-authored-by: vfdev <[email protected]> --------- Co-authored-by: vfdev <[email protected]>
1 parent fe06ba0 commit fa11c9a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ignite/metrics/ssim.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,10 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
159159
y_pred = F.pad(y_pred, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect")
160160
y = F.pad(y, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect")
161161

162-
input_list = torch.cat([y_pred, y, y_pred * y_pred, y * y, y_pred * y])
163-
outputs = F.conv2d(input_list, self._kernel, groups=channel)
164-
165-
output_list = [outputs[x * y_pred.size(0) : (x + 1) * y_pred.size(0)] for x in range(len(outputs))]
162+
input_list = [y_pred, y, y_pred * y_pred, y * y, y_pred * y]
163+
outputs = F.conv2d(torch.cat(input_list), self._kernel, groups=channel)
164+
batch_size = y_pred.size(0)
165+
output_list = [outputs[x * batch_size : (x + 1) * batch_size] for x in range(len(input_list))]
166166

167167
mu_pred_sq = output_list[0].pow(2)
168168
mu_target_sq = output_list[1].pow(2)

0 commit comments

Comments
 (0)