Skip to content

Commit bfcc46c

Browse files
committed
force positive Hessian
1 parent 5683074 commit bfcc46c

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

icenet/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import os
44
import psutil
55

6-
__version__ = '0.1.5.4'
6+
__version__ = '0.1.5.5'
77
__release__ = 'alpha'
8-
__date__ = '01/08/2025'
8+
__date__ = '05/08/2025'
99
__author__ = 'm.mieskolainen@imperial.ac.uk'
1010
__repository__ = 'github.com/mieskolainen/icenet'
1111
__asciiart__ = \

icenet/deep/autogradxgb.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Custom pytorch-driven autograd losses for XGBoost
2-
# with various Hessian diagonal approaches
2+
# with various Hessian diagonal approaches.
33
#
44
# m.mieskolainen@imperial.ac.uk, 2025
55

@@ -51,7 +51,7 @@ def __init__(self,
5151
hessian_const: float=1.0,
5252
hessian_gamma: float=0.9,
5353
hessian_eps: float=1e-8,
54-
hessian_absmax: float=10.0,
54+
hessian_limit: list=[1e-6, 20],
5555
hessian_slices: int=10,
5656
device: torch.device='cpu'
5757
):
@@ -65,7 +65,7 @@ def __init__(self,
6565
# Iterative mode
6666
self.hessian_gamma = hessian_gamma
6767
self.hessian_eps = hessian_eps
68-
self.hessian_absmax = hessian_absmax
68+
self.hessian_limit = hessian_limit
6969

7070
# Hutchinson mode
7171
self.hessian_slices = int(hessian_slices)
@@ -118,6 +118,19 @@ def torch_conversion(self, preds: np.ndarray, targets: xgboost.DMatrix):
118118

119119
return preds, targets, weights
120120

121+
def regulate_hess(self, hess: Tensor):
122+
"""
123+
Regulate to be positive definite (H_ii > hessian_min)
124+
as required by second order gradient descent.
125+
126+
Do not clip to zero, as that might result in zero denominators
127+
in the Hessian routines inside xgboost.
128+
"""
129+
hess = torch.abs(hess) # ~ negative weights
130+
hess = torch.clamp(hess, min=self.hessian_limit[0], max=self.hessian_limit[1])
131+
132+
return hess
133+
121134
@torch.no_grad
122135
def iterative_hessian_update(self, grad: Tensor, preds: Tensor):
123136
"""
@@ -142,8 +155,7 @@ def iterative_hessian_update(self, grad: Tensor, preds: Tensor):
142155
ds = preds - self.preds_prev
143156

144157
hess_diag_new = dg / (ds + self.hessian_eps)
145-
hess_diag_new = torch.clamp(hess_diag_new,
146-
min=-self.hessian_absmax, max=self.hessian_absmax)
158+
hess_diag_new = self.regulate_hess(hess_diag_new) # regulate
147159

148160
# Exponential Moving Average (EMA), approx filter size ~ 1 / (1 - gamma) steps
149161
self.hess_diag = self.hessian_gamma * self.hess_diag + \
@@ -160,7 +172,7 @@ def hessian_hutchinson(self, grad: Tensor, preds: Tensor):
160172
tic = time.time()
161173
print(f'Computing Hessian diag with Hutchinson MC (slices = {self.hessian_slices}) ... ')
162174

163-
grad2 = torch.zeros_like(preds)
175+
hess = torch.zeros_like(preds)
164176

165177
for _ in range(self.hessian_slices):
166178

@@ -172,21 +184,24 @@ def hessian_hutchinson(self, grad: Tensor, preds: Tensor):
172184
Hv = torch.autograd.grad(grad, preds, grad_outputs=v, retain_graph=True)[0]
173185

174186
# Accumulate element-wise product v * Hv to get the diagonal
175-
grad2 += v * Hv
176-
177-
print(f'Took {time.time()-tic:.2f} sec')
187+
hess += v * Hv
178188

179189
# Average over all samples
180-
return grad2 / self.hessian_slices
190+
hess = hess / self.hessian_slices
191+
hess = self.regulate_hess(hess) # regulate
192+
193+
print(f'Took {time.time()-tic:.2f} sec')
181194

195+
return hess
196+
182197
def hessian_exact(self, grad: Tensor, preds: Tensor):
183198
"""
184199
Hessian diagonal with exact autograd ~ O(data points) (time)
185200
"""
186201
tic = time.time()
187202
print('Computing Hessian diagonal with exact autograd ... ')
188203

189-
grad2 = torch.zeros_like(preds)
204+
hess = torch.zeros_like(preds)
190205

191206
for i in tqdm(range(len(preds))):
192207

@@ -195,11 +210,13 @@ def hessian_exact(self, grad: Tensor, preds: Tensor):
195210
e_i[i] = 1.0
196211

197212
# Compute the Hessian-vector product H e_i
198-
grad2[i] = torch.autograd.grad(grad, preds, grad_outputs=e_i, retain_graph=True)[0][i]
213+
hess[i] = torch.autograd.grad(grad, preds, grad_outputs=e_i, retain_graph=True)[0][i]
199214

215+
hess = self.regulate_hess(hess) # regulate
216+
200217
print(f'Took {time.time()-tic:.2f} sec')
201218

202-
return grad2
219+
return hess
203220

204221
def derivatives(self, loss: Tensor, preds: Tensor) -> Tuple[np.ndarray, np.ndarray]:
205222
"""

0 commit comments

Comments
 (0)