|
9 | 9 | from scipy import interpolate
|
10 | 10 |
|
11 | 11 | from adaptive.learner.base_learner import BaseLearner
|
| 12 | +from adaptive.learner.triangulation import simplex_volume_in_embedding |
12 | 13 | from adaptive.notebook_integration import ensure_holoviews
|
13 | 14 | from adaptive.utils import cache_latest
|
14 | 15 |
|
@@ -212,6 +213,45 @@ def choose_point_in_triangle(triangle, max_badness):
|
212 | 213 | return point
|
213 | 214 |
|
214 | 215 |
|
| 216 | +def triangle_loss(ip): |
| 217 | + r"""Computes the average of the volumes of the simplex combined with each |
| 218 | + neighbouring point. |
| 219 | +
|
| 220 | + Parameters |
| 221 | + ---------- |
| 222 | + ip : `scipy.interpolate.LinearNDInterpolator` instance |
| 223 | +
|
| 224 | + Returns |
| 225 | + ------- |
| 226 | + triangle_loss : list |
| 227 | + The mean volume per triangle. |
| 228 | +
|
| 229 | + Notes |
| 230 | + ----- |
| 231 | + This loss function is *extremely* slow. It is here because it gives the |
| 232 | + same result as the `adaptive.LearnerND`\s |
| 233 | + `~adaptive.learner.learnerND.triangle_loss`. |
| 234 | + """ |
| 235 | + tri = ip.tri |
| 236 | + |
| 237 | + def get_neighbors(i, ip): |
| 238 | + n = np.array([tri.simplices[n] for n in tri.neighbors[i] if n != -1]) |
| 239 | + # remove the vertices that are in the simplex |
| 240 | + c = np.setdiff1d(n.reshape(-1), tri.simplices[i]) |
| 241 | + return np.concatenate((tri.points[c], ip.values[c]), axis=-1) |
| 242 | + |
| 243 | + simplices = np.concatenate( |
| 244 | + [tri.points[tri.simplices], ip.values[tri.simplices]], axis=-1 |
| 245 | + ) |
| 246 | + neighbors = [get_neighbors(i, ip) for i in range(len(tri.simplices))] |
| 247 | + |
| 248 | + return [ |
| 249 | + sum(simplex_volume_in_embedding(np.vstack([simplex, n])) for n in neighbors[i]) |
| 250 | + / len(neighbors[i]) |
| 251 | + for i, simplex in enumerate(simplices) |
| 252 | + ] |
| 253 | + |
| 254 | + |
215 | 255 | class Learner2D(BaseLearner):
|
216 | 256 | """Learns and predicts a function 'f: ℝ^2 → ℝ^N'.
|
217 | 257 |
|
|
0 commit comments