@@ -68,8 +68,8 @@ def nugget_optimizer(
68
68
print (f"Outliers sp: { torch .nonzero (mask_sp )} " )
69
69
print (f"Outliers ori: { torch .nonzero (mask_ori )} " )
70
70
71
- _gradient_foo ( nugget_effect_scalar = nugget , mask = mask_sp )
72
- _gradient_foo ( nugget_effect_scalar = nugget_ori , mask = mask_ori )
71
+ _apply_outlier_gradients ( tensor = nugget , mask = mask_sp )
72
+ _apply_outlier_gradients ( tensor = nugget_ori , mask = mask_ori )
73
73
74
74
# Step & clamp safely
75
75
opt .step ()
@@ -88,20 +88,21 @@ def nugget_optimizer(
88
88
return model
89
89
90
90
91
- def _mask_iqr (grads ) :
91
+ def _mask_iqr (grads , multiplier : float = 1.5 ) -> torch . BoolTensor :
92
92
q1 , q3 = grads .quantile (0.25 ), grads .quantile (0.75 )
93
- iqr = q3 - q1
94
- thresh = q3 + 1.5 * iqr
95
- mask = grads > thresh
96
- return mask
93
+ thresh = q3 + multiplier * (q3 - q1 )
94
+ return grads > thresh
95
+
96
+ def _apply_outlier_gradients (
97
+ tensor : torch .Tensor ,
98
+ mask : torch .BoolTensor ,
99
+ amplification : float = 5.0 ,
100
+ ):
101
+ # wrap in no_grad if you prefer, but .grad modifications are fine
102
+ tensor .grad .view (- 1 )[mask ] *= amplification
103
+ tensor .grad .view (- 1 )[~ mask ] = 0
97
104
98
105
99
- def _gradient_foo (nugget_effect_scalar : torch .Tensor , mask ):
100
- # amplify outliers if you want bigger jumps
101
- nugget_effect_scalar .grad [mask ] *= 5.0
102
- # zero all other gradients
103
- nugget_effect_scalar .grad [~ mask ] = 0
104
-
105
106
106
107
def _gradient_masking (nugget , focus = 0.01 ):
107
108
"""Old way of avoiding exploding gradients."""
0 commit comments