diff --git a/max_ssim.py b/max_ssim.py index 608fdc8..c3aeef5 100644 --- a/max_ssim.py +++ b/max_ssim.py @@ -20,7 +20,7 @@ # Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True) -ssim_value = pytorch_ssim.ssim(img1, img2).data[0] +ssim_value = pytorch_ssim.ssim(img1, img2).data.item() print("Initial ssim:", ssim_value) # Module: pytorch_ssim.SSIM(window_size = 11, size_average = True) @@ -31,7 +31,7 @@ while ssim_value < 0.95: optimizer.zero_grad() ssim_out = -ssim_loss(img1, img2) - ssim_value = - ssim_out.data[0] + ssim_value = - ssim_out.data.item() print(ssim_value) ssim_out.backward() optimizer.step()