Skip to content

Commit c7ea098

Browse files
committed
[WIP] Getting the mapping for the scalar bar in pyvista correct
1 parent 3e7ec2a commit c7ea098

File tree

6 files changed

+41
-46
lines changed

6 files changed

+41
-46
lines changed

gempy_viewer/API/_plot_3d_API.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def plot_3d(
175175
scalar_data_type=ScalarDataType.LITHOLOGY,
176176
active_scalar_field="lith",
177177
solution=solutions_raw_arrays,
178-
cmap=get_geo_model_cmap(structural_frame.elements_colors_volumes),
178+
cmap=get_geo_model_cmap(structural_frame.elements_colors),
179179
**kwargs_plot_structured_grid
180180
)
181181

gempy_viewer/modules/plot_3d/drawer_input_3d.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from gempy_viewer.optional_dependencies import require_pyvista
99

1010

11-
def plot_data(gempy_vista: GemPyToVista,
11+
def plot_data(gempy_vista: GemPyToVista,
1212
model: GeoModel,
1313
arrows_factor: float,
1414
transformed_data: bool = False,
@@ -19,7 +19,7 @@ def plot_data(gempy_vista: GemPyToVista,
1919
else:
2020
surface_points_copy = model.surface_points_copy
2121
orientations_copy = model.orientations_copy
22-
22+
2323
plot_surface_points(
2424
gempy_vista=gempy_vista,
2525
surface_points=surface_points_copy,
@@ -40,27 +40,21 @@ def plot_surface_points(
4040
surface_points: SurfacePointsTable,
4141
elements_colors: list[str],
4242
render_points_as_spheres=True,
43-
point_size=10,
43+
point_size=10,
4444
**kwargs
4545
):
46-
ids = surface_points.ids
47-
if ids.shape[0] == 0:
48-
return
49-
unique_values, first_indices = np.unique(ids, return_index=True) # Find the unique elements and their first indices
50-
unique_values_order = unique_values[np.argsort(first_indices)] # Sort the unique values by their first appearance in `a`
51-
52-
mapping_dict = {value: i for i, value in enumerate(unique_values_order)} # Use a dictionary to map the original numbers to new values
53-
mapped_array = np.vectorize(mapping_dict.get)(ids) # Map the original array to the new values
54-
5546
# Selecting the surfaces to plot
5647
xyz = surface_points.xyz
5748
if transfromed_data := False: # TODO: Expose this to user
5849
xyz = surface_points.model_transform.apply(xyz)
5950

6051
pv = require_pyvista()
6152
poly = pv.PolyData(xyz)
62-
poly['id'] = mapped_array
6353

54+
ids = surface_points.ids
55+
if ids.shape[0] == 0:
56+
return
57+
poly['id'] = _vectorize_ids(ids)
6458

6559
gempy_vista.surface_points_mesh = poly
6660
gempy_vista.surface_points_actor = gempy_vista.p.add_mesh(
@@ -80,23 +74,16 @@ def plot_orientations(
8074
):
8175
orientations_xyz = orientations.xyz
8276
orientations_grads = orientations.grads
83-
77+
8478
if orientations_xyz.shape[0] == 0:
8579
return
8680

8781
pv = require_pyvista()
8882
poly = pv.PolyData(orientations_xyz)
89-
83+
9084
ids = orientations.ids
91-
if ids.shape[0] == 0:
92-
return
93-
unique_values, first_indices = np.unique(ids, return_index=True) # Find the unique elements and their first indices
94-
unique_values_order = unique_values[np.argsort(first_indices)] # Sort the unique values by their first appearance in `a`
9585

96-
mapping_dict = {value: i for i, value in enumerate(unique_values_order)} # Use a dictionary to map the original numbers to new values
97-
mapped_array = np.vectorize(mapping_dict.get)(ids) # Map the original array to the new values
98-
99-
poly['id'] = mapped_array
86+
poly['id'] = _vectorize_ids(ids)
10087
poly['vectors'] = orientations_grads
10188

10289
# TODO: I am still trying to figure out colors and ids in orientations and surface points
@@ -117,3 +104,13 @@ def plot_orientations(
117104
show_scalar_bar=False
118105
)
119106
gempy_vista.orientations_mesh = arrows
107+
108+
109+
def _vectorize_ids(ids):
110+
unique_values, first_indices = np.unique(ids, return_index=True) # Find the unique elements and their first indices
111+
unique_values_order = unique_values[np.argsort(first_indices)] # Sort the unique values by their first appearance in `a`
112+
# Flip order to please pyvista vertical scalarbar
113+
unique_values_order = unique_values_order[::-1]
114+
mapping_dict = {value: i + 1 for i, value in enumerate(unique_values_order)} # Use a dictionary to map the original numbers to new values
115+
mapped_array = np.vectorize(mapping_dict.get)(ids) # Map the original array to the new values
116+
return mapped_array

gempy_viewer/modules/plot_3d/drawer_structured_grid_3d.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from gempy_engine.core.data.raw_arrays_solution import RawArraysSolution
77
from gempy_viewer.core.scalar_data_type import ScalarDataType
8-
from gempy.core.data.grid_modules import RegularGrid, Topography
8+
from gempy.core.data.grid_modules import Topography
99
from gempy_viewer.modules.plot_3d.vista import GemPyToVista
1010
from gempy_viewer.optional_dependencies import require_pyvista
1111

@@ -30,7 +30,6 @@ def plot_structured_grid(
3030
structured_grid = set_scalar_data(
3131
structured_grid=structured_grid,
3232
data=solution,
33-
resolution=resolution,
3433
scalar_data_type=scalar_data_type
3534
)
3635

@@ -64,18 +63,19 @@ def add_regular_grid_mesh(
6463
**kwargs
6564
):
6665
if isinstance(cmap, mcolors.Colormap):
67-
_clim = (0, cmap.N - 1)
66+
_clim = (0, cmap.N)
6867
else:
6968
_clim = None
7069
gempy_vista.regular_grid_actor = gempy_vista.p.add_mesh(
7170
mesh=structured_grid,
72-
cmap=cmap,
71+
# cmap=cmap,
7372
# ? scalars=main_scalar, if we prepare the structured grid do we need this arg?
7473
show_scalar_bar=False,
75-
scalar_bar_args=gempy_vista.scalar_bar_arguments,
74+
# scalar_bar_args=gempy_vista.scalar_bar_arguments,
7675
interpolate_before_map=True,
7776
opacity=opacity,
78-
clim=_clim,
77+
# flip_scalars=True,
78+
# clim=(4,0),
7979
**kwargs
8080
)
8181

@@ -98,27 +98,25 @@ def _mask_topography(structured_grid: "pv.StructuredGrid", topography: Topograph
9898
def set_scalar_data(
9999
data: RawArraysSolution,
100100
structured_grid: "pv.StructuredGrid",
101-
resolution: np.ndarray,
102101
scalar_data_type: ScalarDataType,
103102
) -> "pv.StructuredGrid":
104-
def _convert_sol_array_to_fortran_order(array: np.ndarray) -> np.ndarray:
105-
# ? (Miguel Jun 24) Is this function deprecated?
106-
# return array.reshape(*resolution, order='C').ravel(order='F')
107-
return array
108-
103+
109104
# Substitute the madness of the previous if with match
110105
match scalar_data_type:
111106
case ScalarDataType.LITHOLOGY | ScalarDataType.ALL:
112-
structured_grid.cell_data['id'] = _convert_sol_array_to_fortran_order(data.lith_block - 1)
107+
max_lith = data.n_surfaces # (for basement)
108+
structured_grid.cell_data['id'] = max_lith - (data.lith_block - 1)
113109
case ScalarDataType.SCALAR_FIELD | ScalarDataType.ALL:
114110
scalar_field_ = 'sf_'
115111
for e in range(data.scalar_field_matrix.shape[0]):
116112
# TODO: Ideally we will have the group name instead the enumeration
117-
structured_grid[scalar_field_ + str(e)] = _convert_sol_array_to_fortran_order(data.scalar_field_matrix[e])
113+
array1 = data.scalar_field_matrix[e]
114+
structured_grid[scalar_field_ + str(e)] = array1
118115
case ScalarDataType.VALUES | ScalarDataType.ALL:
119116
scalar_field_ = 'values_'
120117
for e in range(data.values_matrix.shape[0]):
121-
structured_grid[scalar_field_ + str(e)] = _convert_sol_array_to_fortran_order(data.values_matrix[e])
118+
array2 = data.values_matrix[e]
119+
structured_grid[scalar_field_ + str(e)] = array2
122120
case _:
123121
raise ValueError(f'Unknown scalar data type: {scalar_data_type}')
124122

gempy_viewer/modules/plot_3d/plot_3d_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def set_scalar_bar(gempy_vista: GemPyToVista, elements_names: list[str],
4141

4242
# Create annotations mapping integers to element names
4343
annotations = {}
44-
for e, name in enumerate(elements_names):
44+
for e, name in enumerate(elements_names[::-1]):
4545
# Convert integer to string for the annotation key
4646
annotations[str(e)] = name
4747

@@ -58,8 +58,7 @@ def set_scalar_bar(gempy_vista: GemPyToVista, elements_names: list[str],
5858
if len(custom_colors) < n_colors:
5959
raise ValueError(f"Not enough custom colors provided. Got {len(custom_colors)}, need {n_colors}")
6060

61-
custom_cmap = ListedColormap(custom_colors).reversed()
62-
61+
custom_cmap = ListedColormap(custom_colors)
6362
# Apply the custom colormap to the lookup table
6463
lut.apply_cmap(cmap=custom_cmap, n_values=n_colors)
6564

gempy_viewer/modules/plot_3d/vista.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,14 @@ def scalar_bar_arguments(self):
126126
sargs = dict(
127127
title_font_size=20,
128128
label_font_size=16,
129-
shadow=True,
129+
shadow=False,
130130
italic=True,
131+
bold=True,
131132
font_family="arial",
132133
height=0.25,
133134
vertical=True,
134-
position_x=0.15,
135-
title="id",
135+
position_x=0.1,
136+
title="Elements",
136137
fmt="%.0f",
137138
)
138139
return sargs

tests/test_plotting/test_plot_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_plot_3d_solutions(self, one_fault_model_topo_solution):
2626
show_topography=False,
2727
show_scalar=False,
2828
show_lith=True,
29-
show_data=True,
29+
show_data=False,
3030
show_boundaries=True,
3131
image=True
3232
)

0 commit comments

Comments
 (0)