Skip to content

Commit e8dfd93

Browse files
authored
Merge pull request #221 from python-adaptive/triangle_loss_2D
2D: add triangle_loss
2 parents 1a64676 + f56257a commit e8dfd93

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

adaptive/learner/learner2D.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from scipy import interpolate
1111

1212
from adaptive.learner.base_learner import BaseLearner
13+
from adaptive.learner.triangulation import simplex_volume_in_embedding
1314
from adaptive.notebook_integration import ensure_holoviews
1415
from adaptive.utils import cache_latest
1516

@@ -247,6 +248,45 @@ def choose_point_in_triangle(triangle, max_badness):
247248
return point
248249

249250

251+
def triangle_loss(ip):
252+
r"""Computes the average of the volumes of the simplex combined with each
253+
neighbouring point.
254+
255+
Parameters
256+
----------
257+
ip : `scipy.interpolate.LinearNDInterpolator` instance
258+
259+
Returns
260+
-------
261+
triangle_loss : list
262+
The mean volume per triangle.
263+
264+
Notes
265+
-----
266+
This loss function is *extremely* slow. It is here because it gives the
267+
same result as the `adaptive.LearnerND`\s
268+
`~adaptive.learner.learnerND.triangle_loss`.
269+
"""
270+
tri = ip.tri
271+
272+
def get_neighbors(i, ip):
273+
n = np.array([tri.simplices[n] for n in tri.neighbors[i] if n != -1])
274+
# remove the vertices that are in the simplex
275+
c = np.setdiff1d(n.reshape(-1), tri.simplices[i])
276+
return np.concatenate((tri.points[c], ip.values[c]), axis=-1)
277+
278+
simplices = np.concatenate(
279+
[tri.points[tri.simplices], ip.values[tri.simplices]], axis=-1
280+
)
281+
neighbors = [get_neighbors(i, ip) for i in range(len(tri.simplices))]
282+
283+
return [
284+
sum(simplex_volume_in_embedding(np.vstack([simplex, n])) for n in neighbors[i])
285+
/ len(neighbors[i])
286+
for i, simplex in enumerate(simplices)
287+
]
288+
289+
250290
class Learner2D(BaseLearner):
251291
"""Learns and predicts a function 'f: ℝ^2 → ℝ^N'.
252292

0 commit comments

Comments
 (0)