Skip to content

Commit e437432

Browse files
committed
Upload PSNR check script
1 parent a9460ed commit e437432

File tree

2 files changed

+166
-0
lines changed

2 files changed

+166
-0
lines changed

tests/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,26 @@ Sound System 11 | IAMF | L3, R3, C, LFE, Ltf3, Rtf3,
9292
Sound System 12 | IAMF | C
9393
Sound System 13 | IAMF | FL, FR, FC, LFE, BL, BR, FLc, FRc, SiL, SiR, TpFL, TpFR, TpBL, TpBR, TpSiL, TpSiR
9494
Binaural Layout | IAMF | L2, R2
95+
96+
# Verification
97+
98+
For test cases using Opus or AAC codecs, the average PSNR value must exceed 30, and for the other codecs, an average PSNR value exceeding 80 is considered PASS.
99+
You can use `psnr_calc.py` file to calculate PSNR between reference and generated output.
100+
101+
- How to use `psnr_calc.py` script:
102+
```
103+
python psnr_calc.py
104+
--dir <directory path containing the target file and reference file>
105+
--target <target wav file name>
106+
--ref <reference wav file name>
107+
```
108+
109+
- Calculate PSNR values of multiple wav files
110+
111+
Multiple files can be entered as `::`
112+
113+
```
114+
Example:
115+
116+
python psnr_calc.py --dir . --target target1.wav::target2.wav --ref ref1.wav::ref2.wav
117+
```

tests/psnr_calc.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import argparse
2+
import wave
3+
import os
4+
import scipy.io.wavfile as wavfile
5+
import numpy as np
6+
import math
7+
8+
parser = argparse.ArgumentParser(description="PSNR verification script")
9+
parser.add_argument(
10+
"--dir",
11+
type=str,
12+
required=True,
13+
help="decoder verification wav output directory",
14+
)
15+
parser.add_argument(
16+
"--target",
17+
type=str,
18+
required=True,
19+
help="decoder verification wav output file. Multiple files can be entered as ::. (ex - test1.wav::test2.wav)",
20+
)
21+
parser.add_argument(
22+
"--ref",
23+
type=str,
24+
required=True,
25+
help="decoder verification PSNR evaluation reference file. Multiple files can be entered as ::. (ex - test1.wav::test2.wav)",
26+
)
27+
args = parser.parse_args()
28+
29+
30+
def get_sampwdith(path):
31+
with wave.open(path, "rb") as wf:
32+
sampwidth_bytes = wf.getsampwidth()
33+
return sampwidth_bytes
34+
35+
36+
def calc_psnr(ref_signal, signal, sampwidth_bytes):
37+
assert (
38+
sampwidth_bytes > 1
39+
), "Supports sample format: [pcm_s16le, pcm_s24le, pcm_s32le]"
40+
max_value = pow(2, sampwidth_bytes * 8) - 1
41+
42+
# To prevent overflow
43+
ref_signal = ref_signal.astype("int64")
44+
signal = signal.astype("int64")
45+
46+
mse = np.mean((ref_signal - signal) ** 2, axis=0, dtype="float64")
47+
48+
psnr_list = list()
49+
50+
# To support mono signal
51+
num_channels = 1 if ref_signal.shape[1:] == () else ref_signal.shape[1]
52+
for i in range(num_channels):
53+
mse_value = mse[i] if num_channels > 1 else mse
54+
if mse_value == 0:
55+
print(f"ch#{i} PSNR: inf")
56+
else:
57+
psnr_value = 10 * math.log10(max_value**2 / mse_value)
58+
psnr_list.append(psnr_value)
59+
print(f"ch#{i} PSNR: {psnr_value} dB")
60+
61+
return -1 if len(psnr_list) == 0 else sum(psnr_list) / len(psnr_list)
62+
63+
64+
target_file_list = args.target.split("::")
65+
ref_file_list = args.ref.split("::")
66+
67+
tc_number_list = []
68+
psnr_list = []
69+
for file_idx in range(len(target_file_list)):
70+
target_file = target_file_list[file_idx]
71+
ref_file = ref_file_list[file_idx]
72+
print(
73+
"[%d] PSNR evaluation: compare %s with %s"
74+
% (file_idx, target_file, ref_file)
75+
)
76+
tc_number_list.append(file_idx)
77+
try:
78+
ref_filepath = os.path.join(
79+
os.path.dirname(os.path.abspath(__file__)), args.dir, ref_file
80+
)
81+
target_filepath = os.path.join(
82+
os.path.dirname(os.path.abspath(__file__)), args.dir, target_file
83+
)
84+
85+
ref_samplerate, ref_data = wavfile.read(ref_filepath)
86+
target_samplerate, target_data = wavfile.read(target_filepath)
87+
88+
ref_sampwdith_bytes = get_sampwdith(ref_filepath)
89+
target_sampwidth_bytes = get_sampwdith(target_filepath)
90+
91+
# Check sampling rate
92+
if not (ref_samplerate == target_samplerate):
93+
print(ref_file, " / ", target_file)
94+
raise Exception(
95+
"Sampling rate of reference file and comparison file are different."
96+
)
97+
98+
# Check number of channels
99+
if not (ref_data.shape[1:] == target_data.shape[1:]):
100+
raise Exception(
101+
"Number of channels of reference file and comparison file are different."
102+
)
103+
104+
# Check number of samples
105+
if not (ref_data.shape[0] == target_data.shape[0]):
106+
print(ref_file, " / ", target_file)
107+
raise Exception(
108+
"Number of samples of reference file and comparison file are different."
109+
)
110+
111+
# Check bit depth
112+
if not (ref_sampwdith_bytes == target_sampwidth_bytes):
113+
print(ref_file, " / ", target_file)
114+
raise Exception(
115+
"Bit depth of reference file and comparison file are different."
116+
)
117+
118+
average_psnr = calc_psnr(ref_data, target_data, ref_sampwdith_bytes)
119+
if average_psnr != -1:
120+
print("average PSNR: %.15f" % (average_psnr))
121+
psnr_list.append(average_psnr)
122+
else:
123+
print("average PSNR: %.15f" % (100))
124+
psnr_list.append(100)
125+
except Exception as err:
126+
print(str(err))
127+
psnr_list.append(0)
128+
print("")
129+
130+
# print result
131+
print(
132+
"\n\n\n[Result] - (If the OPUS or AAC codec has a over avgPSNR 30, it is considered PASS. Other codecs must be over avgPSNR 80.)"
133+
)
134+
for i in range(len(tc_number_list)):
135+
print(
136+
"TC#%d : %.3f (compare %s with %s)"
137+
% (
138+
tc_number_list[i],
139+
round(psnr_list[i], 3),
140+
target_file_list[i],
141+
ref_file_list[i],
142+
)
143+
)

0 commit comments

Comments
 (0)