Skip to content

Commit d1efe28

Browse files
committed
[WIP] Better configuration of parameters
1 parent e9911cd commit d1efe28

File tree

2 files changed

+38
-27
lines changed

2 files changed

+38
-27
lines changed

gempy/modules/optimize_nuggets/_optimizer.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@
1414
torch = None
1515

1616

17-
1817
def nugget_optimizer(
1918
target_cond_num: float,
2019
engine_cfg: GemPyEngineConfig,
2120
model: GeoModel,
2221
max_epochs: int,
23-
lr: float = 1e-2,
22+
lr: float = .1,
2423
patience: int = 10,
2524
min_impr: float = 0.01,
2625
) -> GeoModel:
@@ -40,9 +39,12 @@ def nugget_optimizer(
4039
interp_in: InterpolationInput = interpolation_input_from_structural_frame(model)
4140
model.taped_interpolation_input = interp_in
4241
nugget: torch.Tensor = interp_in.surface_points.nugget_effect_scalar
42+
nugget_ori: torch.Tensor = interp_in.orientations.nugget_effect_grad
43+
4344
nugget.requires_grad_(True)
45+
nugget_ori.requires_grad_(True)
4446

45-
opt = torch.optim.Adam(params=[nugget], lr=lr)
47+
opt = torch.optim.Adam(params=[nugget, nugget_ori], lr=lr)
4648

4749
model.interpolation_options.kernel_options.optimizing_condition_number = True
4850

@@ -58,26 +60,16 @@ def nugget_optimizer(
5860
)
5961
except ContinueEpoch:
6062
# Keep only top 10% gradients
61-
if False:
62-
_gradient_masking(
63-
nugget=nugget,
64-
focus=0.01
65-
)
66-
elif True:
67-
if epoch % 5 == 0:
68-
# if True:
69-
grads = nugget.grad.abs().view(-1)
70-
q1, q3 = grads.quantile(0.25), grads.quantile(0.75)
71-
iqr = q3 - q1
72-
thresh = q3 + 1.5 * iqr
73-
mask = grads > thresh
74-
75-
# print the indices of mask
76-
print(f"Outliers: {torch.nonzero(mask)}")
77-
78-
_gradient_foo(nugget_effect_scalar=nugget, mask=mask)
79-
else:
80-
clip_grad_norm_(parameters=[nugget], max_norm=0.0001)
63+
if epoch % 5 == 0:
64+
mask_sp = _mask_iqr(nugget.grad.abs().view(-1))
65+
mask_ori = _mask_iqr(nugget_ori.grad.abs().view(-1))
66+
67+
# print the indices of mask
68+
print(f"Outliers sp: {torch.nonzero(mask_sp)}")
69+
print(f"Outliers ori: {torch.nonzero(mask_ori)}")
70+
71+
_gradient_foo(nugget_effect_scalar=nugget, mask=mask_sp)
72+
_gradient_foo(nugget_effect_scalar=nugget_ori, mask=mask_ori)
8173

8274
# Step & clamp safely
8375
opt.step()
@@ -96,14 +88,22 @@ def nugget_optimizer(
9688
return model
9789

9890

91+
def _mask_iqr(grads):
92+
q1, q3 = grads.quantile(0.25), grads.quantile(0.75)
93+
iqr = q3 - q1
94+
thresh = q3 + 1.5 * iqr
95+
mask = grads > thresh
96+
return mask
97+
98+
9999
def _gradient_foo(nugget_effect_scalar: torch.Tensor, mask):
100-
101100
# amplify outliers if you want bigger jumps
102101
nugget_effect_scalar.grad[mask] *= 5.0
103102
# zero all other gradients
104103
nugget_effect_scalar.grad[~mask] = 0
105104

106-
def _gradient_masking(nugget, focus = 0.01):
105+
106+
def _gradient_masking(nugget, focus=0.01):
107107
"""Old way of avoiding exploding gradients."""
108108
grads = nugget.grad.abs()
109109
k = int(grads.numel() * focus)
@@ -135,15 +135,15 @@ def nugget_optimizer__legacy(target_cond_num, engine_cfg, model, max_epochs):
135135
geo_model: GeoModel = model
136136
convergence_criteria = target_cond_num
137137
engine_config = engine_cfg
138-
138+
139139
BackendTensor.change_backend_gempy(
140140
engine_backend=engine_config.backend,
141141
use_gpu=engine_config.use_gpu,
142142
dtype=engine_config.dtype
143143
)
144144
import torch
145145
from gempy_engine.core.data.continue_epoch import ContinueEpoch
146-
146+
147147
interpolation_input: InterpolationInput = interpolation_input_from_structural_frame(geo_model)
148148
geo_model.taped_interpolation_input = interpolation_input
149149
nugget_effect_scalar: torch.Tensor = geo_model.taped_interpolation_input.surface_points.nugget_effect_scalar

test/test_private/test_terranigma/test_nuggets/test_nugget_effect_optimization.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,15 @@ def test_optimize_nugget_effect():
110110
point_size=25,
111111
)
112112

113+
if False:
114+
ori_cloud = pv.PolyData(geo_model.orientations_copy.df[['X', 'Y', 'Z']].to_numpy())
115+
ori_cloud['values2'] = geo_model.taped_interpolation_input.orientations.nugget_effect_grad.detach().numpy()
116+
117+
gempy_vista.p.add_mesh(
118+
ori_cloud,
119+
scalars='values2',
120+
cmap='viridis',
121+
point_size=20,
122+
)
123+
113124
gempy_vista.p.show()

0 commit comments

Comments
 (0)