Skip to content

Commit 6cfbdbc

Browse files
authored
Merge pull request #250 from NeuroBench/feature/processors
Rafactor processors to accept callables
2 parents 704282d + cc9b368 commit 6cfbdbc

File tree

2 files changed

+60
-25
lines changed

2 files changed

+60
-25
lines changed

neurobench/benchmarks/benchmark.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
import csv
1818
import os
19-
from typing import Literal, List, Type, Optional, Dict, Any
19+
from typing import Literal, List, Type, Optional, Dict, Any, Callable, Tuple
2020
import pathlib
2121
import snntorch
2222
from torch import Tensor
@@ -36,18 +36,25 @@ def __init__(
3636
self,
3737
model: NeuroBenchModel,
3838
dataloader: Optional[DataLoader],
39-
preprocessors: Optional[List[NeuroBenchPreProcessor]],
40-
postprocessors: Optional[List[NeuroBenchPostProcessor]],
39+
preprocessors: Optional[
40+
List[
41+
NeuroBenchPreProcessor
42+
| Callable[[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
43+
]
44+
],
45+
postprocessors: Optional[
46+
List[NeuroBenchPostProcessor | Callable[[Tensor], Tensor]]
47+
],
4148
metric_list: List[List[Type[StaticMetric | WorkloadMetric]]],
4249
):
4350
"""
4451
Args:
4552
model: A NeuroBenchModel.
4653
dataloader: A PyTorch DataLoader.
47-
preprocessors: A list of NeuroBenchPreProcessors.
48-
postprocessors: A list of NeuroBenchPostProcessors.
49-
metric_list: A list of lists of strings of metrics to run.
50-
First item is static metrics, second item is data metrics.
54+
preprocessors: A list of NeuroBenchPreProcessors or callable functions (e.g. lambda) with matching interfaces.
55+
postprocessors: A list of NeuroBenchPostProcessors or callable functions (e.g. lambda) with matching interfaces.
56+
metric_list: A list of lists of StaticMetric and WorkloadMetric classes of metrics to run.
57+
First item is StaticMetrics, second item is WorkloadMetrics.
5158
"""
5259

5360
self.model = model
@@ -63,8 +70,13 @@ def run(
6370
quiet: bool = False,
6471
verbose: bool = False,
6572
dataloader: Optional[DataLoader] = None,
66-
preprocessors: Optional[NeuroBenchPreProcessor] = None,
67-
postprocessors: Optional[NeuroBenchPostProcessor] = None,
73+
preprocessors: Optional[
74+
NeuroBenchPreProcessor
75+
| Callable[[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
76+
] = None,
77+
postprocessors: Optional[
78+
NeuroBenchPostProcessor | Callable[[Tensor], Tensor]
79+
] = None,
6880
device: Optional[str] = None,
6981
) -> Dict[str, Any]:
7082
"""
@@ -114,10 +126,10 @@ def run(
114126
batch_size = data[0].size(0)
115127

116128
# Preprocessing data
117-
data = self.processor_manager.preprocess(data)
129+
input, target = self.processor_manager.preprocess(data)
118130

119131
# Run model on test data
120-
preds = self.model(data[0])
132+
preds = self.model(input)
121133

122134
# Postprocessing data
123135
preds = self.processor_manager.postprocess(preds)

neurobench/processors/manager/proc_manager.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
NeuroBenchPreProcessor,
44
NeuroBenchPostProcessor,
55
)
6-
from typing import List
6+
from typing import List, Callable, Tuple
77
from torch import Tensor
88

99

@@ -26,8 +26,11 @@ class ProcessorManager(ABC):
2626

2727
def __init__(
2828
self,
29-
preprocessors: List[NeuroBenchPreProcessor],
30-
postprocessors: List[NeuroBenchPostProcessor],
29+
preprocessors: List[
30+
NeuroBenchPreProcessor
31+
| Callable[[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
32+
],
33+
postprocessors: List[NeuroBenchPostProcessor | Callable[[Tensor], Tensor]],
3134
):
3235
"""
3336
Initialize the ProcessorManager with the given preprocessors and postprocessors.
@@ -45,19 +48,31 @@ def __init__(
4548
4649
"""
4750

48-
if any(not isinstance(p, NeuroBenchPreProcessor) for p in preprocessors):
51+
if any(
52+
not (isinstance(p, NeuroBenchPreProcessor) or callable(p))
53+
for p in preprocessors
54+
):
4955
raise TypeError(
50-
"All preprocessors must be instances of NeuroBenchPreProcessor"
56+
"All preprocessors must be instances of NeuroBenchPreProcessor or callable functions"
5157
)
52-
if any(not isinstance(p, NeuroBenchPostProcessor) for p in postprocessors):
58+
if any(
59+
not (isinstance(p, NeuroBenchPostProcessor) or callable(p))
60+
for p in postprocessors
61+
):
5362
raise TypeError(
54-
"All postprocessors must be instances of NeuroBenchPostProcessor"
63+
"All postprocessors must be instances of NeuroBenchPostProcessor or callable functions"
5564
)
5665

5766
self.preprocessors = preprocessors
5867
self.postprocessors = postprocessors
5968

60-
def replace_preprocessors(self, preprocessors: List[NeuroBenchPreProcessor]):
69+
def replace_preprocessors(
70+
self,
71+
preprocessors: List[
72+
NeuroBenchPreProcessor
73+
| Callable[[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
74+
],
75+
):
6176
"""
6277
Replace the current list of preprocessors with the provided list.
6378
@@ -70,13 +85,18 @@ def replace_preprocessors(self, preprocessors: List[NeuroBenchPreProcessor]):
7085
NeuroBenchPreProcessor.
7186
7287
"""
73-
if any(not isinstance(p, NeuroBenchPreProcessor) for p in preprocessors):
88+
if any(
89+
not (isinstance(p, NeuroBenchPreProcessor) or callable(p))
90+
for p in preprocessors
91+
):
7492
raise TypeError(
75-
"All preprocessors must be instances of NeuroBenchPreProcessor"
93+
"All preprocessors must be instances of NeuroBenchPreProcessor or callable functions"
7694
)
7795
self.preprocessors = preprocessors
7896

79-
def replace_postprocessors(self, postprocessors: List[NeuroBenchPostProcessor]):
97+
def replace_postprocessors(
98+
self, postprocessors: List[NeuroBenchPostProcessor | Callable[[Tensor], Tensor]]
99+
):
80100
"""
81101
Replace the current list of postprocessors with the provided list.
82102
@@ -89,13 +109,16 @@ def replace_postprocessors(self, postprocessors: List[NeuroBenchPostProcessor]):
89109
NeuroBenchPostProcessor.
90110
91111
"""
92-
if any(not isinstance(p, NeuroBenchPostProcessor) for p in postprocessors):
112+
if any(
113+
not (isinstance(p, NeuroBenchPostProcessor) or callable(p))
114+
for p in postprocessors
115+
):
93116
raise TypeError(
94-
"All postprocessors must be instances of NeuroBenchPostProcessor"
117+
"All postprocessors must be instances of NeuroBenchPostProcessor or callable functions"
95118
)
96119
self.postprocessors = postprocessors
97120

98-
def preprocess(self, data: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]:
121+
def preprocess(self, data: Tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]:
99122
"""
100123
Apply preprocessing steps to the input data.
101124

0 commit comments

Comments
 (0)