diff --git a/neurobench/benchmarks/benchmark.py b/neurobench/benchmarks/benchmark.py index 1fc66db2..4a79fcdb 100644 --- a/neurobench/benchmarks/benchmark.py +++ b/neurobench/benchmarks/benchmark.py @@ -21,7 +21,6 @@ import snntorch from torch import Tensor import torch -import nir class Benchmark: @@ -228,6 +227,12 @@ def to_nir(self, dummy_input: Tensor, filename: str, **kwargs) -> None: If the installed version of `snntorch` is less than `0.9.0`. """ + try: + import nir + except ImportError: + raise ImportError( + "Exporting to NIR requires the `nir` package. Install it using `pip install nir`." + ) if snntorch.__version__ < "0.9.0": raise ValueError("Exporting to NIR requires snntorch version >= 0.9.0")