Skip to content

Commit b2cb975

Browse files
authored
AC: fix psnr in case color reference and gray prediction (#3194)
1 parent a8721ad commit b2cb975

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/metrics/image_quality_assessment.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def configure(self):
111111
self.color_scale = 255 if not self.normalized_images else 1
112112

113113
def _psnr_differ(self, annotation_image, prediction_image):
114-
prediction = np.squeeze(np.asarray(prediction_image)).astype(np.float)
115-
ground_truth = np.squeeze(np.asarray(annotation_image)).astype(np.float)
114+
prediction = np.squeeze(np.asarray(prediction_image)).astype(float)
115+
ground_truth = np.squeeze(np.asarray(annotation_image)).astype(float)
116116

117117
height, width = prediction.shape[:2]
118118
prediction = prediction[
@@ -123,6 +123,11 @@ def _psnr_differ(self, annotation_image, prediction_image):
123123
self.scale_border:height - self.scale_border,
124124
self.scale_border:width - self.scale_border
125125
]
126+
if np.ndim(ground_truth) == 3 and prediction.ndim == 2:
127+
ground_truth = cv2.cvtColor(
128+
ground_truth.astype(np.float32),
129+
cv2.COLOR_BGR2GRAY if self.color_order == 'BGR' else cv2.COLOR_RGB2GRAY
130+
).astype(float)
126131
image_difference = (prediction - ground_truth) / self.color_scale
127132
if len(ground_truth.shape) == 3 and ground_truth.shape[2] == 3:
128133
r_channel_diff = image_difference[:, :, self.channel_order[0]]

0 commit comments

Comments
 (0)