@@ -90,13 +90,15 @@ def forward(
9090 The data to be transformed.
9191 stage : str, one of ["training", "validation", "inference"]
9292 The stage the function is called in.
93+ jacobian: bool, optional
94+ Whether to return the log determinant jacobians of the transforms.
9395 **kwargs : dict
9496 Additional keyword arguments passed to each transform.
9597
9698 Returns
9799 -------
98- dict
99- The transformed data.
100+ dict | tuple[dict, dict]
101+ The transformed data or tuple of transformed data and jacobians .
100102 """
101103 data = data .copy ()
102104 if not jacobian :
@@ -122,13 +124,15 @@ def inverse(
122124 The data to be transformed.
123125 stage : str, one of ["training", "validation", "inference"]
124126 The stage the function is called in.
127+ jacobian: bool, optional
128+ Whether to return the log determinant jacobians of the transforms.
125129 **kwargs : dict
126130 Additional keyword arguments passed to each transform.
127131
128132 Returns
129133 -------
130- dict
131- The transformed data.
134+ dict | tuple[dict, dict]
135+ The transformed data or tuple of transformed data and jacobians .
132136 """
133137 data = data .copy ()
134138 if not jacobian :
@@ -145,7 +149,7 @@ def inverse(
145149
146150 def __call__ (
147151 self , data : Mapping [str , any ], * , inverse : bool = False , stage = "inference" , ** kwargs
148- ) -> dict [str , np .ndarray ]:
152+ ) -> dict [str , np .ndarray ] | tuple [ dict [ str , np . ndarray ], dict [ str , np . ndarray ]] :
149153 """Apply the transforms in the given direction.
150154
151155 Parameters
@@ -161,8 +165,8 @@ def __call__(
161165
162166 Returns
163167 -------
164- dict
165- The transformed data.
168+ dict | tuple[dict, dict]
169+ The transformed data or tuple of transformed data and jacobians .
166170 """
167171 if inverse :
168172 return self .inverse (data , stage = stage , ** kwargs )
0 commit comments