|
10 | 10 | from scipy import interpolate
|
11 | 11 |
|
12 | 12 | from adaptive.learner.base_learner import BaseLearner
|
| 13 | +from adaptive.learner.triangulation import simplex_volume_in_embedding |
13 | 14 | from adaptive.notebook_integration import ensure_holoviews
|
14 | 15 | from adaptive.utils import cache_latest
|
15 | 16 |
|
@@ -247,6 +248,45 @@ def choose_point_in_triangle(triangle, max_badness):
|
247 | 248 | return point
|
248 | 249 |
|
249 | 250 |
|
| 251 | +def triangle_loss(ip): |
| 252 | + r"""Computes the average of the volumes of the simplex combined with each |
| 253 | + neighbouring point. |
| 254 | +
|
| 255 | + Parameters |
| 256 | + ---------- |
| 257 | + ip : `scipy.interpolate.LinearNDInterpolator` instance |
| 258 | +
|
| 259 | + Returns |
| 260 | + ------- |
| 261 | + triangle_loss : list |
| 262 | + The mean volume per triangle. |
| 263 | +
|
| 264 | + Notes |
| 265 | + ----- |
| 266 | + This loss function is *extremely* slow. It is here because it gives the |
| 267 | + same result as the `adaptive.LearnerND`\s |
| 268 | + `~adaptive.learner.learnerND.triangle_loss`. |
| 269 | + """ |
| 270 | + tri = ip.tri |
| 271 | + |
| 272 | + def get_neighbors(i, ip): |
| 273 | + n = np.array([tri.simplices[n] for n in tri.neighbors[i] if n != -1]) |
| 274 | + # remove the vertices that are in the simplex |
| 275 | + c = np.setdiff1d(n.reshape(-1), tri.simplices[i]) |
| 276 | + return np.concatenate((tri.points[c], ip.values[c]), axis=-1) |
| 277 | + |
| 278 | + simplices = np.concatenate( |
| 279 | + [tri.points[tri.simplices], ip.values[tri.simplices]], axis=-1 |
| 280 | + ) |
| 281 | + neighbors = [get_neighbors(i, ip) for i in range(len(tri.simplices))] |
| 282 | + |
| 283 | + return [ |
| 284 | + sum(simplex_volume_in_embedding(np.vstack([simplex, n])) for n in neighbors[i]) |
| 285 | + / len(neighbors[i]) |
| 286 | + for i, simplex in enumerate(simplices) |
| 287 | + ] |
| 288 | + |
| 289 | + |
250 | 290 | class Learner2D(BaseLearner):
|
251 | 291 | """Learns and predicts a function 'f: ℝ^2 → ℝ^N'.
|
252 | 292 |
|
|
0 commit comments