1
- import torch
2
-
3
1
import gempy_engine
2
+ from gempy .optional_dependencies import require_torch
4
3
from gempy_engine .core .data .continue_epoch import ContinueEpoch
5
4
6
5
7
6
def run_optimization (lr , max_epochs , min_impr , model , nugget , patience , target_cond_num ):
7
+ torch = require_torch ()
8
8
opt = torch .optim .Adam (
9
9
params = [
10
10
nugget ,
@@ -32,7 +32,7 @@ def run_optimization(lr, max_epochs, min_impr, model, nugget, patience, target_c
32
32
_apply_outlier_gradients (tensor = nugget , mask = mask_sp )
33
33
34
34
# Step & clamp safely
35
- opt .step ()
35
+ opt .step ()
36
36
with torch .no_grad ():
37
37
nugget .clamp_ (min = 1e-7 )
38
38
@@ -47,14 +47,14 @@ def run_optimization(lr, max_epochs, min_impr, model, nugget, patience, target_c
47
47
return nugget
48
48
49
49
50
- def _mask_iqr (grads , multiplier : float = 1.5 ) -> torch .BoolTensor :
50
+ def _mask_iqr (grads , multiplier : float = 1.5 ) -> " torch.BoolTensor" :
51
51
q1 , q3 = grads .quantile (0.25 ), grads .quantile (0.75 )
52
52
thresh = q3 + multiplier * (q3 - q1 )
53
53
return grads > thresh
54
54
55
55
def _apply_outlier_gradients (
56
- tensor : torch .Tensor ,
57
- mask : torch .BoolTensor ,
56
+ tensor : " torch.Tensor" ,
57
+ mask : " torch.BoolTensor" ,
58
58
amplification : float = 1.0 ,
59
59
):
60
60
# wrap in no_grad if you prefer, but .grad modifications are fine
@@ -65,6 +65,7 @@ def _apply_outlier_gradients(
65
65
66
66
def _gradient_masking (nugget , focus = 0.01 ):
67
67
"""Old way of avoiding exploding gradients."""
68
+ torch = require_torch ()
68
69
grads = nugget .grad .abs ()
69
70
k = int (grads .numel () * focus )
70
71
top_vals , top_idx = torch .topk (grads , k , largest = True )
0 commit comments