Skip to content

Commit a7a2f31

Browse files
authored
fix file path when render model (#1857)
1 parent f5aca91 commit a7a2f31

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

numpyro/infer/inspect.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -629,9 +629,14 @@ def render_model(
629629

630630
if filename is not None:
631631
filename = Path(filename)
632+
# remove leading period from suffix
633+
filename_without_suffix = filename.with_suffix("")
632634
graph.render(
633-
filename.stem, view=False, cleanup=True, format=filename.suffix[1:]
634-
) # remove leading period from suffix
635+
filename_without_suffix,
636+
view=False,
637+
cleanup=True,
638+
format=filename.suffix[1:],
639+
)
635640

636641
return graph
637642

test/test_model_rendering.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
# Copyright Contributors to the Pyro project.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import os
5+
46
import numpy as np
57
import pytest
68

79
import jax.numpy as jnp
810

911
import numpyro
1012
import numpyro.distributions as dist
11-
from numpyro.infer.inspect import generate_graph_specification, get_model_relations
13+
from numpyro.infer.inspect import (
14+
generate_graph_specification,
15+
get_model_relations,
16+
render_model,
17+
)
1218

1319

1420
def simple(data):
@@ -129,3 +135,12 @@ def test_model_transformation(test_model, model_kwargs, expected_graph_spec):
129135
graph_spec = generate_graph_specification(relations)
130136

131137
assert graph_spec == expected_graph_spec
138+
139+
140+
def test_render_model_filename():
141+
def model():
142+
numpyro.sample("x", dist.Normal(0, 1))
143+
144+
render_model(model, filename="graph.png")
145+
assert os.path.exists("graph.png")
146+
os.remove("graph.png")

0 commit comments

Comments
 (0)