diff --git a/README.md b/README.md index dcfd4ff8..e3a83a68 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ With minimal code, you can perform evaluations on a computing cluster, display l Adaptive is most efficient for computations where each function evaluation takes at least ≈50ms due to the overhead of selecting potentially interesting points. -To see Adaptive in action, try the [example notebook on Binder](https://mybinder.org/v2/gh/python-adaptive/adaptive/main?filepath=example-notebook.ipynb) or explore the [tutorial on Read the Docs](https://adaptive.readthedocs.io/en/latest/tutorial/tutorial.html). +To see Adaptive in action, try the [example notebook on Binder](https://mybinder.org/v2/gh/python-adaptive/adaptive/main?filepath=example-notebook.ipynb) or explore the [tutorial on Read the Docs](https://adaptive.readthedocs.io/en/latest/tutorial/tutorial). diff --git a/adaptive/learner/learner2D.py b/adaptive/learner/learner2D.py index cb179a22..ea1e82af 100644 --- a/adaptive/learner/learner2D.py +++ b/adaptive/learner/learner2D.py @@ -33,7 +33,7 @@ # Learner2D and helper functions. -def deviations(ip: LinearNDInterpolator) -> list[np.ndarray]: +def deviations(ip: LinearNDInterpolator) -> np.ndarray: """Returns the deviation of the linear estimate. Is useful when defining custom loss functions. @@ -55,18 +55,14 @@ def deviations(ip: LinearNDInterpolator) -> list[np.ndarray]: vs = values[simplices] gs = gradients[simplices] - def deviation(p, v, g): - dev = 0 - for j in range(3): - vest = v[:, j, None] + ( - (p[:, :, :] - p[:, j, None, :]) * g[:, j, None, :] - ).sum(axis=-1) - dev += abs(vest - v).max(axis=1) - return dev - - n_levels = vs.shape[2] - devs = [deviation(p, vs[:, :, i], gs[:, :, i]) for i in range(n_levels)] - return devs + p = np.expand_dims(p, axis=2) + + p_diff = p[:, None] - p[:, :, None] + p_diff_scaled = p_diff * gs[:, :, None] + vest = vs[:, :, None] + p_diff_scaled.sum(axis=-1) + devs = np.sum(np.max(np.abs(vest - vs[:, None]), axis=2), axis=1) + + return np.swapaxes(devs, 0, 1) def areas(ip: LinearNDInterpolator) -> np.ndarray: