Skip to content

Commit c85d5f6

Browse files
committed
[ENH] Refactor nugget optimization for modularity and group handling
Streamlined nugget optimization logic by delegating it to a standalone `run_optimization` function, improving reusability and clarity. Added support for optimizing specific structural groups, enhancing flexibility in nugget effect optimization workflows.
1 parent 4d8080f commit c85d5f6

File tree

6 files changed

+180
-229
lines changed

6 files changed

+180
-229
lines changed

gempy/API/compute_API.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
import dotenv
2+
import numpy as np
23
import os
3-
44
from typing import Optional
55

6-
import numpy as np
7-
86
import gempy_engine
9-
from gempy_engine.core.backend_tensor import BackendTensor
107
from gempy.API.gp2_gp3_compatibility.gp3_to_gp2_input import gempy3_to_gempy2
118
from gempy_engine.config import AvailableBackends
9+
from gempy_engine.core.backend_tensor import BackendTensor
1210
from gempy_engine.core.data import Solutions
13-
from gempy_engine.core.data.interpolation_input import InterpolationInput
1411
from .grid_API import set_custom_grid
12+
from ..core.data import StructuralGroup
1513
from ..core.data.gempy_engine_config import GemPyEngineConfig
1614
from ..core.data.geo_model import GeoModel
1715
from ..modules.data_manipulation import interpolation_input_from_structural_frame
@@ -96,18 +94,29 @@ def compute_model_at(gempy_model: GeoModel, at: np.ndarray,
9694
return sol.raw_arrays.custom
9795

9896

99-
def optimize_and_compute(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10,
100-
convergence_criteria: float = 1e5):
97+
def optimize_nuggets(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10,
98+
convergence_criteria: float = 1e5, only_groups:list[StructuralGroup] | None = None) -> GeoModel:
99+
"""
100+
Optimize the nuggets of the interpolation input of the provided model.
101+
"""
102+
101103
if engine_config.backend != AvailableBackends.PYTORCH:
102104
raise ValueError(f'Only PyTorch backend is supported for optimization. Received {engine_config.backend}')
103-
105+
104106
geo_model = nugget_optimizer(
105107
target_cond_num=convergence_criteria,
106108
engine_cfg=engine_config,
107109
model=geo_model,
108110
max_epochs=max_epochs,
111+
only_groups=only_groups
109112
)
110113

114+
return geo_model
115+
116+
def optimize_and_compute(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10,
117+
convergence_criteria: float = 1e5):
118+
119+
optimize_nuggets(geo_model, engine_config, max_epochs, convergence_criteria)
111120
geo_model.solutions = gempy_engine.compute_model(
112121
interpolation_input=geo_model.taped_interpolation_input,
113122
options=geo_model.interpolation_options,

gempy/core/data/geo_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from gempy_engine.core.data.interpolation_input import InterpolationInput
1616
from gempy_engine.core.data.raw_arrays_solution import RawArraysSolution
1717
from gempy_engine.core.data.transforms import Transform, GlobalAnisotropy
18-
from gempy_engine.modules.geophysics.gravity_gradient import calculate_gravity_gradient
1918
from .encoders.converters import instantiate_if_necessary
2019
from .encoders.json_geomodel_encoder import encode_numpy_array
2120
from .grid import Grid
@@ -319,6 +318,7 @@ def deserialize_properties(cls, data: Union["GeoModel", dict], constructor: Mode
319318
# * Reset geophysics if necessary
320319
centered_grid = instance.grid.centered_grid
321320
if centered_grid is not None and instance.geophysics_input is not None:
321+
from gempy_engine.modules.geophysics.gravity_gradient import calculate_gravity_gradient
322322
instance.geophysics_input.tz = calculate_gravity_gradient(centered_grid)
323323

324324
return instance
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import torch
2+
3+
import gempy_engine
4+
from gempy_engine.core.data.continue_epoch import ContinueEpoch
5+
6+
7+
def run_optimization(lr, max_epochs, min_impr, model, nugget, patience, target_cond_num):
8+
opt = torch.optim.Adam(
9+
params=[
10+
nugget,
11+
],
12+
lr=lr
13+
)
14+
prev_cond = float('inf')
15+
for epoch in range(max_epochs):
16+
opt.zero_grad()
17+
try:
18+
gempy_engine.compute_model(
19+
interpolation_input=model.taped_interpolation_input,
20+
options=model.interpolation_options,
21+
data_descriptor=model.input_data_descriptor,
22+
geophysics_input=model.geophysics_input,
23+
)
24+
except ContinueEpoch:
25+
if True:
26+
# Keep only top 10% gradients
27+
_gradient_masking(nugget, focus=0.01)
28+
else:
29+
if epoch % 1 == 0:
30+
mask_sp = _mask_iqr(nugget.grad.abs().view(-1), multiplier=3)
31+
print(f"Outliers sp: {torch.nonzero(mask_sp)}")
32+
_apply_outlier_gradients(tensor=nugget, mask=mask_sp)
33+
34+
# Step & clamp safely
35+
opt.step()
36+
with torch.no_grad():
37+
nugget.clamp_(min=1e-7)
38+
39+
# Evaluate condition number
40+
cur_cond = model.interpolation_options.kernel_options.condition_number
41+
print(f"[Epoch {epoch}] cond. num. = {cur_cond:.2e}")
42+
43+
if _has_converged(cur_cond, prev_cond, target_cond_num, epoch, min_impr, patience):
44+
break
45+
prev_cond = cur_cond
46+
47+
return nugget
48+
49+
50+
def _mask_iqr(grads, multiplier: float = 1.5) -> torch.BoolTensor:
51+
q1, q3 = grads.quantile(0.25), grads.quantile(0.75)
52+
thresh = q3 + multiplier * (q3 - q1)
53+
return grads > thresh
54+
55+
def _apply_outlier_gradients(
56+
tensor: torch.Tensor,
57+
mask: torch.BoolTensor,
58+
amplification: float = 1.0,
59+
):
60+
# wrap in no_grad if you prefer, but .grad modifications are fine
61+
tensor.grad.view(-1)[mask] *= amplification
62+
tensor.grad.view(-1)[~mask] = 0
63+
64+
65+
66+
def _gradient_masking(nugget, focus=0.01):
67+
"""Old way of avoiding exploding gradients."""
68+
grads = nugget.grad.abs()
69+
k = int(grads.numel() * focus)
70+
top_vals, top_idx = torch.topk(grads, k, largest=True)
71+
mask = torch.zeros_like(grads)
72+
mask[top_idx] = 1
73+
nugget.grad.mul_(mask)
74+
75+
76+
def _has_converged(
77+
current: float,
78+
previous: float,
79+
target: float = 1e5,
80+
epoch: int = 0,
81+
min_improvement: float = 0.01,
82+
patience: int = 10,
83+
) -> bool:
84+
if current < target:
85+
return True
86+
if epoch > patience:
87+
# relative improvement
88+
rel_impr = abs(current - previous) / max(previous, 1e-8)
89+
return rel_impr < min_improvement
90+
return False
91+

0 commit comments

Comments
 (0)