Skip to content

Commit dd5ee79

Browse files
committed
Make export_to_nir import correct and move it into .to_nir() method to make sure that people can also run benchmarks without having nirtorch installed
1 parent bb23c6b commit dd5ee79

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

neurobench/benchmarks/benchmark.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121
import snntorch
2222
from torch import Tensor
2323

24-
if snntorch.__version__ >= "0.9.0":
25-
from snntorch import export_to_nir
26-
2724
import torch
2825
import nir
2926

@@ -222,6 +219,10 @@ def to_nir(self, dummy_input: Tensor, filename: str, **kwargs) -> None:
222219
"""
223220
if snntorch.__version__ < "0.9.0":
224221
raise ValueError("Exporting to NIR requires snntorch version >= 0.9.0")
222+
223+
if snntorch.__version__ >= "0.9.0":
224+
from snntorch.export_nir import export_to_nir
225+
225226
nir_graph = export_to_nir(self.model.__net__(), dummy_input, **kwargs)
226227
nir.write(filename, nir_graph)
227228
print(f"Model exported to {filename}")

0 commit comments

Comments
 (0)