Skip to content

Commit 9381017

Browse files
committed
feat: Differentiable smooth_min, smooth_max, and least_squares implementations
1 parent ecf2fd6 commit 9381017

File tree

3 files changed

+336
-99
lines changed

3 files changed

+336
-99
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111
- Autograd support for local field projections using `FieldProjectionKSpaceMonitor`.
1212
- Function `components.geometry.utils.flatten_groups` now also flattens transformed groups when requested.
13+
- Differentiable `smooth_min`, `smooth_max`, and `least_squares` functions in `tidy3d.plugins.autograd`.
1314

1415
### Changed
1516
- `CustomMedium` design regions require far less data when performing inverse design by reducing adjoint field monitor size for dims with one pixel.

tests/test_plugins/autograd/test_functions.py

Lines changed: 144 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,22 @@
55
import scipy.ndimage
66
from autograd.test_util import check_grads
77
from scipy.signal import convolve as convolve_sp
8-
from tidy3d.plugins.autograd.functions import (
8+
from tidy3d.plugins.autograd import (
99
add_at,
1010
convolve,
1111
grey_closing,
1212
grey_dilation,
1313
grey_erosion,
1414
grey_opening,
1515
interpn,
16+
least_squares,
1617
morphological_gradient,
1718
morphological_gradient_external,
1819
morphological_gradient_internal,
1920
pad,
2021
rescale,
22+
smooth_max,
23+
smooth_min,
2124
threshold,
2225
trapz,
2326
)
@@ -55,7 +58,7 @@ def test_pad_val(self, rng, mode, size, pad_width, axis):
5558
def test_pad_grad(self, rng, mode, size, pad_width, axis):
5659
"""Test gradients of padding function for various modes, sizes, pad widths, and axes."""
5760
x = rng.random(size)
58-
check_grads(pad, modes=["fwd", "rev"], order=1)(x, pad_width, mode=mode, axis=axis)
61+
check_grads(pad, modes=["fwd", "rev"], order=2)(x, pad_width, mode=mode, axis=axis)
5962

6063

6164
class TestPadExceptions:
@@ -136,7 +139,7 @@ def test_convolve_grad(self, rng, mode, padding, ary_size, kernel_size, square_k
136139
)
137140

138141
x, k = self._ary_and_kernel(rng, ary_size, kernel_size, square_kernel)
139-
check_grads(convolve, modes=["rev"], order=1)(x, k, padding=padding, mode=mode)
142+
check_grads(convolve, modes=["rev"], order=2)(x, k, padding=padding, mode=mode)
140143

141144

142145
class TestConvolveExceptions:
@@ -194,7 +197,7 @@ def test_morphology_val_size(self, rng, op, sp_op, mode, ary_size, kernel_size):
194197
def test_morphology_val_grad(self, rng, op, sp_op, mode, ary_size, kernel_size):
195198
"""Test gradients of morphological operations for various modes, array sizes, and kernel sizes."""
196199
x = rng.random(ary_size)
197-
check_grads(op, modes=["rev"], order=1)(x, size=kernel_size, mode=mode)
200+
check_grads(op, modes=["rev"], order=2)(x, size=kernel_size, mode=mode)
198201

199202
@pytest.mark.parametrize(
200203
"full",
@@ -238,7 +241,7 @@ def test_morphology_val_structure_grad(
238241
):
239242
"""Test gradients of morphological operations for various kernel structures."""
240243
x, k = self._ary_and_kernel(rng, ary_size, kernel_size, full, square, flat)
241-
check_grads(op, modes=["rev"], order=1)(x, size=kernel_size, mode=mode)
244+
check_grads(op, modes=["rev"], order=2)(x, size=kernel_size, mode=mode)
242245

243246

244247
@pytest.mark.parametrize(
@@ -317,11 +320,9 @@ def test_interpn_val(self, rng, dim, method):
317320
result_scipy = scipy.interpolate.interpn(points, values, tuple(xi_grid), method=method)
318321
npt.assert_allclose(result_custom, result_scipy)
319322

320-
@pytest.mark.parametrize("order", [1, 2])
321-
@pytest.mark.parametrize("mode", ["fwd", "rev"])
322-
def test_interpn_values_grad(self, rng, dim, method, order, mode):
323+
def test_interpn_values_grad(self, rng, dim, method):
323324
points, values, xi = self.generate_points_values_xi(rng, dim)
324-
check_grads(lambda v: interpn(points, v, xi, method=method), modes=[mode], order=order)(
325+
check_grads(lambda v: interpn(points, v, xi, method=method), modes=["fwd", "rev"], order=2)(
325326
values
326327
)
327328

@@ -356,12 +357,10 @@ def test_trapz_val(self, rng, shape, axis, use_x):
356357
result_numpy = np.trapz(y, x=x, dx=dx, axis=axis)
357358
npt.assert_allclose(result_custom, result_numpy)
358359

359-
@pytest.mark.parametrize("order", [1, 2])
360-
@pytest.mark.parametrize("mode", ["fwd", "rev"])
361-
def test_trapz_grad(self, rng, shape, axis, use_x, order, mode):
360+
def test_trapz_grad(self, rng, shape, axis, use_x):
362361
"""Test gradients of trapz function for different array dimensions and integration axes."""
363362
y, x, dx = self.generate_y_x_dx(rng, shape, use_x)
364-
check_grads(lambda y: trapz(y, x=x, dx=dx, axis=axis), modes=[mode], order=order)(y)
363+
check_grads(lambda y: trapz(y, x=x, dx=dx, axis=axis), modes=["fwd", "rev"], order=2)(y)
365364

366365

367366
@pytest.mark.parametrize("shape", [(10,), (10, 10)])
@@ -381,10 +380,137 @@ def test_add_at_val(self, rng, shape, indices):
381380
result_numpy[indices] += y
382381
npt.assert_allclose(result_custom, result_numpy)
383382

384-
@pytest.mark.parametrize("order", [1, 2])
385-
@pytest.mark.parametrize("mode", ["fwd", "rev"])
386-
def test_add_at_grad(self, rng, shape, indices, order, mode):
383+
def test_add_at_grad(self, rng, shape, indices):
387384
"""Test gradients of add_at function for different array dimensions and indices."""
388385
x, y = self.generate_x_y(rng, shape, indices)
389-
check_grads(lambda x: add_at(x, indices, y), modes=[mode], order=order)(x)
390-
check_grads(lambda y: add_at(x, indices, y), modes=[mode], order=order)(y)
386+
check_grads(lambda x: add_at(x, indices, y), modes=["fwd", "rev"], order=2)(x)
387+
check_grads(lambda y: add_at(x, indices, y), modes=["fwd", "rev"], order=2)(y)
388+
389+
390+
@pytest.mark.parametrize("shape", [(5,), (5, 5), (5, 5, 5)])
391+
@pytest.mark.parametrize("tau", [1e-3, 1.0])
392+
@pytest.mark.parametrize("axis", [None, 0, 1, -1])
393+
class TestSmoothMax:
394+
def test_smooth_max_values(self, rng, shape, tau, axis):
395+
"""Test `smooth_max` values for various shapes, tau, and axes."""
396+
397+
if axis == 1 and len(shape) == 1:
398+
pytest.skip()
399+
400+
x = rng.uniform(-10, 10, size=shape)
401+
result = smooth_max(x, tau=tau, axis=axis)
402+
403+
expected = np.max(x, axis=axis)
404+
npt.assert_allclose(result, expected, atol=10 * tau)
405+
406+
def test_smooth_max_grad(self, rng, shape, tau, axis):
407+
"""Test gradients of `smooth_max` for various parameters."""
408+
409+
if axis == 1 and len(shape) == 1:
410+
pytest.skip()
411+
412+
x = rng.uniform(-1, 1, size=shape)
413+
func = lambda x: smooth_max(x, tau=tau, axis=axis)
414+
check_grads(func, modes=["fwd", "rev"], order=2)(x)
415+
416+
417+
@pytest.mark.parametrize("shape", [(5,), (5, 5), (5, 5, 5)])
418+
@pytest.mark.parametrize("tau", [1e-3, 1.0])
419+
@pytest.mark.parametrize("axis", [None, 0, 1, -1])
420+
class TestSmoothMin:
421+
def test_smooth_min_values(self, rng, shape, tau, axis):
422+
"""Test `smooth_min` values for various shapes, tau, and axes."""
423+
424+
if axis == 1 and len(shape) == 1:
425+
pytest.skip()
426+
427+
x = rng.uniform(-10, 10, size=shape)
428+
result = smooth_min(x, tau=tau, axis=axis)
429+
430+
expected = np.min(x, axis=axis)
431+
npt.assert_allclose(result, expected, atol=10 * tau)
432+
433+
def test_smooth_min_grad(self, rng, shape, tau, axis):
434+
"""Test gradients of `smooth_min` for various parameters."""
435+
436+
if axis == 1 and len(shape) == 1:
437+
pytest.skip()
438+
439+
x = rng.uniform(-1, 1, size=shape)
440+
func = lambda x: smooth_min(x, tau=tau, axis=axis)
441+
check_grads(func, modes=["fwd", "rev"], order=2)(x)
442+
443+
444+
class TestLeastSquares:
445+
@pytest.mark.parametrize(
446+
"model, params_true, initial_guess, x, y",
447+
[
448+
(
449+
lambda x, a, b: a * x + b,
450+
np.array([2.0, -3.0]),
451+
(0.0, 0.0),
452+
np.linspace(0, 10, 50),
453+
2.0 * np.linspace(0, 10, 50) - 3.0,
454+
),
455+
(
456+
lambda x, a, b, c: a * x**2 + b * x + c,
457+
np.array([1.0, -2.0, 1.0]),
458+
(0.0, 0.0, 0.0),
459+
np.linspace(-5, 5, 100),
460+
1.0 * np.linspace(-5, 5, 100) ** 2 - 2.0 * np.linspace(-5, 5, 100) + 1.0,
461+
),
462+
(
463+
lambda x, a, b: a * np.exp(b * x),
464+
np.array([1.5, 0.5]),
465+
(1.0, 0.0),
466+
np.linspace(0, 2, 50),
467+
1.5 * np.exp(0.5 * np.linspace(0, 2, 50)),
468+
),
469+
],
470+
)
471+
def test_least_squares(self, model, params_true, initial_guess, x, y):
472+
"""Test least_squares function with different models."""
473+
params_estimated = least_squares(model, x, y, initial_guess)
474+
npt.assert_allclose(params_estimated, params_true, rtol=1e-5)
475+
476+
def test_least_squares_with_noise(self, rng):
477+
"""Test least_squares function with noisy data."""
478+
479+
model = lambda x, a, b: a * x + b
480+
a_true, b_true = -1.0, 4.0
481+
params_true = np.array([a_true, b_true])
482+
x = np.linspace(0, 10, 100)
483+
noise = rng.normal(scale=0.1, size=x.shape)
484+
y = a_true * x + b_true + noise
485+
initial_guess = (0.0, 0.0)
486+
487+
params_estimated = least_squares(model, x, y, initial_guess)
488+
489+
npt.assert_allclose(params_estimated, params_true, rtol=1e-1)
490+
491+
def test_least_squares_no_convergence(self):
492+
"""Test that least_squares function raises an error when not converging."""
493+
494+
def constant_model(x, a):
495+
return a
496+
497+
x = np.linspace(0, 10, 50)
498+
y = 2.0 * x - 3.0 # Linear data
499+
initial_guess = (0.0,)
500+
501+
with pytest.raises(np.linalg.LinAlgError):
502+
least_squares(constant_model, x, y, initial_guess, max_iterations=10, tol=1e-12)
503+
504+
def test_least_squares_gradient(self):
505+
"""Test gradients of least_squares function with respect to parameters."""
506+
507+
def linear_model(x, a, b):
508+
return a * x + b
509+
510+
x = np.linspace(0, 10, 50)
511+
y = 2.0 * x - 3.0
512+
initial_guess = (1.0, 0.0)
513+
514+
check_grads(
515+
lambda params: least_squares(linear_model, x, y, params), modes=["fwd", "rev"], order=2
516+
)(initial_guess)

0 commit comments

Comments
 (0)