Skip to content

Commit 220e88d

Browse files
authored
Add option to pass kwargs for savefig (#1363)
1 parent cdbda05 commit 220e88d

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/torchio/visualization.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import warnings
44
from pathlib import Path
55
from typing import TYPE_CHECKING
6+
from typing import Any
67

78
import numpy as np
89
import torch
@@ -24,6 +25,7 @@
2425
if TYPE_CHECKING:
2526
from matplotlib.colors import BoundaryNorm
2627
from matplotlib.colors import ListedColormap
28+
from matplotlib.figure import Figure
2729

2830

2931
def import_mpl_plt():
@@ -80,10 +82,11 @@ def plot_volume(
8082
reorient=True,
8183
indices=None,
8284
rgb=True,
85+
savefig_kwargs: dict[str, Any] | None = None,
8386
**imshow_kwargs,
84-
):
87+
) -> Figure | None:
8588
_, plt = import_mpl_plt()
86-
fig = None
89+
fig: Figure | None = None
8790
if axes is None:
8891
fig, axes = plt.subplots(1, 3, figsize=figsize)
8992

@@ -182,7 +185,9 @@ def plot_volume(
182185
plt.suptitle(title)
183186

184187
if output_path is not None and fig is not None:
185-
fig.savefig(output_path)
188+
if savefig_kwargs is None:
189+
savefig_kwargs = {}
190+
fig.savefig(output_path, **savefig_kwargs)
186191
if show:
187192
plt.show()
188193
return fig

0 commit comments

Comments
 (0)