@@ -23,7 +23,7 @@ def nugget_optimizer(
23
23
lr : float = 1e-2 ,
24
24
patience : int = 10 ,
25
25
min_impr : float = 0.01 ,
26
- ) -> float :
26
+ ) -> GeoModel :
27
27
"""
28
28
Optimize the nugget effect scalar to achieve a target condition number.
29
29
Returns the final nugget effect value.
@@ -59,9 +59,25 @@ def nugget_optimizer(
59
59
except ContinueEpoch :
60
60
# Keep only top 10% gradients
61
61
if False :
62
- _gradient_masking (nugget )
62
+ _gradient_masking (
63
+ nugget = nugget ,
64
+ focus = 0.01
65
+ )
66
+ elif True :
67
+ if epoch % 5 == 0 :
68
+ # if True:
69
+ grads = nugget .grad .abs ().view (- 1 )
70
+ q1 , q3 = grads .quantile (0.25 ), grads .quantile (0.75 )
71
+ iqr = q3 - q1
72
+ thresh = q3 + 1.5 * iqr
73
+ mask = grads > thresh
74
+
75
+ # print the indices of mask
76
+ print (f"Outliers: { torch .nonzero (mask )} " )
77
+
78
+ _gradient_foo (nugget_effect_scalar = nugget , mask = mask )
63
79
else :
64
- clip_grad_norm_ (parameters = [nugget ], max_norm = 1.0 )
80
+ clip_grad_norm_ (parameters = [nugget ], max_norm = 0.0001 )
65
81
66
82
# Step & clamp safely
67
83
opt .step ()
@@ -77,13 +93,20 @@ def nugget_optimizer(
77
93
prev_cond = cur_cond
78
94
79
95
model .interpolation_options .kernel_options .optimizing_condition_number = False
80
- return nugget . item ()
96
+ return model
81
97
82
98
83
- def _gradient_masking (nugget ):
99
+ def _gradient_foo (nugget_effect_scalar : torch .Tensor , mask ):
100
+
101
+ # amplify outliers if you want bigger jumps
102
+ nugget_effect_scalar .grad [mask ] *= 5.0
103
+ # zero all other gradients
104
+ nugget_effect_scalar .grad [~ mask ] = 0
105
+
106
+ def _gradient_masking (nugget , focus = 0.01 ):
84
107
"""Old way of avoiding exploding gradients."""
85
108
grads = nugget .grad .abs ()
86
- k = int (grads .numel () * 0.1 )
109
+ k = int (grads .numel () * focus )
87
110
top_vals , top_idx = torch .topk (grads , k , largest = True )
88
111
mask = torch .zeros_like (grads )
89
112
mask [top_idx ] = 1
@@ -105,3 +128,92 @@ def _has_converged(
105
128
rel_impr = abs (current - previous ) / max (previous , 1e-8 )
106
129
return rel_impr < min_improvement
107
130
return False
131
+
132
+
133
+ # region legacy
134
+ def nugget_optimizer__legacy (target_cond_num , engine_cfg , model , max_epochs ):
135
+ geo_model : GeoModel = model
136
+ convergence_criteria = target_cond_num
137
+ engine_config = engine_cfg
138
+
139
+ BackendTensor .change_backend_gempy (
140
+ engine_backend = engine_config .backend ,
141
+ use_gpu = engine_config .use_gpu ,
142
+ dtype = engine_config .dtype
143
+ )
144
+ import torch
145
+ from gempy_engine .core .data .continue_epoch import ContinueEpoch
146
+
147
+ interpolation_input : InterpolationInput = interpolation_input_from_structural_frame (geo_model )
148
+ geo_model .taped_interpolation_input = interpolation_input
149
+ nugget_effect_scalar : torch .Tensor = geo_model .taped_interpolation_input .surface_points .nugget_effect_scalar
150
+ nugget_effect_scalar .requires_grad = True
151
+ optimizer = torch .optim .Adam (
152
+ params = [nugget_effect_scalar ],
153
+ lr = 0.01 ,
154
+ )
155
+ # Optimization loop
156
+ geo_model .interpolation_options .kernel_options .optimizing_condition_number = True
157
+
158
+ previous_condition_number = 0
159
+ for epoch in range (max_epochs ):
160
+ optimizer .zero_grad ()
161
+ try :
162
+ # geo_model.taped_interpolation_input.grid = geo_model.interpolation_input_copy.grid
163
+
164
+ gempy_engine .compute_model (
165
+ interpolation_input = geo_model .taped_interpolation_input ,
166
+ options = geo_model .interpolation_options ,
167
+ data_descriptor = geo_model .input_data_descriptor ,
168
+ geophysics_input = geo_model .geophysics_input ,
169
+ )
170
+ except ContinueEpoch :
171
+ # Get absolute values of gradients
172
+ grad_magnitudes = torch .abs (nugget_effect_scalar .grad )
173
+
174
+ # Get indices of the 10 largest gradients
175
+ grad_magnitudes .size
176
+
177
+ # * This ignores 90 percent of the gradients
178
+ # To int
179
+ n_values = int (grad_magnitudes .size ()[0 ] * 0.9 )
180
+ _ , indices = torch .topk (grad_magnitudes , n_values , largest = False )
181
+
182
+ # Zero out gradients that are not in the top 10
183
+ mask = torch .ones_like (nugget_effect_scalar .grad )
184
+ mask [indices ] = 0
185
+ nugget_effect_scalar .grad *= mask
186
+
187
+ # Update the vector
188
+ optimizer .step ()
189
+ nugget_effect_scalar .data = nugget_effect_scalar .data .clamp_ (min = 1e-7 ) # Replace negative values with 0
190
+
191
+ # optimizer.zero_grad()
192
+ # Monitor progress
193
+ if epoch % 1 == 0 :
194
+ # print(f"Epoch {epoch}: Condition Number = {condition_number.item()}")
195
+ print (f"Epoch { epoch } " )
196
+
197
+ if _check_convergence_criterion (
198
+ conditional_number = geo_model .interpolation_options .kernel_options .condition_number ,
199
+ condition_number_old = previous_condition_number ,
200
+ conditional_number_target = convergence_criteria ,
201
+ epoch = epoch
202
+ ):
203
+ break
204
+ previous_condition_number = geo_model .interpolation_options .kernel_options .condition_number
205
+ continue
206
+ geo_model .interpolation_options .kernel_options .optimizing_condition_number = False
207
+ return geo_model
208
+
209
+
210
+ def _check_convergence_criterion (conditional_number : float , condition_number_old : float , conditional_number_target : float = 1e5 , epoch : int = 0 ):
211
+ import torch
212
+ reached_conditional_target = conditional_number < conditional_number_target
213
+ if reached_conditional_target == False and epoch > 10 :
214
+ condition_number_change = torch .abs (conditional_number - condition_number_old ) / condition_number_old
215
+ if condition_number_change < 0.01 :
216
+ reached_conditional_target = True
217
+ return reached_conditional_target
218
+
219
+ # endregion
0 commit comments