Skip to content

Commit 9fb075e

Browse files
committed
[CLN] Refactor grids into its own modules
1 parent 57b9984 commit 9fb075e

File tree

5 files changed

+153
-143
lines changed

5 files changed

+153
-143
lines changed
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1-
from .grid_types import Sections, RegularGrid, CustomGrid
1+
from .regular_grid import RegularGrid
2+
from .custom_grid import CustomGrid
3+
from .sections_grid import Sections
24
from .topography import Topography
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import dataclasses
2+
import numpy as np
3+
from pydantic import Field
4+
5+
6+
@dataclasses.dataclass
7+
class CustomGrid:
8+
"""Object that contains arbitrary XYZ coordinates.
9+
10+
Args:
11+
xyx_coords (numpy.ndarray like): XYZ (in columns) of the desired coordinates
12+
13+
Attributes:
14+
values (np.ndarray): XYZ coordinates
15+
"""
16+
17+
values: np.ndarray = Field(
18+
exclude=True,
19+
default_factory=lambda: np.zeros((0, 3)),
20+
repr=False
21+
)
22+
23+
24+
def __post_init__(self):
25+
custom_grid = np.atleast_2d(self.values)
26+
assert type(custom_grid) is np.ndarray and custom_grid.shape[1] == 3, \
27+
'The shape of new grid must be (n,3) where n is the number of' \
28+
' points of the grid'
29+
30+
31+
@property
32+
def length(self):
33+
return self.values.shape[0]

gempy/core/data/grid_modules/grid_types.py renamed to gempy/core/data/grid_modules/regular_grid.py

Lines changed: 1 addition & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55

66
import numpy as np
77

8-
from ..core_utils import calculate_line_coordinates_2points
98
from ..encoders.converters import numpy_array_short_validator
109
from .... import optional_dependencies
11-
from ....optional_dependencies import require_pandas
12-
from gempy_engine.core.data.transforms import Transform, TransformOpsOrder
10+
from gempy_engine.core.data.transforms import Transform
1311

1412

1513
@dataclasses.dataclass
@@ -255,141 +253,3 @@ def plot_rotation(regular_grid, pivot, point_x_axis, point_y_axis):
255253
plt.show()
256254

257255

258-
@dataclasses.dataclass
259-
class Sections:
260-
"""
261-
Object that creates a grid of cross sections between two points.
262-
263-
Args:
264-
regular_grid: Model.grid.regular_grid
265-
section_dict: {'section name': ([p1_x, p1_y], [p2_x, p2_y], [xyres, zres])}
266-
"""
267-
268-
def __init__(self, regular_grid=None, z_ext=None, section_dict=None):
269-
pd = require_pandas()
270-
if regular_grid is not None:
271-
self.z_ext = regular_grid.extent[4:]
272-
else:
273-
self.z_ext = z_ext
274-
275-
self.section_dict = section_dict
276-
self.names = []
277-
self.points = []
278-
self.resolution = []
279-
self.length = [0]
280-
self.dist = []
281-
self.df = pd.DataFrame()
282-
self.df['dist'] = self.dist
283-
self.values = np.empty((0, 3))
284-
self.extent = None
285-
286-
if section_dict is not None:
287-
self.set_sections(section_dict)
288-
289-
def _repr_html_(self):
290-
return self.df.to_html()
291-
292-
def __repr__(self):
293-
return self.df.to_string()
294-
295-
def show(self):
296-
pass
297-
298-
def set_sections(self, section_dict, regular_grid=None, z_ext=None):
299-
pd = require_pandas()
300-
self.section_dict = section_dict
301-
if regular_grid is not None:
302-
self.z_ext = regular_grid.extent[4:]
303-
304-
self.names = np.array(list(self.section_dict.keys()))
305-
306-
self.get_section_params()
307-
self.calculate_all_distances()
308-
self.df = pd.DataFrame.from_dict(self.section_dict, orient='index', columns=['start', 'stop', 'resolution'])
309-
self.df['dist'] = self.dist
310-
311-
self.compute_section_coordinates()
312-
313-
def get_section_params(self):
314-
self.points = []
315-
self.resolution = []
316-
self.length = [0]
317-
318-
for i, section in enumerate(self.names):
319-
points = [self.section_dict[section][0], self.section_dict[section][1]]
320-
assert points[0] != points[
321-
1], 'The start and end points of the section must not be identical.'
322-
323-
self.points.append(points)
324-
self.resolution.append(self.section_dict[section][2])
325-
self.length = np.append(self.length, self.section_dict[section][2][0] *
326-
self.section_dict[section][2][1])
327-
self.length = np.array(self.length).cumsum()
328-
329-
def calculate_all_distances(self):
330-
self.coordinates = np.array(self.points).ravel().reshape(-1,
331-
4) # axis are x1,y1,x2,y2
332-
self.dist = np.sqrt(np.diff(self.coordinates[:, [0, 2]]) ** 2 + np.diff(
333-
self.coordinates[:, [1, 3]]) ** 2)
334-
335-
def compute_section_coordinates(self):
336-
for i in range(len(self.names)):
337-
xy = calculate_line_coordinates_2points(self.coordinates[i, :2],
338-
self.coordinates[i, 2:],
339-
self.resolution[i][0])
340-
zaxis = np.linspace(self.z_ext[0], self.z_ext[1], self.resolution[i][1],
341-
dtype="float64")
342-
X, Z = np.meshgrid(xy[:, 0], zaxis, indexing='ij')
343-
Y, _ = np.meshgrid(xy[:, 1], zaxis, indexing='ij')
344-
xyz = np.vstack((X.flatten(), Y.flatten(), Z.flatten())).T
345-
if i == 0:
346-
self.values = xyz
347-
else:
348-
self.values = np.vstack((self.values, xyz))
349-
350-
def generate_axis_coord(self):
351-
for i, name in enumerate(self.names):
352-
xy = calculate_line_coordinates_2points(
353-
self.coordinates[i, :2],
354-
self.coordinates[i, 2:],
355-
self.resolution[i][0]
356-
)
357-
yield name, xy
358-
359-
def get_section_args(self, section_name: str):
360-
where = np.where(self.names == section_name)[0][0]
361-
return self.length[where], self.length[where + 1]
362-
363-
def get_section_grid(self, section_name: str):
364-
l0, l1 = self.get_section_args(section_name)
365-
return self.values[l0:l1]
366-
367-
368-
@dataclasses.dataclass
369-
class CustomGrid:
370-
"""Object that contains arbitrary XYZ coordinates.
371-
372-
Args:
373-
xyx_coords (numpy.ndarray like): XYZ (in columns) of the desired coordinates
374-
375-
Attributes:
376-
values (np.ndarray): XYZ coordinates
377-
"""
378-
379-
values: np.ndarray = Field(
380-
exclude=True,
381-
default_factory=lambda: np.zeros((0, 3)),
382-
repr=False
383-
)
384-
385-
386-
def __post_init__(self):
387-
custom_grid = np.atleast_2d(self.values)
388-
assert type(custom_grid) is np.ndarray and custom_grid.shape[1] == 3, \
389-
'The shape of new grid must be (n,3) where n is the number of' \
390-
' points of the grid'
391-
392-
393-
@property
394-
def length(self):
395-
return self.values.shape[0]
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import dataclasses
2+
import numpy as np
3+
4+
from gempy.core.data.core_utils import calculate_line_coordinates_2points
5+
from gempy.optional_dependencies import require_pandas
6+
7+
8+
@dataclasses.dataclass
9+
class Sections:
10+
"""
11+
Object that creates a grid of cross sections between two points.
12+
13+
Args:
14+
regular_grid: Model.grid.regular_grid
15+
section_dict: {'section name': ([p1_x, p1_y], [p2_x, p2_y], [xyres, zres])}
16+
"""
17+
18+
def __init__(self, regular_grid=None, z_ext=None, section_dict=None):
19+
pd = require_pandas()
20+
if regular_grid is not None:
21+
self.z_ext = regular_grid.extent[4:]
22+
else:
23+
self.z_ext = z_ext
24+
25+
self.section_dict = section_dict
26+
self.names = []
27+
self.points = []
28+
self.resolution = []
29+
self.length = [0]
30+
self.dist = []
31+
self.df = pd.DataFrame()
32+
self.df['dist'] = self.dist
33+
self.values = np.empty((0, 3))
34+
self.extent = None
35+
36+
if section_dict is not None:
37+
self.set_sections(section_dict)
38+
39+
def _repr_html_(self):
40+
return self.df.to_html()
41+
42+
def __repr__(self):
43+
return self.df.to_string()
44+
45+
def show(self):
46+
pass
47+
48+
def set_sections(self, section_dict, regular_grid=None, z_ext=None):
49+
pd = require_pandas()
50+
self.section_dict = section_dict
51+
if regular_grid is not None:
52+
self.z_ext = regular_grid.extent[4:]
53+
54+
self.names = np.array(list(self.section_dict.keys()))
55+
56+
self.get_section_params()
57+
self.calculate_all_distances()
58+
self.df = pd.DataFrame.from_dict(self.section_dict, orient='index', columns=['start', 'stop', 'resolution'])
59+
self.df['dist'] = self.dist
60+
61+
self.compute_section_coordinates()
62+
63+
def get_section_params(self):
64+
self.points = []
65+
self.resolution = []
66+
self.length = [0]
67+
68+
for i, section in enumerate(self.names):
69+
points = [self.section_dict[section][0], self.section_dict[section][1]]
70+
assert points[0] != points[
71+
1], 'The start and end points of the section must not be identical.'
72+
73+
self.points.append(points)
74+
self.resolution.append(self.section_dict[section][2])
75+
self.length = np.append(self.length, self.section_dict[section][2][0] *
76+
self.section_dict[section][2][1])
77+
self.length = np.array(self.length).cumsum()
78+
79+
def calculate_all_distances(self):
80+
self.coordinates = np.array(self.points).ravel().reshape(-1,
81+
4) # axis are x1,y1,x2,y2
82+
self.dist = np.sqrt(np.diff(self.coordinates[:, [0, 2]]) ** 2 + np.diff(
83+
self.coordinates[:, [1, 3]]) ** 2)
84+
85+
def compute_section_coordinates(self):
86+
for i in range(len(self.names)):
87+
xy = calculate_line_coordinates_2points(self.coordinates[i, :2],
88+
self.coordinates[i, 2:],
89+
self.resolution[i][0])
90+
zaxis = np.linspace(self.z_ext[0], self.z_ext[1], self.resolution[i][1],
91+
dtype="float64")
92+
X, Z = np.meshgrid(xy[:, 0], zaxis, indexing='ij')
93+
Y, _ = np.meshgrid(xy[:, 1], zaxis, indexing='ij')
94+
xyz = np.vstack((X.flatten(), Y.flatten(), Z.flatten())).T
95+
if i == 0:
96+
self.values = xyz
97+
else:
98+
self.values = np.vstack((self.values, xyz))
99+
100+
def generate_axis_coord(self):
101+
for i, name in enumerate(self.names):
102+
xy = calculate_line_coordinates_2points(
103+
self.coordinates[i, :2],
104+
self.coordinates[i, 2:],
105+
self.resolution[i][0]
106+
)
107+
yield name, xy
108+
109+
def get_section_args(self, section_name: str):
110+
where = np.where(self.names == section_name)[0][0]
111+
return self.length[where], self.length[where + 1]
112+
113+
def get_section_grid(self, section_name: str):
114+
l0, l1 = self.get_section_args(section_name)
115+
return self.values[l0:l1]

gempy/core/data/grid_modules/topography.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88

9-
from .grid_types import RegularGrid
9+
from .regular_grid import RegularGrid
1010
from ....modules.grids.create_topography import _LoadDEMArtificial
1111

1212
from ....optional_dependencies import require_skimage

0 commit comments

Comments
 (0)