1616import json
1717import csv
1818import os
19- from typing import Literal , List , Type , Optional , Dict , Any
19+ from typing import Literal , List , Type , Optional , Dict , Any , Callable , Tuple
2020import pathlib
2121import snntorch
2222from torch import Tensor
23-
24- if snntorch .__version__ >= "0.9.0" :
25- from snntorch import export_to_nir
26-
2723import torch
28- import nir
2924
3025
3126class Benchmark :
@@ -39,18 +34,25 @@ def __init__(
3934 self ,
4035 model : NeuroBenchModel ,
4136 dataloader : Optional [DataLoader ],
42- preprocessors : Optional [List [NeuroBenchPreProcessor ]],
43- postprocessors : Optional [List [NeuroBenchPostProcessor ]],
37+ preprocessors : Optional [
38+ List [
39+ NeuroBenchPreProcessor
40+ | Callable [[Tuple [Tensor , Tensor ]], Tuple [Tensor , Tensor ]]
41+ ]
42+ ],
43+ postprocessors : Optional [
44+ List [NeuroBenchPostProcessor | Callable [[Tensor ], Tensor ]]
45+ ],
4446 metric_list : List [List [Type [StaticMetric | WorkloadMetric ]]],
4547 ):
4648 """
4749 Args:
4850 model: A NeuroBenchModel.
4951 dataloader: A PyTorch DataLoader.
50- preprocessors: A list of NeuroBenchPreProcessors.
51- postprocessors: A list of NeuroBenchPostProcessors.
52- metric_list: A list of lists of strings of metrics to run.
53- First item is static metrics , second item is data metrics .
52+ preprocessors: A list of NeuroBenchPreProcessors or callable functions (e.g. lambda) with matching interfaces .
53+ postprocessors: A list of NeuroBenchPostProcessors or callable functions (e.g. lambda) with matching interfaces .
54+ metric_list: A list of lists of StaticMetric and WorkloadMetric classes of metrics to run.
55+ First item is StaticMetrics , second item is WorkloadMetrics .
5456 """
5557
5658 self .model = model
@@ -66,8 +68,13 @@ def run(
6668 quiet : bool = False ,
6769 verbose : bool = False ,
6870 dataloader : Optional [DataLoader ] = None ,
69- preprocessors : Optional [NeuroBenchPreProcessor ] = None ,
70- postprocessors : Optional [NeuroBenchPostProcessor ] = None ,
71+ preprocessors : Optional [
72+ NeuroBenchPreProcessor
73+ | Callable [[Tuple [Tensor , Tensor ]], Tuple [Tensor , Tensor ]]
74+ ] = None ,
75+ postprocessors : Optional [
76+ NeuroBenchPostProcessor | Callable [[Tensor ], Tensor ]
77+ ] = None ,
7178 device : Optional [str ] = None ,
7279 ) -> Dict [str , Any ]:
7380 """
@@ -117,10 +124,10 @@ def run(
117124 batch_size = data [0 ].size (0 )
118125
119126 # Preprocessing data
120- data = self .processor_manager .preprocess (data )
127+ input , target = self .processor_manager .preprocess (data )
121128
122129 # Run model on test data
123- preds = self .model (data [ 0 ] )
130+ preds = self .model (input )
124131
125132 # Postprocessing data
126133 preds = self .processor_manager .postprocess (preds )
@@ -220,8 +227,18 @@ def to_nir(self, dummy_input: Tensor, filename: str, **kwargs) -> None:
220227 If the installed version of `snntorch` is less than `0.9.0`.
221228
222229 """
230+ try :
231+ import nir
232+ except ImportError :
233+ raise ImportError (
234+ "Exporting to NIR requires the `nir` package. Install it using `pip install nir`."
235+ )
223236 if snntorch .__version__ < "0.9.0" :
224237 raise ValueError ("Exporting to NIR requires snntorch version >= 0.9.0" )
238+
239+ if snntorch .__version__ >= "0.9.0" :
240+ from snntorch .export_nir import export_to_nir
241+
225242 nir_graph = export_to_nir (self .model .__net__ (), dummy_input , ** kwargs )
226243 nir .write (filename , nir_graph )
227244 print (f"Model exported to { filename } " )
0 commit comments