@@ -80,7 +80,7 @@ def get_config(self) -> dict:
8080 return serialize (config )
8181
8282 def forward (
83- self , data : dict [str , any ], * , stage : str = "inference" , jacobian : bool = False , ** kwargs
83+ self , data : dict [str , any ], * , stage : str = "inference" , log_det_jac : bool = False , ** kwargs
8484 ) -> dict [str , np .ndarray ] | tuple [dict [str , np .ndarray ], dict [str , np .ndarray ]]:
8585 """Apply the transforms in the forward direction.
8686
@@ -90,18 +90,18 @@ def forward(
9090 The data to be transformed.
9191 stage : str, one of ["training", "validation", "inference"]
9292 The stage the function is called in.
93- jacobian : bool, optional
93+ log_det_jac : bool, optional
9494 Whether to return the log determinant jacobians of the transforms.
9595 **kwargs : dict
9696 Additional keyword arguments passed to each transform.
9797
9898 Returns
9999 -------
100100 dict | tuple[dict, dict]
101- The transformed data or tuple of transformed data and jacobians.
101+ The transformed data or tuple of transformed data and log determinant jacobians.
102102 """
103103 data = data .copy ()
104- if not jacobian :
104+ if not log_det_jac :
105105 for transform in self .transforms :
106106 data = transform (data , stage = stage , ** kwargs )
107107 return data
@@ -114,7 +114,7 @@ def forward(
114114 return data , log_det_jac
115115
116116 def inverse (
117- self , data : dict [str , np .ndarray ], * , stage : str = "inference" , jacobian : bool = False , ** kwargs
117+ self , data : dict [str , np .ndarray ], * , stage : str = "inference" , log_det_jac : bool = False , ** kwargs
118118 ) -> dict [str , np .ndarray ] | tuple [dict [str , np .ndarray ], dict [str , np .ndarray ]]:
119119 """Apply the transforms in the inverse direction.
120120
@@ -124,18 +124,18 @@ def inverse(
124124 The data to be transformed.
125125 stage : str, one of ["training", "validation", "inference"]
126126 The stage the function is called in.
127- jacobian : bool, optional
127+ log_det_jac : bool, optional
128128 Whether to return the log determinant jacobians of the transforms.
129129 **kwargs : dict
130130 Additional keyword arguments passed to each transform.
131131
132132 Returns
133133 -------
134134 dict | tuple[dict, dict]
135- The transformed data or tuple of transformed data and jacobians.
135+ The transformed data or tuple of transformed data and log determinant jacobians.
136136 """
137137 data = data .copy ()
138- if not jacobian :
138+ if not log_det_jac :
139139 for transform in reversed (self .transforms ):
140140 data = transform (data , stage = stage , inverse = True , ** kwargs )
141141 return data
@@ -166,7 +166,7 @@ def __call__(
166166 Returns
167167 -------
168168 dict | tuple[dict, dict]
169- The transformed data or tuple of transformed data and jacobians.
169+ The transformed data or tuple of transformed data and log determinant jacobians.
170170 """
171171 if inverse :
172172 return self .inverse (data , stage = stage , ** kwargs )
0 commit comments