33 NeuroBenchPreProcessor ,
44 NeuroBenchPostProcessor ,
55)
6- from typing import List
6+ from typing import List , Callable , Tuple
77from torch import Tensor
88
99
@@ -26,8 +26,11 @@ class ProcessorManager(ABC):
2626
2727 def __init__ (
2828 self ,
29- preprocessors : List [NeuroBenchPreProcessor ],
30- postprocessors : List [NeuroBenchPostProcessor ],
29+ preprocessors : List [
30+ NeuroBenchPreProcessor
31+ | Callable [[Tuple [Tensor , Tensor ]], Tuple [Tensor , Tensor ]]
32+ ],
33+ postprocessors : List [NeuroBenchPostProcessor | Callable [[Tensor ], Tensor ]],
3134 ):
3235 """
3336 Initialize the ProcessorManager with the given preprocessors and postprocessors.
@@ -45,19 +48,31 @@ def __init__(
4548
4649 """
4750
48- if any (not isinstance (p , NeuroBenchPreProcessor ) for p in preprocessors ):
51+ if any (
52+ not (isinstance (p , NeuroBenchPreProcessor ) or callable (p ))
53+ for p in preprocessors
54+ ):
4955 raise TypeError (
50- "All preprocessors must be instances of NeuroBenchPreProcessor"
56+ "All preprocessors must be instances of NeuroBenchPreProcessor or callable functions "
5157 )
52- if any (not isinstance (p , NeuroBenchPostProcessor ) for p in postprocessors ):
58+ if any (
59+ not (isinstance (p , NeuroBenchPostProcessor ) or callable (p ))
60+ for p in postprocessors
61+ ):
5362 raise TypeError (
54- "All postprocessors must be instances of NeuroBenchPostProcessor"
63+ "All postprocessors must be instances of NeuroBenchPostProcessor or callable functions "
5564 )
5665
5766 self .preprocessors = preprocessors
5867 self .postprocessors = postprocessors
5968
60- def replace_preprocessors (self , preprocessors : List [NeuroBenchPreProcessor ]):
69+ def replace_preprocessors (
70+ self ,
71+ preprocessors : List [
72+ NeuroBenchPreProcessor
73+ | Callable [[Tuple [Tensor , Tensor ]], Tuple [Tensor , Tensor ]]
74+ ],
75+ ):
6176 """
6277 Replace the current list of preprocessors with the provided list.
6378
@@ -70,13 +85,18 @@ def replace_preprocessors(self, preprocessors: List[NeuroBenchPreProcessor]):
7085 NeuroBenchPreProcessor.
7186
7287 """
73- if any (not isinstance (p , NeuroBenchPreProcessor ) for p in preprocessors ):
88+ if any (
89+ not (isinstance (p , NeuroBenchPreProcessor ) or callable (p ))
90+ for p in preprocessors
91+ ):
7492 raise TypeError (
75- "All preprocessors must be instances of NeuroBenchPreProcessor"
93+ "All preprocessors must be instances of NeuroBenchPreProcessor or callable functions "
7694 )
7795 self .preprocessors = preprocessors
7896
79- def replace_postprocessors (self , postprocessors : List [NeuroBenchPostProcessor ]):
97+ def replace_postprocessors (
98+ self , postprocessors : List [NeuroBenchPostProcessor | Callable [[Tensor ], Tensor ]]
99+ ):
80100 """
81101 Replace the current list of postprocessors with the provided list.
82102
@@ -89,13 +109,16 @@ def replace_postprocessors(self, postprocessors: List[NeuroBenchPostProcessor]):
89109 NeuroBenchPostProcessor.
90110
91111 """
92- if any (not isinstance (p , NeuroBenchPostProcessor ) for p in postprocessors ):
112+ if any (
113+ not (isinstance (p , NeuroBenchPostProcessor ) or callable (p ))
114+ for p in postprocessors
115+ ):
93116 raise TypeError (
94- "All postprocessors must be instances of NeuroBenchPostProcessor"
117+ "All postprocessors must be instances of NeuroBenchPostProcessor or callable functions "
95118 )
96119 self .postprocessors = postprocessors
97120
98- def preprocess (self , data : tuple [Tensor , Tensor ]) -> tuple [Tensor , Tensor ]:
121+ def preprocess (self , data : Tuple [Tensor , Tensor ]) -> tuple [Tensor , Tensor ]:
99122 """
100123 Apply preprocessing steps to the input data.
101124
0 commit comments