Skip to content

Remove for loops in deviations function body #482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ With minimal code, you can perform evaluations on a computing cluster, display l

Adaptive is most efficient for computations where each function evaluation takes at least ≈50ms due to the overhead of selecting potentially interesting points.

To see Adaptive in action, try the [example notebook on Binder](https://mybinder.org/v2/gh/python-adaptive/adaptive/main?filepath=example-notebook.ipynb) or explore the [tutorial on Read the Docs](https://adaptive.readthedocs.io/en/latest/tutorial/tutorial.html).
To see Adaptive in action, try the [example notebook on Binder](https://mybinder.org/v2/gh/python-adaptive/adaptive/main?filepath=example-notebook.ipynb) or explore the [tutorial on Read the Docs](https://adaptive.readthedocs.io/en/latest/tutorial/tutorial).

<!-- summary-end -->

Expand Down
22 changes: 9 additions & 13 deletions adaptive/learner/learner2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# Learner2D and helper functions.


def deviations(ip: LinearNDInterpolator) -> list[np.ndarray]:
def deviations(ip: LinearNDInterpolator) -> np.ndarray:
"""Returns the deviation of the linear estimate.

Is useful when defining custom loss functions.
Expand All @@ -55,18 +55,14 @@ def deviations(ip: LinearNDInterpolator) -> list[np.ndarray]:
vs = values[simplices]
gs = gradients[simplices]

def deviation(p, v, g):
dev = 0
for j in range(3):
vest = v[:, j, None] + (
(p[:, :, :] - p[:, j, None, :]) * g[:, j, None, :]
).sum(axis=-1)
dev += abs(vest - v).max(axis=1)
return dev

n_levels = vs.shape[2]
devs = [deviation(p, vs[:, :, i], gs[:, :, i]) for i in range(n_levels)]
return devs
p = np.expand_dims(p, axis=2)

p_diff = p[:, None] - p[:, :, None]
p_diff_scaled = p_diff * gs[:, :, None]
vest = vs[:, :, None] + p_diff_scaled.sum(axis=-1)
devs = np.sum(np.max(np.abs(vest - vs[:, None]), axis=2), axis=1)

return np.swapaxes(devs, 0, 1)


def areas(ip: LinearNDInterpolator) -> np.ndarray:
Expand Down
Loading