Skip to content

Commit 3fcd3c5

Browse files
committed
fix: pytest for evaluating model performance dont read stderr anymore
1 parent bfeca31 commit 3fcd3c5

File tree

3 files changed

+26
-10
lines changed

3 files changed

+26
-10
lines changed

compressai/utils/video/eval_model/__main__.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,13 @@ def create_parser() -> argparse.ArgumentParser:
467467
default="mse",
468468
help="metric trained against (default: %(default)s)",
469469
)
470+
parent_parser.add_argument(
471+
"-d",
472+
"--output_directory",
473+
type=str,
474+
default="",
475+
help="path of output directory. Optional, required for output json file, results per video.",
476+
)
470477
parent_parser.add_argument(
471478
"-o",
472479
"--output-file",
@@ -525,8 +532,8 @@ def main(args: Any = None) -> None:
525532
raise SystemExit(1)
526533

527534
# create output directory
528-
outputdir = args.output
529-
Path(outputdir).mkdir(parents=True, exist_ok=True)
535+
if args.output_directory:
536+
Path(args.output_directory).mkdir(parents=True, exist_ok=True)
530537

531538
if args.source == "pretrained":
532539
args.qualities = [int(q) for q in args.qualities.split(",") if q]
@@ -561,7 +568,7 @@ def main(args: Any = None) -> None:
561568
filepaths,
562569
args.dataset,
563570
model,
564-
outputdir,
571+
args.output_directory,
565572
trained_net=trained_net,
566573
description=description,
567574
**args_dict,
@@ -581,7 +588,7 @@ def main(args: Any = None) -> None:
581588
else:
582589
output_file = args.output_file
583590

584-
with (Path(f"{outputdir}/{output_file}").with_suffix(".json")).open("wb") as f:
591+
with (Path(f"{args.output_directory}/{output_file}").with_suffix(".json")).open("wb") as f:
585592
f.write(json.dumps(output, indent=2).encode())
586593
print(json.dumps(output, indent=2))
587594

tests/test_eval_model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_eval_model():
7878
@pytest.mark.parametrize("metric", ("mse", "ms-ssim"))
7979
@pytest.mark.parametrize("entropy_estimation", (False, True))
8080
def test_eval_model_pretrained(
81-
capsys, model, quality, metric, entropy_estimation, tmpdir
81+
model, quality, metric, entropy_estimation, tmpdir
8282
):
8383
here = os.path.dirname(__file__)
8484
dirpath = os.path.join(here, "assets/dataset/image")
@@ -92,13 +92,17 @@ def test_eval_model_pretrained(
9292
metric,
9393
"-q",
9494
quality,
95+
str(tmpdir),
96+
"-o",
97+
f"{model}-{metric}-{quality}",
9598
]
9699
if entropy_estimation:
97100
cmd += ["--entropy-estimation"]
98101
eval_model.main(cmd)
99102

100-
output = capsys.readouterr().out
101-
output = json.loads(output)
103+
with open(f"{tmpdir}/{model}-{metric}-{quality}.json") as f:
104+
output = json.load(f)
105+
102106
expected = os.path.join(
103107
here,
104108
"expected",

tests/test_eval_model_video.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_eval_model_video():
8080
@pytest.mark.parametrize("metric", ("mse",))
8181
@pytest.mark.parametrize("entropy_estimation", (True, False))
8282
def test_eval_model_pretrained(
83-
capsys, model, quality, metric, entropy_estimation, tmpdir
83+
model, quality, metric, entropy_estimation, tmpdir
8484
):
8585
here = os.path.dirname(__file__)
8686
dirpath = os.path.join(here, "assets/dataset/video")
@@ -95,13 +95,18 @@ def test_eval_model_pretrained(
9595
metric,
9696
"-q",
9797
quality,
98+
"-d",
99+
str(tmpdir),
100+
"-o",
101+
f"{model}-{metric}-{quality}",
98102
]
99103
if entropy_estimation:
100104
cmd += ["--entropy-estimation"]
101105
eval_model.main(cmd)
102106

103-
output = capsys.readouterr().out
104-
output = json.loads(output)
107+
with open(f"{tmpdir}/{model}-{metric}-{quality}.json") as f:
108+
output = json.load(f)
109+
105110
expected = os.path.join(
106111
here,
107112
"expected",

0 commit comments

Comments
 (0)