14
14
torch = None
15
15
16
16
17
-
18
17
def nugget_optimizer (
19
18
target_cond_num : float ,
20
19
engine_cfg : GemPyEngineConfig ,
21
20
model : GeoModel ,
22
21
max_epochs : int ,
23
- lr : float = 1e-2 ,
22
+ lr : float = .1 ,
24
23
patience : int = 10 ,
25
24
min_impr : float = 0.01 ,
26
25
) -> GeoModel :
@@ -40,9 +39,12 @@ def nugget_optimizer(
40
39
interp_in : InterpolationInput = interpolation_input_from_structural_frame (model )
41
40
model .taped_interpolation_input = interp_in
42
41
nugget : torch .Tensor = interp_in .surface_points .nugget_effect_scalar
42
+ nugget_ori : torch .Tensor = interp_in .orientations .nugget_effect_grad
43
+
43
44
nugget .requires_grad_ (True )
45
+ nugget_ori .requires_grad_ (True )
44
46
45
- opt = torch .optim .Adam (params = [nugget ], lr = lr )
47
+ opt = torch .optim .Adam (params = [nugget , nugget_ori ], lr = lr )
46
48
47
49
model .interpolation_options .kernel_options .optimizing_condition_number = True
48
50
@@ -58,26 +60,16 @@ def nugget_optimizer(
58
60
)
59
61
except ContinueEpoch :
60
62
# Keep only top 10% gradients
61
- if False :
62
- _gradient_masking (
63
- nugget = nugget ,
64
- focus = 0.01
65
- )
66
- elif True :
67
- if epoch % 5 == 0 :
68
- # if True:
69
- grads = nugget .grad .abs ().view (- 1 )
70
- q1 , q3 = grads .quantile (0.25 ), grads .quantile (0.75 )
71
- iqr = q3 - q1
72
- thresh = q3 + 1.5 * iqr
73
- mask = grads > thresh
74
-
75
- # print the indices of mask
76
- print (f"Outliers: { torch .nonzero (mask )} " )
77
-
78
- _gradient_foo (nugget_effect_scalar = nugget , mask = mask )
79
- else :
80
- clip_grad_norm_ (parameters = [nugget ], max_norm = 0.0001 )
63
+ if epoch % 5 == 0 :
64
+ mask_sp = _mask_iqr (nugget .grad .abs ().view (- 1 ))
65
+ mask_ori = _mask_iqr (nugget_ori .grad .abs ().view (- 1 ))
66
+
67
+ # print the indices of mask
68
+ print (f"Outliers sp: { torch .nonzero (mask_sp )} " )
69
+ print (f"Outliers ori: { torch .nonzero (mask_ori )} " )
70
+
71
+ _gradient_foo (nugget_effect_scalar = nugget , mask = mask_sp )
72
+ _gradient_foo (nugget_effect_scalar = nugget_ori , mask = mask_ori )
81
73
82
74
# Step & clamp safely
83
75
opt .step ()
@@ -96,14 +88,22 @@ def nugget_optimizer(
96
88
return model
97
89
98
90
91
+ def _mask_iqr (grads ):
92
+ q1 , q3 = grads .quantile (0.25 ), grads .quantile (0.75 )
93
+ iqr = q3 - q1
94
+ thresh = q3 + 1.5 * iqr
95
+ mask = grads > thresh
96
+ return mask
97
+
98
+
99
99
def _gradient_foo (nugget_effect_scalar : torch .Tensor , mask ):
100
-
101
100
# amplify outliers if you want bigger jumps
102
101
nugget_effect_scalar .grad [mask ] *= 5.0
103
102
# zero all other gradients
104
103
nugget_effect_scalar .grad [~ mask ] = 0
105
104
106
- def _gradient_masking (nugget , focus = 0.01 ):
105
+
106
+ def _gradient_masking (nugget , focus = 0.01 ):
107
107
"""Old way of avoiding exploding gradients."""
108
108
grads = nugget .grad .abs ()
109
109
k = int (grads .numel () * focus )
@@ -135,15 +135,15 @@ def nugget_optimizer__legacy(target_cond_num, engine_cfg, model, max_epochs):
135
135
geo_model : GeoModel = model
136
136
convergence_criteria = target_cond_num
137
137
engine_config = engine_cfg
138
-
138
+
139
139
BackendTensor .change_backend_gempy (
140
140
engine_backend = engine_config .backend ,
141
141
use_gpu = engine_config .use_gpu ,
142
142
dtype = engine_config .dtype
143
143
)
144
144
import torch
145
145
from gempy_engine .core .data .continue_epoch import ContinueEpoch
146
-
146
+
147
147
interpolation_input : InterpolationInput = interpolation_input_from_structural_frame (geo_model )
148
148
geo_model .taped_interpolation_input = interpolation_input
149
149
nugget_effect_scalar : torch .Tensor = geo_model .taped_interpolation_input .surface_points .nugget_effect_scalar
0 commit comments