11import pprint
22import warnings
3+ from dataclasses import dataclass
34from enum import Enum , auto
4- from typing import Optional , Union
5+ from typing import Optional
56
67import numpy as np
7- from dataclasses import dataclass
8+ from typing_extensions import Annotated
9+
10+ from .encoders .converters import numpy_array_short_validator
811
912
1013class TransformOpsOrder (Enum ):
@@ -13,16 +16,16 @@ class TransformOpsOrder(Enum):
1316
1417
1518class GlobalAnisotropy (Enum ):
16- CUBE = auto () # * Transform data to be as close as possible to a cube
17- NONE = auto () # * Do not transform data
18- MANUAL = auto () # * Use the user defined transform
19-
19+ CUBE = auto () # * Transform data to be as close as possible to a cube
20+ NONE = auto () # * Do not transform data
21+ MANUAL = auto () # * Use the user defined transform
22+
2023
2124@dataclass
2225class Transform :
23- position : np .ndarray
24- rotation : np .ndarray
25- scale : np .ndarray
26+ position : Annotated [ np .ndarray , numpy_array_short_validator ]
27+ rotation : Annotated [ np .ndarray , numpy_array_short_validator ]
28+ scale : Annotated [ np .ndarray , numpy_array_short_validator ]
2629
2730 _is_default_transform : bool = False
2831 _cached_pivot : Optional [np .ndarray ] = None
@@ -68,11 +71,10 @@ def from_matrix(cls, matrix: np.ndarray):
6871 ])
6972 return cls (position , rotation_degrees , scale )
7073
71-
7274 @property
7375 def cached_pivot (self ):
7476 return self ._cached_pivot
75-
77+
7678 @cached_pivot .setter
7779 def cached_pivot (self , pivot : np .ndarray ):
7880 self ._cached_pivot = pivot
@@ -96,7 +98,7 @@ def from_input_points(cls, surface_points: 'gempy.data.SurfacePointsTable', orie
9698
9799 # The scaling factor for each dimension is the inverse of its range
98100 scaling_factors = 1 / range_coord
99-
101+
100102 # ! Be careful with toy models
101103 center : np .ndarray = (max_coord + min_coord ) / 2
102104 return cls (
@@ -127,14 +129,14 @@ def apply_anisotropy(self, anisotropy_type: GlobalAnisotropy, anisotropy_limit:
127129 )
128130 else :
129131 raise NotImplementedError
130-
132+
131133 @staticmethod
132134 def _adjust_scale_to_limit_ratio (s , anisotropic_limit = np .array ([10 , 10 , 10 ])):
133135 # Calculate the ratios
134136 ratios = [
135- s [0 ] / s [1 ], s [0 ] / s [2 ],
136- s [1 ] / s [0 ], s [1 ] / s [2 ],
137- s [2 ] / s [0 ], s [2 ] / s [1 ]
137+ s [0 ] / s [1 ], s [0 ] / s [2 ],
138+ s [1 ] / s [0 ], s [1 ] / s [2 ],
139+ s [2 ] / s [0 ], s [2 ] / s [1 ]
138140 ]
139141
140142 # Adjust the scales based on the index of the max ratio
@@ -158,9 +160,9 @@ def _adjust_scale_to_limit_ratio(s, anisotropic_limit=np.array([10, 10, 10])):
158160 @staticmethod
159161 def _max_scale_ratio (s ):
160162 ratios = [
161- s [0 ] / s [1 ], s [0 ] / s [2 ],
162- s [1 ] / s [0 ], s [1 ] / s [2 ],
163- s [2 ] / s [0 ], s [2 ] / s [1 ]
163+ s [0 ] / s [1 ], s [0 ] / s [2 ],
164+ s [1 ] / s [0 ], s [1 ] / s [2 ],
165+ s [2 ] / s [0 ], s [2 ] / s [1 ]
164166 ]
165167 return max (ratios )
166168
@@ -223,7 +225,7 @@ def apply(self, points: np.ndarray, transform_op_order: TransformOpsOrder = Tran
223225
224226 def scale_points (self , points : np .ndarray ):
225227 return points * self .scale
226-
228+
227229 def apply_inverse (self , points : np .ndarray , transform_op_order : TransformOpsOrder = TransformOpsOrder .SRT ):
228230 # * NOTE: to compare with legacy we would have to add 0.5 to the coords
229231 assert points .shape [1 ] == 3
@@ -233,12 +235,11 @@ def apply_inverse(self, points: np.ndarray, transform_op_order: TransformOpsOrde
233235 transformed_points = (inv @ homogeneous_points .T ).T
234236 return transformed_points [:, :3 ]
235237
236-
237238 def apply_with_cached_pivot (self , points : np .ndarray , transform_op_order : TransformOpsOrder = TransformOpsOrder .SRT ):
238239 if self ._cached_pivot is None :
239240 raise ValueError ("A pivot must be set before calling this method" )
240241 return self .apply_with_pivot (points , self ._cached_pivot , transform_op_order )
241-
242+
242243 def apply_inverse_with_cached_pivot (self , points : np .ndarray , transform_op_order : TransformOpsOrder = TransformOpsOrder .SRT ):
243244 if self ._cached_pivot is None :
244245 raise ValueError ("A pivot must be set before calling this method" )
@@ -269,7 +270,7 @@ def apply_with_pivot(self, points: np.ndarray, pivot: np.ndarray,
269270 def apply_inverse_with_pivot (self , points : np .ndarray , pivot : np .ndarray ,
270271 transform_op_order : TransformOpsOrder = TransformOpsOrder .SRT ):
271272 assert points .shape [1 ] == 3
272-
273+
273274 # Translation matrices to and from the pivot
274275 T_to_origin = self ._translation_matrix (- pivot [0 ], - pivot [1 ], - pivot [2 ])
275276 T_back = self ._translation_matrix (* pivot )
@@ -284,10 +285,10 @@ def apply_inverse_with_pivot(self, points: np.ndarray, pivot: np.ndarray,
284285 @staticmethod
285286 def _translation_matrix (tx , ty , tz ):
286287 return np .array ([
287- [1 , 0 , 0 , tx ],
288- [0 , 1 , 0 , ty ],
289- [0 , 0 , 1 , tz ],
290- [0 , 0 , 0 , 1 ]
288+ [1 , 0 , 0 , tx ],
289+ [0 , 1 , 0 , ty ],
290+ [0 , 0 , 1 , tz ],
291+ [0 , 0 , 0 , 1 ]
291292 ])
292293
293294 def transform_gradient (self , gradients : np .ndarray , transform_op_order : TransformOpsOrder = TransformOpsOrder .SRT ,
0 commit comments