1
+ from pydantic import Field , model_validator
2
+ from typing import Tuple , Dict , List , Optional
3
+
1
4
import dataclasses
2
5
import numpy as np
3
6
4
7
from gempy .core .data .core_utils import calculate_line_coordinates_2points
5
8
from gempy .optional_dependencies import require_pandas
6
9
10
+ try :
11
+ import pandas as pd
12
+ except ImportError :
13
+ pandas = None
14
+
15
+
16
+ @dataclasses .dataclass
17
+ class SectionDefinition :
18
+ """
19
+ A single cross‐section’s raw parameters.
20
+ """
21
+ start : Tuple [float , float ]
22
+ stop : Tuple [float , float ]
23
+ resolution : Tuple [int , int ]
24
+
7
25
8
26
@dataclasses .dataclass
9
27
class Sections :
@@ -15,26 +33,74 @@ class Sections:
15
33
section_dict: {'section name': ([p1_x, p1_y], [p2_x, p2_y], [xyres, zres])}
16
34
"""
17
35
18
- def __init__ (self , regular_grid = None , z_ext = None , section_dict = None ):
36
+ """
37
+ Pydantic v2 model of your original Sections class.
38
+ All computed fields are initialized with model_validator.
39
+ """
40
+
41
+ # user‐provided inputs
42
+
43
+ z_ext : Tuple [float , float ]
44
+ section_dict : Dict [str , tuple [list [int ]]]
45
+
46
+ # computed/internal (will be serialized too unless excluded)
47
+ names : List [str ] = Field (default_factory = list )
48
+ points : List [List [Tuple [float , float ]]] = Field (default_factory = list )
49
+ resolution : List [Tuple [int , int ]] = Field (default_factory = list )
50
+ length : np .ndarray = Field (default_factory = lambda : np .array ([0 ]), exclude = False )
51
+ dist : np .ndarray = Field (default_factory = lambda : np .array ([]), exclude = False )
52
+ df : Optional [pd .DataFrame ] = Field (default_factory = None , exclude = False )
53
+ values : np .ndarray = Field (default_factory = lambda : np .empty ((0 , 3 )), exclude = False )
54
+ extent : Optional [np .ndarray ] = None
55
+
56
+ # def __init__(self, regular_grid=None, z_ext=None, section_dict=None):
57
+ # pd = require_pandas()
58
+ # if regular_grid is not None:
59
+ # self.z_ext = regular_grid.extent[4:]
60
+ # else:
61
+ # self.z_ext = z_ext
62
+ #
63
+ # self.section_dict = section_dict
64
+ # self.names = []
65
+ # self.points = []
66
+ # self.resolution = []
67
+ # self.length = [0]
68
+ # self.dist = []
69
+ # self.df = pd.DataFrame()
70
+ # self.df['dist'] = self.dist
71
+ # self.values = np.empty((0, 3))
72
+ # self.extent = None
73
+ #
74
+ # if section_dict is not None:
75
+ # self.set_sections(section_dict)
76
+ def __post_init__ (self ):
77
+ self .initialize_computations ()
78
+
79
+ # @model_validator(mode="after")
80
+ # def init_class(self):
81
+ # self.initialize_computations()
82
+ # return self
83
+
84
+ def initialize_computations (self ):
85
+ # copy names
86
+ self .names = list (self .section_dict .keys ())
87
+
88
+ # build points/resolution/length
89
+ self ._get_section_params ()
90
+ # compute distances
91
+ self ._calculate_all_distances ()
92
+ # re-build DataFrame
19
93
pd = require_pandas ()
20
- if regular_grid is not None :
21
- self .z_ext = regular_grid .extent [4 :]
22
- else :
23
- self .z_ext = z_ext
94
+ df = pd .DataFrame .from_dict (
95
+ data = self .section_dict ,
96
+ orient = "index" ,
97
+ columns = ["start" , "stop" , "resolution" ],
98
+ )
99
+ df ["dist" ] = self .dist
100
+ self .df = df
24
101
25
- self .section_dict = section_dict
26
- self .names = []
27
- self .points = []
28
- self .resolution = []
29
- self .length = [0 ]
30
- self .dist = []
31
- self .df = pd .DataFrame ()
32
- self .df ['dist' ] = self .dist
33
- self .values = np .empty ((0 , 3 ))
34
- self .extent = None
35
-
36
- if section_dict is not None :
37
- self .set_sections (section_dict )
102
+ # compute the XYZ grid
103
+ self ._compute_section_coordinates ()
38
104
39
105
def _repr_html_ (self ):
40
106
return self .df .to_html ()
@@ -50,17 +116,10 @@ def set_sections(self, section_dict, regular_grid=None, z_ext=None):
50
116
self .section_dict = section_dict
51
117
if regular_grid is not None :
52
118
self .z_ext = regular_grid .extent [4 :]
119
+
120
+ self .initialize_computations ()
53
121
54
- self .names = np .array (list (self .section_dict .keys ()))
55
-
56
- self .get_section_params ()
57
- self .calculate_all_distances ()
58
- self .df = pd .DataFrame .from_dict (self .section_dict , orient = 'index' , columns = ['start' , 'stop' , 'resolution' ])
59
- self .df ['dist' ] = self .dist
60
-
61
- self .compute_section_coordinates ()
62
-
63
- def get_section_params (self ):
122
+ def _get_section_params (self ):
64
123
self .points = []
65
124
self .resolution = []
66
125
self .length = [0 ]
@@ -76,13 +135,13 @@ def get_section_params(self):
76
135
self .section_dict [section ][2 ][1 ])
77
136
self .length = np .array (self .length ).cumsum ()
78
137
79
- def calculate_all_distances (self ):
138
+ def _calculate_all_distances (self ):
80
139
self .coordinates = np .array (self .points ).ravel ().reshape (- 1 ,
81
140
4 ) # axis are x1,y1,x2,y2
82
141
self .dist = np .sqrt (np .diff (self .coordinates [:, [0 , 2 ]]) ** 2 + np .diff (
83
142
self .coordinates [:, [1 , 3 ]]) ** 2 )
84
143
85
- def compute_section_coordinates (self ):
144
+ def _compute_section_coordinates (self ):
86
145
for i in range (len (self .names )):
87
146
xy = calculate_line_coordinates_2points (self .coordinates [i , :2 ],
88
147
self .coordinates [i , 2 :],
0 commit comments