@@ -80,7 +80,9 @@ def get_config(self) -> dict:
8080
8181 return serialize (config )
8282
83- def forward (self , data : dict [str , any ], * , stage : str = "inference" , ** kwargs ) -> dict [str , np .ndarray ]:
83+ def forward (
84+ self , data : dict [str , any ], * , stage : str = "inference" , log_det_jac : bool = False , ** kwargs
85+ ) -> dict [str , np .ndarray ] | tuple [dict [str , np .ndarray ], dict [str , np .ndarray ]]:
8486 """Apply the transforms in the forward direction.
8587
8688 Parameters
@@ -89,22 +91,33 @@ def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -
8991 The data to be transformed.
9092 stage : str, one of ["training", "validation", "inference"]
9193 The stage the function is called in.
94+ log_det_jac: bool, optional
95+ Whether to return the log determinant of the Jacobian of the transforms.
9296 **kwargs : dict
9397 Additional keyword arguments passed to each transform.
9498
9599 Returns
96100 -------
97- dict
98- The transformed data.
101+ dict | tuple[dict, dict]
102+ The transformed data or tuple of transformed data and log determinant of the Jacobian .
99103 """
100104 data = data .copy ()
105+ if not log_det_jac :
106+ for transform in self .transforms :
107+ data = transform (data , stage = stage , ** kwargs )
108+ return data
101109
110+ log_det_jac = {}
102111 for transform in self .transforms :
103- data = transform (data , stage = stage , ** kwargs )
112+ transformed_data = transform (data , stage = stage , ** kwargs )
113+ log_det_jac = transform .log_det_jac (data , log_det_jac , ** kwargs )
114+ data = transformed_data
104115
105- return data
116+ return data , log_det_jac
106117
107- def inverse (self , data : dict [str , np .ndarray ], * , stage : str = "inference" , ** kwargs ) -> dict [str , any ]:
118+ def inverse (
119+ self , data : dict [str , np .ndarray ], * , stage : str = "inference" , log_det_jac : bool = False , ** kwargs
120+ ) -> dict [str , np .ndarray ] | tuple [dict [str , np .ndarray ], dict [str , np .ndarray ]]:
108121 """Apply the transforms in the inverse direction.
109122
110123 Parameters
@@ -113,24 +126,32 @@ def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kw
113126 The data to be transformed.
114127 stage : str, one of ["training", "validation", "inference"]
115128 The stage the function is called in.
129+ log_det_jac: bool, optional
130+ Whether to return the log determinant of the Jacobian of the transforms.
116131 **kwargs : dict
117132 Additional keyword arguments passed to each transform.
118133
119134 Returns
120135 -------
121- dict
122- The transformed data.
136+ dict | tuple[dict, dict]
137+ The transformed data or tuple of transformed data and log determinant of the Jacobian .
123138 """
124139 data = data .copy ()
140+ if not log_det_jac :
141+ for transform in reversed (self .transforms ):
142+ data = transform (data , stage = stage , inverse = True , ** kwargs )
143+ return data
125144
145+ log_det_jac = {}
126146 for transform in reversed (self .transforms ):
127147 data = transform (data , stage = stage , inverse = True , ** kwargs )
148+ log_det_jac = transform .log_det_jac (data , log_det_jac , inverse = True , ** kwargs )
128149
129- return data
150+ return data , log_det_jac
130151
131152 def __call__ (
132153 self , data : Mapping [str , any ], * , inverse : bool = False , stage = "inference" , ** kwargs
133- ) -> dict [str , np .ndarray ]:
154+ ) -> dict [str , np .ndarray ] | tuple [ dict [ str , np . ndarray ], dict [ str , np . ndarray ]] :
134155 """Apply the transforms in the given direction.
135156
136157 Parameters
@@ -146,8 +167,8 @@ def __call__(
146167
147168 Returns
148169 -------
149- dict
150- The transformed data.
170+ dict | tuple[dict, dict]
171+ The transformed data or tuple of transformed data and log determinant of the Jacobian .
151172 """
152173 if inverse :
153174 return self .inverse (data , stage = stage , ** kwargs )
0 commit comments