Skip to content

Commit 790924b

Browse files
committed
[WIP] Converting sections grid into a proper data class
1 parent 9fb075e commit 790924b

File tree

3 files changed

+100
-35
lines changed

3 files changed

+100
-35
lines changed

gempy/API/grid_API.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111

1212
def set_section_grid(grid: Grid, section_dict: dict):
1313
if grid.sections is None:
14-
grid.sections = Sections(regular_grid=grid.regular_grid, section_dict=section_dict)
14+
grid.sections = Sections(
15+
z_ext=grid.regular_grid.extent[4:],
16+
section_dict=section_dict
17+
)
1518
else:
1619
grid.sections.set_sections(section_dict,
1720
regular_grid=grid.regular_grid)
@@ -54,7 +57,7 @@ def set_topography_from_random(grid: Grid, fractal_dimension: float = 2.0, d_z:
5457
dz=d_z,
5558
fractal_dimension=fractal_dimension
5659
)
57-
60+
5861
grid.topography = Topography(
5962
regular_grid=grid.regular_grid,
6063
values_2d=random_topography
@@ -70,7 +73,7 @@ def set_topography_from_subsurface_structured_grid(grid: Grid, struct: "subsurfa
7073
return grid.topography
7174

7275

73-
def set_topography_from_arrays(grid: Grid, xyz_vertices: np.ndarray):
76+
def set_topography_from_arrays(grid: Grid, xyz_vertices: np.ndarray):
7477
grid.topography = Topography.from_unstructured_mesh(grid.regular_grid, xyz_vertices)
7578
set_active_grid(grid, [Grid.GridTypes.TOPOGRAPHY])
7679
return grid.topography
@@ -88,7 +91,7 @@ def set_topography_from_file(grid: Grid, filepath: str, crop_to_extent: Union[Se
8891
def set_custom_grid(grid: Grid, xyz_coord: np.ndarray):
8992
custom_grid = CustomGrid(values=xyz_coord)
9093
grid.custom_grid = custom_grid
91-
94+
9295
set_active_grid(grid, [Grid.GridTypes.CUSTOM])
9396
return grid.custom_grid
9497

gempy/core/data/geo_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from .surface_points import SurfacePointsTable
2525
from ...modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame
2626

27+
import pandas as pd
28+
2729
"""
2830
TODO:
2931
- [ ] StructuralFrame will all input points chunked on Elements. Here I will need a property to put all
@@ -296,7 +298,8 @@ def add_surface_points(self, X: Sequence[float], Y: Sequence[float], Z: Sequence
296298
arbitrary_types_allowed=True,
297299
use_enum_values=False,
298300
json_encoders={
299-
np.ndarray: encode_numpy_array
301+
np.ndarray: encode_numpy_array,
302+
pd.DataFrame: lambda df: df.to_dict(orient="list"),
300303
}
301304
)
302305

gempy/core/data/grid_modules/sections_grid.py

Lines changed: 89 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,27 @@
1+
from pydantic import Field, model_validator
2+
from typing import Tuple, Dict, List, Optional
3+
14
import dataclasses
25
import numpy as np
36

47
from gempy.core.data.core_utils import calculate_line_coordinates_2points
58
from gempy.optional_dependencies import require_pandas
69

10+
try:
11+
import pandas as pd
12+
except ImportError:
13+
pandas = None
14+
15+
16+
@dataclasses.dataclass
17+
class SectionDefinition:
18+
"""
19+
A single cross‐section’s raw parameters.
20+
"""
21+
start: Tuple[float, float]
22+
stop: Tuple[float, float]
23+
resolution: Tuple[int, int]
24+
725

826
@dataclasses.dataclass
927
class Sections:
@@ -15,26 +33,74 @@ class Sections:
1533
section_dict: {'section name': ([p1_x, p1_y], [p2_x, p2_y], [xyres, zres])}
1634
"""
1735

18-
def __init__(self, regular_grid=None, z_ext=None, section_dict=None):
36+
"""
37+
Pydantic v2 model of your original Sections class.
38+
All computed fields are initialized with model_validator.
39+
"""
40+
41+
# user‐provided inputs
42+
43+
z_ext: Tuple[float, float]
44+
section_dict: Dict[str, tuple[list[int]]]
45+
46+
# computed/internal (will be serialized too unless excluded)
47+
names: List[str] = Field(default_factory=list)
48+
points: List[List[Tuple[float, float]]] = Field(default_factory=list)
49+
resolution: List[Tuple[int, int]] = Field(default_factory=list)
50+
length: np.ndarray = Field(default_factory=lambda: np.array([0]), exclude=False)
51+
dist: np.ndarray = Field(default_factory=lambda: np.array([]), exclude=False)
52+
df: Optional[pd.DataFrame] = Field(default_factory=None, exclude=False)
53+
values: np.ndarray = Field(default_factory=lambda: np.empty((0, 3)), exclude=False)
54+
extent: Optional[np.ndarray] = None
55+
56+
# def __init__(self, regular_grid=None, z_ext=None, section_dict=None):
57+
# pd = require_pandas()
58+
# if regular_grid is not None:
59+
# self.z_ext = regular_grid.extent[4:]
60+
# else:
61+
# self.z_ext = z_ext
62+
#
63+
# self.section_dict = section_dict
64+
# self.names = []
65+
# self.points = []
66+
# self.resolution = []
67+
# self.length = [0]
68+
# self.dist = []
69+
# self.df = pd.DataFrame()
70+
# self.df['dist'] = self.dist
71+
# self.values = np.empty((0, 3))
72+
# self.extent = None
73+
#
74+
# if section_dict is not None:
75+
# self.set_sections(section_dict)
76+
def __post_init__(self):
77+
self.initialize_computations()
78+
79+
# @model_validator(mode="after")
80+
# def init_class(self):
81+
# self.initialize_computations()
82+
# return self
83+
84+
def initialize_computations(self):
85+
# copy names
86+
self.names = list(self.section_dict.keys())
87+
88+
# build points/resolution/length
89+
self._get_section_params()
90+
# compute distances
91+
self._calculate_all_distances()
92+
# re-build DataFrame
1993
pd = require_pandas()
20-
if regular_grid is not None:
21-
self.z_ext = regular_grid.extent[4:]
22-
else:
23-
self.z_ext = z_ext
94+
df = pd.DataFrame.from_dict(
95+
data=self.section_dict,
96+
orient="index",
97+
columns=["start", "stop", "resolution"],
98+
)
99+
df["dist"] = self.dist
100+
self.df = df
24101

25-
self.section_dict = section_dict
26-
self.names = []
27-
self.points = []
28-
self.resolution = []
29-
self.length = [0]
30-
self.dist = []
31-
self.df = pd.DataFrame()
32-
self.df['dist'] = self.dist
33-
self.values = np.empty((0, 3))
34-
self.extent = None
35-
36-
if section_dict is not None:
37-
self.set_sections(section_dict)
102+
# compute the XYZ grid
103+
self._compute_section_coordinates()
38104

39105
def _repr_html_(self):
40106
return self.df.to_html()
@@ -50,17 +116,10 @@ def set_sections(self, section_dict, regular_grid=None, z_ext=None):
50116
self.section_dict = section_dict
51117
if regular_grid is not None:
52118
self.z_ext = regular_grid.extent[4:]
119+
120+
self.initialize_computations()
53121

54-
self.names = np.array(list(self.section_dict.keys()))
55-
56-
self.get_section_params()
57-
self.calculate_all_distances()
58-
self.df = pd.DataFrame.from_dict(self.section_dict, orient='index', columns=['start', 'stop', 'resolution'])
59-
self.df['dist'] = self.dist
60-
61-
self.compute_section_coordinates()
62-
63-
def get_section_params(self):
122+
def _get_section_params(self):
64123
self.points = []
65124
self.resolution = []
66125
self.length = [0]
@@ -76,13 +135,13 @@ def get_section_params(self):
76135
self.section_dict[section][2][1])
77136
self.length = np.array(self.length).cumsum()
78137

79-
def calculate_all_distances(self):
138+
def _calculate_all_distances(self):
80139
self.coordinates = np.array(self.points).ravel().reshape(-1,
81140
4) # axis are x1,y1,x2,y2
82141
self.dist = np.sqrt(np.diff(self.coordinates[:, [0, 2]]) ** 2 + np.diff(
83142
self.coordinates[:, [1, 3]]) ** 2)
84143

85-
def compute_section_coordinates(self):
144+
def _compute_section_coordinates(self):
86145
for i in range(len(self.names)):
87146
xy = calculate_line_coordinates_2points(self.coordinates[i, :2],
88147
self.coordinates[i, 2:],

0 commit comments

Comments
 (0)