Skip to content

Commit 4d8080f

Browse files
committed
[WIP/CLN] Improve naming
1 parent d1efe28 commit 4d8080f

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

gempy/modules/optimize_nuggets/_optimizer.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def nugget_optimizer(
6868
print(f"Outliers sp: {torch.nonzero(mask_sp)}")
6969
print(f"Outliers ori: {torch.nonzero(mask_ori)}")
7070

71-
_gradient_foo(nugget_effect_scalar=nugget, mask=mask_sp)
72-
_gradient_foo(nugget_effect_scalar=nugget_ori, mask=mask_ori)
71+
_apply_outlier_gradients(tensor=nugget, mask=mask_sp)
72+
_apply_outlier_gradients(tensor=nugget_ori, mask=mask_ori)
7373

7474
# Step & clamp safely
7575
opt.step()
@@ -88,20 +88,21 @@ def nugget_optimizer(
8888
return model
8989

9090

91-
def _mask_iqr(grads):
91+
def _mask_iqr(grads, multiplier: float = 1.5) -> torch.BoolTensor:
9292
q1, q3 = grads.quantile(0.25), grads.quantile(0.75)
93-
iqr = q3 - q1
94-
thresh = q3 + 1.5 * iqr
95-
mask = grads > thresh
96-
return mask
93+
thresh = q3 + multiplier * (q3 - q1)
94+
return grads > thresh
95+
96+
def _apply_outlier_gradients(
97+
tensor: torch.Tensor,
98+
mask: torch.BoolTensor,
99+
amplification: float = 5.0,
100+
):
101+
# wrap in no_grad if you prefer, but .grad modifications are fine
102+
tensor.grad.view(-1)[mask] *= amplification
103+
tensor.grad.view(-1)[~mask] = 0
97104

98105

99-
def _gradient_foo(nugget_effect_scalar: torch.Tensor, mask):
100-
# amplify outliers if you want bigger jumps
101-
nugget_effect_scalar.grad[mask] *= 5.0
102-
# zero all other gradients
103-
nugget_effect_scalar.grad[~mask] = 0
104-
105106

106107
def _gradient_masking(nugget, focus=0.01):
107108
"""Old way of avoiding exploding gradients."""

0 commit comments

Comments
 (0)