Skip to content

Commit 25f5d09

Browse files
authored
Improve 3D visualization with consistent color handling and scalar bar display (#7)
# Improved 3D Visualization with Enhanced Color Mapping This PR refactors the 3D plotting utilities to provide better color mapping and visualization. Key improvements include: - Refactored scalar bar implementation with proper color mapping and annotations - Added support for custom colormaps in surface point visualization - Fixed lithology block visualization by properly handling element colors - Improved actor priority for scalar bar assignment - Enhanced scalar bar appearance with better formatting and positioning - Changed default scalar field colormap from 'viridis' to 'magma' - Fixed vectorization of IDs for consistent color mapping between surface points and orientations - Optimized structural frame access with local variable assignment - Improved scalar bar configuration with cleaner annotations and better visual styling - Added gempy_engine dependency to requirements.txt These changes provide more consistent and visually appealing 3D visualizations with proper color mapping between different elements in the model.
2 parents 39cae3d + d87015d commit 25f5d09

File tree

9 files changed

+172
-90
lines changed

9 files changed

+172
-90
lines changed

gempy_viewer/API/_plot_3d_API.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,20 +114,21 @@ def plot_3d(
114114
**kwargs_plotter
115115
)
116116

117+
structural_frame = model.structural_frame
117118
if data_to_show.show_topography[0] is True and model.grid.topography is not None:
118119
plot_topography_3d(
119120
gempy_vista=gempy_vista,
120121
topography=model.grid.topography,
121122
solution=solutions_raw_arrays,
122123
topography_scalar_type=topography_scalar_type,
123-
elements_colors=model.structural_frame.elements_colors[::-1],
124+
elements_colors=structural_frame.elements_colors[::-1],
124125
contours=kwargs_plot_topography.get('contours', True),
125126
**kwargs_plot_topography
126127
)
127128

128129
if data_to_show.show_boundaries[0] is True:
129130
# Check elements to plot .vertices are not empty
130-
elements_to_plot = model.structural_frame.structural_elements
131+
elements_to_plot = structural_frame.structural_elements
131132
for element in elements_to_plot:
132133
if element.vertices is None:
133134
elements_to_plot.remove(element)
@@ -174,7 +175,7 @@ def plot_3d(
174175
scalar_data_type=ScalarDataType.LITHOLOGY,
175176
active_scalar_field="lith",
176177
solution=solutions_raw_arrays,
177-
cmap=get_geo_model_cmap(model.structural_frame.elements_colors_volumes),
178+
cmap=get_geo_model_cmap(structural_frame.elements_colors),
178179
**kwargs_plot_structured_grid
179180
)
180181

@@ -186,15 +187,16 @@ def plot_3d(
186187
scalar_data_type=ScalarDataType.SCALAR_FIELD,
187188
active_scalar_field=active_scalar_field,
188189
solution=solutions_raw_arrays,
189-
cmap='viridis',
190+
cmap='magma',
190191
**kwargs_plot_structured_grid
191192
)
192193

193-
if True:
194+
if True:
194195
set_scalar_bar(
195196
gempy_vista=gempy_vista,
196-
elements_names = model.structural_frame.elements_names,
197-
surfaces_ids=model.structural_frame.elements_ids - 1
197+
elements_names=structural_frame.elements_names,
198+
surfaces_ids=structural_frame.elements_ids - 1,
199+
custom_colors=structural_frame.elements_colors_volumes
198200
)
199201

200202
if ve is not None:

gempy_viewer/modules/plot_3d/drawer_input_3d.py

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from gempy.core.data import GeoModel
44
from gempy.core.data.orientations import OrientationsTable
55
from gempy.core.data.surface_points import SurfacePointsTable
6+
from matplotlib.colors import ListedColormap
7+
68
from gempy_viewer.modules.plot_2d.plot_2d_utils import get_geo_model_cmap
79
from gempy_viewer.modules.plot_3d.vista import GemPyToVista
810
from gempy_viewer.optional_dependencies import require_pyvista
911

1012

11-
def plot_data(gempy_vista: GemPyToVista,
13+
def plot_data(gempy_vista: GemPyToVista,
1214
model: GeoModel,
1315
arrows_factor: float,
1416
transformed_data: bool = False,
@@ -19,96 +21,76 @@ def plot_data(gempy_vista: GemPyToVista,
1921
else:
2022
surface_points_copy = model.surface_points_copy
2123
orientations_copy = model.orientations_copy
22-
24+
2325
plot_surface_points(
2426
gempy_vista=gempy_vista,
2527
surface_points=surface_points_copy,
26-
elements_colors=model.structural_frame.elements_colors_contacts,
27-
**kwargs
28+
element_colors=model.structural_frame.elements_colors
2829
)
2930

3031
plot_orientations(
3132
gempy_vista=gempy_vista,
3233
orientations=orientations_copy,
33-
elements_colors=model.structural_frame.elements_colors_orientations,
34+
surface_points=surface_points_copy,
3435
arrows_factor=arrows_factor
3536
)
3637

3738

3839
def plot_surface_points(
3940
gempy_vista: GemPyToVista,
4041
surface_points: SurfacePointsTable,
41-
elements_colors: list[str],
4242
render_points_as_spheres=True,
43-
point_size=10,
44-
**kwargs
43+
element_colors=None,
44+
point_size=10
4545
):
46-
ids = surface_points.ids
47-
if ids.shape[0] == 0:
48-
return
49-
unique_values, first_indices = np.unique(ids, return_index=True) # Find the unique elements and their first indices
50-
unique_values_order = unique_values[np.argsort(first_indices)] # Sort the unique values by their first appearance in `a`
51-
52-
mapping_dict = {value: i for i, value in enumerate(unique_values_order)} # Use a dictionary to map the original numbers to new values
53-
mapped_array = np.vectorize(mapping_dict.get)(ids) # Map the original array to the new values
54-
5546
# Selecting the surfaces to plot
5647
xyz = surface_points.xyz
5748
if transfromed_data := False: # TODO: Expose this to user
5849
xyz = surface_points.model_transform.apply(xyz)
5950

6051
pv = require_pyvista()
6152
poly = pv.PolyData(xyz)
62-
poly['id'] = mapped_array
6353

64-
cmap = get_geo_model_cmap(
65-
elements_colors=np.array(elements_colors),
66-
reverse=False
67-
)
54+
ids = surface_points.ids
55+
if ids.shape[0] == 0:
56+
return
57+
vectorize_ids = _vectorize_ids(ids, ids)
58+
poly['id'] = vectorize_ids
6859

60+
custom_cmap = ListedColormap(element_colors)
61+
6962
gempy_vista.surface_points_mesh = poly
7063
gempy_vista.surface_points_actor = gempy_vista.p.add_mesh(
7164
mesh=poly,
72-
cmap=cmap, # TODO: Add colors
7365
scalars='id',
7466
render_points_as_spheres=render_points_as_spheres,
7567
point_size=point_size,
76-
show_scalar_bar=False
68+
show_scalar_bar=False,
69+
cmap=custom_cmap,
70+
clim=(0, np.unique(vectorize_ids).shape[0])
7771
)
7872

7973

8074
def plot_orientations(
8175
gempy_vista: GemPyToVista,
8276
orientations: OrientationsTable,
83-
elements_colors: list[str],
77+
surface_points: SurfacePointsTable,
8478
arrows_factor: float,
8579
):
8680
orientations_xyz = orientations.xyz
8781
orientations_grads = orientations.grads
88-
82+
8983
if orientations_xyz.shape[0] == 0:
9084
return
9185

9286
pv = require_pyvista()
9387
poly = pv.PolyData(orientations_xyz)
94-
95-
ids = orientations.ids
96-
if ids.shape[0] == 0:
97-
return
98-
unique_values, first_indices = np.unique(ids, return_index=True) # Find the unique elements and their first indices
99-
unique_values_order = unique_values[np.argsort(first_indices)] # Sort the unique values by their first appearance in `a`
100-
101-
mapping_dict = {value: i for i, value in enumerate(unique_values_order)} # Use a dictionary to map the original numbers to new values
102-
mapped_array = np.vectorize(mapping_dict.get)(ids) # Map the original array to the new values
103-
104-
poly['id'] = mapped_array
105-
poly['vectors'] = orientations_grads
10688

107-
# TODO: I am still trying to figure out colors and ids in orientations and surface points
108-
cmap = get_geo_model_cmap(
109-
elements_colors=np.array(elements_colors),
110-
reverse=False
89+
poly['id'] = _vectorize_ids(
90+
mapping_ids=surface_points.ids,
91+
ids_to_map=orientations.ids
11192
)
93+
poly['vectors'] = orientations_grads
11294

11395
arrows = poly.glyph(
11496
orient='vectors',
@@ -118,7 +100,21 @@ def plot_orientations(
118100

119101
gempy_vista.orientations_actor = gempy_vista.p.add_mesh(
120102
mesh=arrows,
121-
cmap=cmap,
103+
scalars='id',
122104
show_scalar_bar=False
123105
)
124106
gempy_vista.orientations_mesh = arrows
107+
108+
109+
def _vectorize_ids(mapping_ids, ids_to_map):
110+
def _mapping_dict(ids):
111+
unique_values, first_indices = np.unique(ids, return_index=True) # Find the unique elements and their first indices
112+
unique_values_order = unique_values[np.argsort(first_indices)] # Sort the unique values by their first appearance in `a`
113+
# Flip order to please pyvista vertical scalarbar
114+
unique_values_order = unique_values_order[::-1]
115+
mapping_dict = {value: i + 1 for i, value in enumerate(unique_values_order)} # Use a dictionary to map the original numbers to new values
116+
return mapping_dict
117+
118+
mapping_dict = _mapping_dict(mapping_ids)
119+
mapped_array = np.vectorize(mapping_dict.get)(ids_to_map) # Map the original array to the new values
120+
return mapped_array

gempy_viewer/modules/plot_3d/drawer_structured_grid_3d.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from gempy_engine.core.data.raw_arrays_solution import RawArraysSolution
77
from gempy_viewer.core.scalar_data_type import ScalarDataType
8-
from gempy.core.data.grid_modules import RegularGrid, Topography
8+
from gempy.core.data.grid_modules import Topography
99
from gempy_viewer.modules.plot_3d.vista import GemPyToVista
1010
from gempy_viewer.optional_dependencies import require_pyvista
1111

@@ -30,7 +30,6 @@ def plot_structured_grid(
3030
structured_grid = set_scalar_data(
3131
structured_grid=structured_grid,
3232
data=solution,
33-
resolution=resolution,
3433
scalar_data_type=scalar_data_type
3534
)
3635

@@ -64,18 +63,15 @@ def add_regular_grid_mesh(
6463
**kwargs
6564
):
6665
if isinstance(cmap, mcolors.Colormap):
67-
_clim = (0, cmap.N - 1)
66+
_clim = (0, cmap.N)
6867
else:
6968
_clim = None
7069
gempy_vista.regular_grid_actor = gempy_vista.p.add_mesh(
7170
mesh=structured_grid,
72-
cmap=cmap,
7371
# ? scalars=main_scalar, if we prepare the structured grid do we need this arg?
74-
show_scalar_bar=True,
75-
scalar_bar_args=gempy_vista.scalar_bar_arguments,
72+
show_scalar_bar=False,
7673
interpolate_before_map=True,
7774
opacity=opacity,
78-
clim=_clim,
7975
**kwargs
8076
)
8177

@@ -98,27 +94,26 @@ def _mask_topography(structured_grid: "pv.StructuredGrid", topography: Topograph
9894
def set_scalar_data(
9995
data: RawArraysSolution,
10096
structured_grid: "pv.StructuredGrid",
101-
resolution: np.ndarray,
10297
scalar_data_type: ScalarDataType,
10398
) -> "pv.StructuredGrid":
104-
def _convert_sol_array_to_fortran_order(array: np.ndarray) -> np.ndarray:
105-
# ? (Miguel Jun 24) Is this function deprecated?
106-
# return array.reshape(*resolution, order='C').ravel(order='F')
107-
return array
108-
99+
109100
# Substitute the madness of the previous if with match
110101
match scalar_data_type:
111102
case ScalarDataType.LITHOLOGY | ScalarDataType.ALL:
112-
structured_grid.cell_data['id'] = _convert_sol_array_to_fortran_order(data.lith_block - 1)
103+
max_lith = data.n_surfaces # (for basement)
104+
block_ = max_lith - (data.lith_block - 1)
105+
structured_grid.cell_data['id'] = block_
113106
case ScalarDataType.SCALAR_FIELD | ScalarDataType.ALL:
114107
scalar_field_ = 'sf_'
115108
for e in range(data.scalar_field_matrix.shape[0]):
116109
# TODO: Ideally we will have the group name instead the enumeration
117-
structured_grid[scalar_field_ + str(e)] = _convert_sol_array_to_fortran_order(data.scalar_field_matrix[e])
110+
array1 = data.scalar_field_matrix[e]
111+
structured_grid[scalar_field_ + str(e)] = array1
118112
case ScalarDataType.VALUES | ScalarDataType.ALL:
119113
scalar_field_ = 'values_'
120114
for e in range(data.values_matrix.shape[0]):
121-
structured_grid[scalar_field_ + str(e)] = _convert_sol_array_to_fortran_order(data.values_matrix[e])
115+
array2 = data.values_matrix[e]
116+
structured_grid[scalar_field_ + str(e)] = array2
122117
case _:
123118
raise ValueError(f'Unknown scalar data type: {scalar_data_type}')
124119

gempy_viewer/modules/plot_3d/plot_3d_utils.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import pandas as pd
5+
from matplotlib.colors import ListedColormap
56

67
from gempy_viewer.modules.plot_3d.vista import GemPyToVista
78

@@ -23,25 +24,68 @@ def select_surfaces_data(data_df: pd.DataFrame, surfaces: Union[str, list[str]]
2324
return geometric_data
2425

2526

26-
def set_scalar_bar(gempy_vista: GemPyToVista, elements_names: list[str], surfaces_ids: np.ndarray):
27+
def set_scalar_bar(gempy_vista: GemPyToVista, elements_names: list[str],
28+
surfaces_ids: np.ndarray, custom_colors: list = None):
29+
"""
30+
31+
LookupTable (0x7f3d1dc62e00)
32+
Table Range: (0.0, 2.0)
33+
N Values: 256
34+
Above Range Color: None
35+
Below Range Color: None
36+
NAN Color: Color(name='darkgray', hex='#a9a9a9ff', opacity=255)
37+
Log Scale: False
38+
Color Map: "viridis"
39+
"""
2740
import pyvista as pv
28-
41+
2942
# Get mapper actor
30-
if gempy_vista.surface_points_actor is not None:
31-
mapper_actor: pv.Actor = gempy_vista.surface_points_actor
32-
elif gempy_vista.regular_grid_actor is not None:
43+
if gempy_vista.regular_grid_actor is not None:
3344
mapper_actor = gempy_vista.regular_grid_actor
45+
elif gempy_vista.surface_points_actor is not None:
46+
mapper_actor: pv.Actor = gempy_vista.surface_points_actor
3447
else:
3548
return None # * Not a good mapper for the scalar bar
36-
49+
50+
# Get the lookup table from the mapper
51+
lut = mapper_actor.mapper.lookup_table
52+
53+
# Create annotations mapping integers to element names
3754
annotations = {}
38-
for e, name in enumerate(elements_names):
39-
annotations[e] = name
55+
for e, name in enumerate(elements_names[::-1]):
56+
# Convert integer to string for the annotation key
57+
annotations[str(e)] = name
58+
59+
# Apply annotations to the lookup table
60+
lut.annotations = annotations
61+
62+
# Set number of colors to match the number of categories
63+
n_colors = len(elements_names)
64+
lut.n_values = n_colors - 1
4065

41-
mapper_actor.mapper.lookup_table.annotations = annotations
42-
66+
# Apply custom colors if provided
67+
if custom_colors is not None:
68+
# Check if we have enough colors
69+
if len(custom_colors) < n_colors:
70+
raise ValueError(f"Not enough custom colors provided. Got {len(custom_colors)}, need {n_colors}")
71+
72+
custom_cmap = ListedColormap(custom_colors)
73+
# Apply the custom colormap to the lookup table
74+
lut.apply_cmap(cmap=custom_cmap, n_values=n_colors, flip=False)
75+
76+
else:
77+
# Apply a default colormap if no custom colors are provided
78+
lut.apply_cmap(cmap='Set1', n_values=n_colors)
79+
80+
# Configure scalar bar arguments
4381
sargs = gempy_vista.scalar_bar_arguments
82+
min_id, max_id = surfaces_ids.min(), surfaces_ids.max()
83+
mapper_actor.mapper.scalar_range = (min_id - .4, max_id + .5)
84+
4485
sargs["mapper"] = mapper_actor.mapper
45-
86+
sargs["n_labels"] = 0
87+
88+
# Add scalar bar
4689
gempy_vista.p.add_scalar_bar(**sargs)
47-
gempy_vista.p.update_scalar_bar_range((surfaces_ids.min(), surfaces_ids.max())) # * This has to be here to now screw the colors with the volumes
90+
91+
# Update scalar bar range to match surface IDs range

gempy_viewer/modules/plot_3d/vista.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,14 @@ def scalar_bar_arguments(self):
126126
sargs = dict(
127127
title_font_size=20,
128128
label_font_size=16,
129-
shadow=True,
129+
shadow=False,
130130
italic=True,
131+
bold=True,
131132
font_family="arial",
132133
height=0.25,
133134
vertical=True,
134-
position_x=0.15,
135-
title="id",
135+
position_x=0.1,
136+
title="Elements",
136137
fmt="%.0f",
137138
)
138139
return sargs

0 commit comments

Comments
 (0)