Skip to content

Commit e9911cd

Browse files
committed
[WIP] Playing with different optimizers
1 parent 4c5eef2 commit e9911cd

File tree

3 files changed

+155
-26
lines changed

3 files changed

+155
-26
lines changed

gempy/modules/optimize_nuggets/_optimizer.py

Lines changed: 118 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def nugget_optimizer(
2323
lr: float = 1e-2,
2424
patience: int = 10,
2525
min_impr: float = 0.01,
26-
) -> float:
26+
) -> GeoModel:
2727
"""
2828
Optimize the nugget effect scalar to achieve a target condition number.
2929
Returns the final nugget effect value.
@@ -59,9 +59,25 @@ def nugget_optimizer(
5959
except ContinueEpoch:
6060
# Keep only top 10% gradients
6161
if False:
62-
_gradient_masking(nugget)
62+
_gradient_masking(
63+
nugget=nugget,
64+
focus=0.01
65+
)
66+
elif True:
67+
if epoch % 5 == 0:
68+
# if True:
69+
grads = nugget.grad.abs().view(-1)
70+
q1, q3 = grads.quantile(0.25), grads.quantile(0.75)
71+
iqr = q3 - q1
72+
thresh = q3 + 1.5 * iqr
73+
mask = grads > thresh
74+
75+
# print the indices of mask
76+
print(f"Outliers: {torch.nonzero(mask)}")
77+
78+
_gradient_foo(nugget_effect_scalar=nugget, mask=mask)
6379
else:
64-
clip_grad_norm_(parameters=[nugget], max_norm=1.0)
80+
clip_grad_norm_(parameters=[nugget], max_norm=0.0001)
6581

6682
# Step & clamp safely
6783
opt.step()
@@ -77,13 +93,20 @@ def nugget_optimizer(
7793
prev_cond = cur_cond
7894

7995
model.interpolation_options.kernel_options.optimizing_condition_number = False
80-
return nugget.item()
96+
return model
8197

8298

83-
def _gradient_masking(nugget):
99+
def _gradient_foo(nugget_effect_scalar: torch.Tensor, mask):
100+
101+
# amplify outliers if you want bigger jumps
102+
nugget_effect_scalar.grad[mask] *= 5.0
103+
# zero all other gradients
104+
nugget_effect_scalar.grad[~mask] = 0
105+
106+
def _gradient_masking(nugget, focus = 0.01):
84107
"""Old way of avoiding exploding gradients."""
85108
grads = nugget.grad.abs()
86-
k = int(grads.numel() * 0.1)
109+
k = int(grads.numel() * focus)
87110
top_vals, top_idx = torch.topk(grads, k, largest=True)
88111
mask = torch.zeros_like(grads)
89112
mask[top_idx] = 1
@@ -105,3 +128,92 @@ def _has_converged(
105128
rel_impr = abs(current - previous) / max(previous, 1e-8)
106129
return rel_impr < min_improvement
107130
return False
131+
132+
133+
# region legacy
134+
def nugget_optimizer__legacy(target_cond_num, engine_cfg, model, max_epochs):
135+
geo_model: GeoModel = model
136+
convergence_criteria = target_cond_num
137+
engine_config = engine_cfg
138+
139+
BackendTensor.change_backend_gempy(
140+
engine_backend=engine_config.backend,
141+
use_gpu=engine_config.use_gpu,
142+
dtype=engine_config.dtype
143+
)
144+
import torch
145+
from gempy_engine.core.data.continue_epoch import ContinueEpoch
146+
147+
interpolation_input: InterpolationInput = interpolation_input_from_structural_frame(geo_model)
148+
geo_model.taped_interpolation_input = interpolation_input
149+
nugget_effect_scalar: torch.Tensor = geo_model.taped_interpolation_input.surface_points.nugget_effect_scalar
150+
nugget_effect_scalar.requires_grad = True
151+
optimizer = torch.optim.Adam(
152+
params=[nugget_effect_scalar],
153+
lr=0.01,
154+
)
155+
# Optimization loop
156+
geo_model.interpolation_options.kernel_options.optimizing_condition_number = True
157+
158+
previous_condition_number = 0
159+
for epoch in range(max_epochs):
160+
optimizer.zero_grad()
161+
try:
162+
# geo_model.taped_interpolation_input.grid = geo_model.interpolation_input_copy.grid
163+
164+
gempy_engine.compute_model(
165+
interpolation_input=geo_model.taped_interpolation_input,
166+
options=geo_model.interpolation_options,
167+
data_descriptor=geo_model.input_data_descriptor,
168+
geophysics_input=geo_model.geophysics_input,
169+
)
170+
except ContinueEpoch:
171+
# Get absolute values of gradients
172+
grad_magnitudes = torch.abs(nugget_effect_scalar.grad)
173+
174+
# Get indices of the 10 largest gradients
175+
grad_magnitudes.size
176+
177+
# * This ignores 90 percent of the gradients
178+
# To int
179+
n_values = int(grad_magnitudes.size()[0] * 0.9)
180+
_, indices = torch.topk(grad_magnitudes, n_values, largest=False)
181+
182+
# Zero out gradients that are not in the top 10
183+
mask = torch.ones_like(nugget_effect_scalar.grad)
184+
mask[indices] = 0
185+
nugget_effect_scalar.grad *= mask
186+
187+
# Update the vector
188+
optimizer.step()
189+
nugget_effect_scalar.data = nugget_effect_scalar.data.clamp_(min=1e-7) # Replace negative values with 0
190+
191+
# optimizer.zero_grad()
192+
# Monitor progress
193+
if epoch % 1 == 0:
194+
# print(f"Epoch {epoch}: Condition Number = {condition_number.item()}")
195+
print(f"Epoch {epoch}")
196+
197+
if _check_convergence_criterion(
198+
conditional_number=geo_model.interpolation_options.kernel_options.condition_number,
199+
condition_number_old=previous_condition_number,
200+
conditional_number_target=convergence_criteria,
201+
epoch=epoch
202+
):
203+
break
204+
previous_condition_number = geo_model.interpolation_options.kernel_options.condition_number
205+
continue
206+
geo_model.interpolation_options.kernel_options.optimizing_condition_number = False
207+
return geo_model
208+
209+
210+
def _check_convergence_criterion(conditional_number: float, condition_number_old: float, conditional_number_target: float = 1e5, epoch: int = 0):
211+
import torch
212+
reached_conditional_target = conditional_number < conditional_number_target
213+
if reached_conditional_target == False and epoch > 10:
214+
condition_number_change = torch.abs(conditional_number - condition_number_old) / condition_number_old
215+
if condition_number_change < 0.01:
216+
reached_conditional_target = True
217+
return reached_conditional_target
218+
219+
# endregion

test/test_private/test_terranigma/test_nuggets/_aux_func.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,22 @@ def initialize_geo_model(structural_elements: list[gp.data.StructuralElement], e
4343
structural_relation=gp.data.StackRelationType.ERODE
4444
)
4545

46-
structural_groups = [structural_group_intrusion, structural_group_green, structural_group_blue, structural_group_red]
46+
structural_groups = [
47+
# structural_group_intrusion,
48+
# structural_group_green,
49+
# structural_group_blue,
50+
structural_group_red
51+
]
4752
structural_frame = gp.data.StructuralFrame(
48-
structural_groups=structural_groups[2:],
53+
structural_groups=structural_groups,
4954
color_gen=gp.data.ColorsGenerator()
5055
)
5156
# TODO: If elements do not have color maybe loop them on structural frame constructor?
5257

5358
geo_model: gp.data.GeoModel = gp.create_geomodel(
5459
project_name='Tutorial_ch1_1_Basics',
5560
extent=extent,
56-
resolution=[20, 10, 20],
61+
# resolution=[20, 10, 20],
5762
refinement=5, # * Here we define the number of octree levels. If octree levels are defined, the resolution is ignored.
5863
structural_frame=structural_frame
5964
)

test/test_private/test_terranigma/test_nuggets/test_nugget_effect_optimization.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
dotenv.load_dotenv()
1414

15-
1615
# %%
1716
# Config
1817
seed = 123456
@@ -27,44 +26,57 @@
2726
# Load necessary configuration and paths from environment variables
2827
path = os.getenv("PATH_TO_NUGGET_TEST_MODEL")
2928

29+
3030
def test_optimize_nugget_effect():
3131
# Initialize lists to store structural elements for the geological model
3232
structural_elements = []
3333
global_extent = None
3434
color_gen = gp.data.ColorsGenerator()
35-
35+
3636
for filename in os.listdir(path):
3737
base, ext = os.path.splitext(filename)
3838
if ext == '.nc':
3939
structural_element, global_extent = process_file(os.path.join(path, filename), global_extent, color_gen)
4040
structural_elements.append(structural_element)
4141

42-
4342
import xarray as xr
4443
geo_model: gp.data.GeoModel = initialize_geo_model(
4544
structural_elements=structural_elements,
4645
extent=(np.array(global_extent)),
4746
topography=(xr.open_dataset(os.path.join(path, "Topography.nc")))
4847
)
49-
48+
5049
if False:
5150
gpv.plot_3d(geo_model, show_data=True, image=True)
5251

53-
5452
geo_model.interpolation_options.cache_mode = gp.data.InterpolationOptions.CacheMode.NO_CACHE
55-
gp.API.compute_API.optimize_and_compute(
56-
geo_model=geo_model,
57-
engine_config=gp.data.GemPyEngineConfig(
58-
backend=gp.data.AvailableBackends.PYTORCH,
59-
),
60-
max_epochs=100,
61-
convergence_criteria=1e5
62-
)
6353

64-
nugget_effect = geo_model.taped_interpolation_input.surface_points.nugget_effect_scalar.detach().numpy()
54+
if True:
55+
gp.API.compute_API.optimize_and_compute(
56+
geo_model=geo_model,
57+
engine_config=gp.data.GemPyEngineConfig(
58+
backend=gp.data.AvailableBackends.PYTORCH,
59+
),
60+
max_epochs=100,
61+
convergence_criteria=1e5
62+
)
63+
64+
print(f"Final cond number: {geo_model.interpolation_options.kernel_options.condition_number}")
65+
nugget_effect = geo_model.taped_interpolation_input.surface_points.nugget_effect_scalar.detach().numpy()
66+
67+
else:
68+
gp.compute_model(
69+
gempy_model=geo_model,
70+
engine_config=gp.data.GemPyEngineConfig(
71+
backend=gp.data.AvailableBackends.PYTORCH,
72+
),
73+
validate_serialization=False
74+
)
75+
6576

77+
nugget_effect = geo_model.taped_interpolation_input.surface_points.nugget_effect_scalar.detach().numpy()
6678

67-
if plot_evaluation:=True:
79+
if plot_evaluation := True:
6880
import matplotlib.pyplot as plt
6981

7082
plt.hist(nugget_effect, bins=50, color='black', alpha=0.7, log=True)
@@ -73,15 +85,15 @@ def test_optimize_nugget_effect():
7385
plt.title('Histogram of Eigenvalues (nugget-grad)')
7486
plt.show()
7587

76-
77-
if plot_result:=True:
88+
if plot_result := True:
7889
import gempy_viewer as gpv
7990
import pyvista as pv
8091

8192
gempy_vista = gpv.plot_3d(
8293
model=geo_model,
8394
show=False,
8495
show_boundaries=True,
96+
show_topography=False,
8597
kwargs_plot_structured_grid={'opacity': 0.3}
8698
)
8799

0 commit comments

Comments
 (0)