@@ -79,13 +79,15 @@ def get_config(self) -> dict:
7979
8080 return serialize (config )
8181
82- def forward (self , data : dict [str , any ], ** kwargs ) -> dict [str , np .ndarray ]:
82+ def forward (self , data : dict [str , any ], * , stage : str = "inference" , * *kwargs ) -> dict [str , np .ndarray ]:
8383 """Apply the transforms in the forward direction.
8484
8585 Parameters
8686 ----------
8787 data : dict
8888 The data to be transformed.
89+ stage : str, one of ["training", "validation", "inference"]
90+ The stage the function is called in.
8991 **kwargs : dict
9092 Additional keyword arguments passed to each transform.
9193
@@ -97,17 +99,19 @@ def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
9799 data = data .copy ()
98100
99101 for transform in self .transforms :
100- data = transform (data , ** kwargs )
102+ data = transform (data , stage = stage , ** kwargs )
101103
102104 return data
103105
104- def inverse (self , data : dict [str , np .ndarray ], ** kwargs ) -> dict [str , any ]:
106+ def inverse (self , data : dict [str , np .ndarray ], * , stage : str = "inference" , * *kwargs ) -> dict [str , any ]:
105107 """Apply the transforms in the inverse direction.
106108
107109 Parameters
108110 ----------
109111 data : dict
110112 The data to be transformed.
113+ stage : str, one of ["training", "validation", "inference"]
114+ The stage the function is called in.
111115 **kwargs : dict
112116 Additional keyword arguments passed to each transform.
113117
@@ -119,11 +123,13 @@ def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, any]:
119123 data = data .copy ()
120124
121125 for transform in reversed (self .transforms ):
122- data = transform (data , inverse = True , ** kwargs )
126+ data = transform (data , stage = stage , inverse = True , ** kwargs )
123127
124128 return data
125129
126- def __call__ (self , data : Mapping [str , any ], * , inverse : bool = False , ** kwargs ) -> dict [str , np .ndarray ]:
130+ def __call__ (
131+ self , data : Mapping [str , any ], * , inverse : bool = False , stage = "inference" , ** kwargs
132+ ) -> dict [str , np .ndarray ]:
127133 """Apply the transforms in the given direction.
128134
129135 Parameters
@@ -132,6 +138,8 @@ def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs)
132138 The data to be transformed.
133139 inverse : bool, optional
134140 If False, apply the forward transform, else apply the inverse transform (default False).
141+ stage : str, one of ["training", "validation", "inference"]
142+ The stage the function is called in.
135143 **kwargs
136144 Additional keyword arguments passed to each transform.
137145
@@ -141,9 +149,9 @@ def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs)
141149 The transformed data.
142150 """
143151 if inverse :
144- return self .inverse (data , ** kwargs )
152+ return self .inverse (data , stage = stage , ** kwargs )
145153
146- return self .forward (data , ** kwargs )
154+ return self .forward (data , stage = stage , ** kwargs )
147155
148156 def __repr__ (self ):
149157 result = ""
0 commit comments