11# Custom pytorch-driven autograd losses for XGBoost
2- # with various Hessian diagonal approaches
2+ # with various Hessian diagonal approaches.
33#
44# m.mieskolainen@imperial.ac.uk, 2025
55
@@ -51,7 +51,7 @@ def __init__(self,
5151 hessian_const : float = 1.0 ,
5252 hessian_gamma : float = 0.9 ,
5353 hessian_eps : float = 1e-8 ,
54- hessian_absmax : float = 10.0 ,
54+ hessian_limit : list = [ 1e-6 , 20 ] ,
5555 hessian_slices : int = 10 ,
5656 device : torch .device = 'cpu'
5757 ):
@@ -65,7 +65,7 @@ def __init__(self,
6565 # Iterative mode
6666 self .hessian_gamma = hessian_gamma
6767 self .hessian_eps = hessian_eps
68- self .hessian_absmax = hessian_absmax
68+ self .hessian_limit = hessian_limit
6969
7070 # Hutchinson mode
7171 self .hessian_slices = int (hessian_slices )
@@ -118,6 +118,19 @@ def torch_conversion(self, preds: np.ndarray, targets: xgboost.DMatrix):
118118
119119 return preds , targets , weights
120120
121+ def regulate_hess (self , hess : Tensor ):
122+ """
123+ Regulate to be positive definite (H_ii > hessian_min)
124+ as required by second order gradient descent.
125+
126+ Do not clip to zero, as that might result in zero denominators
127+ in the Hessian routines inside xgboost.
128+ """
129+ hess = torch .abs (hess ) # ~ negative weights
130+ hess = torch .clamp (hess , min = self .hessian_limit [0 ], max = self .hessian_limit [1 ])
131+
132+ return hess
133+
121134 @torch .no_grad
122135 def iterative_hessian_update (self , grad : Tensor , preds : Tensor ):
123136 """
@@ -142,8 +155,7 @@ def iterative_hessian_update(self, grad: Tensor, preds: Tensor):
142155 ds = preds - self .preds_prev
143156
144157 hess_diag_new = dg / (ds + self .hessian_eps )
145- hess_diag_new = torch .clamp (hess_diag_new ,
146- min = - self .hessian_absmax , max = self .hessian_absmax )
158+ hess_diag_new = self .regulate_hess (hess_diag_new ) # regulate
147159
148160 # Exponential Moving Average (EMA), approx filter size ~ 1 / (1 - gamma) steps
149161 self .hess_diag = self .hessian_gamma * self .hess_diag + \
@@ -160,7 +172,7 @@ def hessian_hutchinson(self, grad: Tensor, preds: Tensor):
160172 tic = time .time ()
161173 print (f'Computing Hessian diag with Hutchinson MC (slices = { self .hessian_slices } ) ... ' )
162174
163- grad2 = torch .zeros_like (preds )
175+ hess = torch .zeros_like (preds )
164176
165177 for _ in range (self .hessian_slices ):
166178
@@ -172,21 +184,24 @@ def hessian_hutchinson(self, grad: Tensor, preds: Tensor):
172184 Hv = torch .autograd .grad (grad , preds , grad_outputs = v , retain_graph = True )[0 ]
173185
174186 # Accumulate element-wise product v * Hv to get the diagonal
175- grad2 += v * Hv
176-
177- print (f'Took { time .time ()- tic :.2f} sec' )
187+ hess += v * Hv
178188
179189 # Average over all samples
180- return grad2 / self .hessian_slices
190+ hess = hess / self .hessian_slices
191+ hess = self .regulate_hess (hess ) # regulate
192+
193+ print (f'Took { time .time ()- tic :.2f} sec' )
181194
195+ return hess
196+
182197 def hessian_exact (self , grad : Tensor , preds : Tensor ):
183198 """
184199 Hessian diagonal with exact autograd ~ O(data points) (time)
185200 """
186201 tic = time .time ()
187202 print ('Computing Hessian diagonal with exact autograd ... ' )
188203
189- grad2 = torch .zeros_like (preds )
204+ hess = torch .zeros_like (preds )
190205
191206 for i in tqdm (range (len (preds ))):
192207
@@ -195,11 +210,13 @@ def hessian_exact(self, grad: Tensor, preds: Tensor):
195210 e_i [i ] = 1.0
196211
197212 # Compute the Hessian-vector product H e_i
198- grad2 [i ] = torch .autograd .grad (grad , preds , grad_outputs = e_i , retain_graph = True )[0 ][i ]
213+ hess [i ] = torch .autograd .grad (grad , preds , grad_outputs = e_i , retain_graph = True )[0 ][i ]
199214
215+ hess = self .regulate_hess (hess ) # regulate
216+
200217 print (f'Took { time .time ()- tic :.2f} sec' )
201218
202- return grad2
219+ return hess
203220
204221 def derivatives (self , loss : Tensor , preds : Tensor ) -> Tuple [np .ndarray , np .ndarray ]:
205222 """
0 commit comments