Skip to content

Commit 3dd214f

Browse files
krokosikbasnijholt
andauthored
Remove for loops in deviations function body (#482)
* Remove for loops in deviations function body * Fix tutorial link in readme * Add test for Learner2D with vector-valued functions * Change type of bounds * No need for import in test --------- Co-authored-by: Bas Nijholt <[email protected]>
1 parent d0aab31 commit 3dd214f

File tree

3 files changed

+72
-15
lines changed

3 files changed

+72
-15
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ With minimal code, you can perform evaluations on a computing cluster, display l
2222

2323
Adaptive is most efficient for computations where each function evaluation takes at least ≈50ms due to the overhead of selecting potentially interesting points.
2424

25-
To see Adaptive in action, try the [example notebook on Binder](https://mybinder.org/v2/gh/python-adaptive/adaptive/main?filepath=example-notebook.ipynb) or explore the [tutorial on Read the Docs](https://adaptive.readthedocs.io/en/latest/tutorial/tutorial.html).
25+
To see Adaptive in action, try the [example notebook on Binder](https://mybinder.org/v2/gh/python-adaptive/adaptive/main?filepath=example-notebook.ipynb) or explore the [tutorial on Read the Docs](https://adaptive.readthedocs.io/en/latest/tutorial/tutorial).
2626

2727
<!-- summary-end -->
2828

adaptive/learner/learner2D.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
# Learner2D and helper functions.
3434

3535

36-
def deviations(ip: LinearNDInterpolator) -> list[np.ndarray]:
36+
def deviations(ip: LinearNDInterpolator) -> np.ndarray:
3737
"""Returns the deviation of the linear estimate.
3838
3939
Is useful when defining custom loss functions.
@@ -44,7 +44,7 @@ def deviations(ip: LinearNDInterpolator) -> list[np.ndarray]:
4444
4545
Returns
4646
-------
47-
deviations : list
47+
deviations : numpy.ndarray
4848
The deviation per triangle.
4949
"""
5050
values = ip.values / (np.ptp(ip.values, axis=0).max() or 1)
@@ -55,18 +55,14 @@ def deviations(ip: LinearNDInterpolator) -> list[np.ndarray]:
5555
vs = values[simplices]
5656
gs = gradients[simplices]
5757

58-
def deviation(p, v, g):
59-
dev = 0
60-
for j in range(3):
61-
vest = v[:, j, None] + (
62-
(p[:, :, :] - p[:, j, None, :]) * g[:, j, None, :]
63-
).sum(axis=-1)
64-
dev += abs(vest - v).max(axis=1)
65-
return dev
66-
67-
n_levels = vs.shape[2]
68-
devs = [deviation(p, vs[:, :, i], gs[:, :, i]) for i in range(n_levels)]
69-
return devs
58+
p = np.expand_dims(p, axis=2)
59+
60+
p_diff = p[:, None] - p[:, :, None]
61+
p_diff_scaled = p_diff * gs[:, :, None]
62+
vest = vs[:, :, None] + p_diff_scaled.sum(axis=-1)
63+
devs = np.sum(np.max(np.abs(vest - vs[:, None]), axis=2), axis=1)
64+
65+
return np.swapaxes(devs, 0, 1)
7066

7167

7268
def areas(ip: LinearNDInterpolator) -> np.ndarray:

adaptive/tests/test_learners.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,67 @@ def f(x):
279279
simple_run(learner, 10)
280280

281281

282+
def test_learner2d_vector_valued_function():
283+
"""Test that Learner2D handles vector-valued functions correctly.
284+
285+
This test verifies that the deviations function works properly when
286+
the function returns a vector (array/list) of values instead of a scalar.
287+
"""
288+
289+
def vector_function(xy):
290+
"""A 2D function that returns a 3-element vector."""
291+
x, y = xy
292+
return [x + y, x * y, x - y] # Returns 3-element vector
293+
294+
# Create learner with vector-valued function
295+
learner = Learner2D(vector_function, bounds=((-1, 1), (-1, 1)))
296+
297+
# Add some initial points
298+
points = [
299+
(0.0, 0.0),
300+
(1.0, 0.0),
301+
(0.0, 1.0),
302+
(1.0, 1.0),
303+
(0.5, 0.5),
304+
(-0.5, 0.5),
305+
(0.5, -0.5),
306+
(-1.0, -1.0),
307+
]
308+
309+
for point in points:
310+
value = vector_function(point)
311+
learner.tell(point, value)
312+
313+
# Run the learner to trigger deviations calculation
314+
# This should not raise any errors
315+
learner.ask(10)
316+
317+
# Verify that the interpolator is created (ip is a property that may return a function)
318+
assert hasattr(learner, "ip")
319+
320+
# Check the internal interpolator if it exists
321+
if hasattr(learner, "_ip") and learner._ip is not None:
322+
# Check that values have the correct shape
323+
assert learner._ip.values.shape[1] == 3 # 3 output dimensions
324+
325+
# Test that we can evaluate the interpolated function
326+
test_point = (0.25, 0.25)
327+
ip_func = learner.interpolator(scaled=True) # Get the interpolator function
328+
if ip_func is not None:
329+
interpolated_value = ip_func(test_point)
330+
assert len(interpolated_value) == 3
331+
332+
# Run more iterations to ensure deviations are computed correctly
333+
simple_run(learner, 20)
334+
335+
# Final verification
336+
assert len(learner.data) > len(points) # Learner added more points
337+
338+
# Check that all values in data are vectors
339+
for _point, value in learner.data.items():
340+
assert len(value) == 3, f"Expected 3-element vector, got {value}"
341+
342+
282343
@run_with(Learner1D, Learner2D, LearnerND, SequenceLearner, AverageLearner1D)
283344
def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
284345
"""Adding already existing data is an idempotent operation.

0 commit comments

Comments
 (0)