Skip to content

Commit 00b0330

Browse files
committed
greg comments
1 parent c43c07c commit 00b0330

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

tests/test_plugins/autograd/test_functions.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,46 @@ def test_morphology_val_structure_grad(
248248
check_grads(op, modes=["rev"], order=1)(x, structure=k, mode=mode)
249249

250250

251+
class TestMorphology1D:
252+
"""Test morphological operations with 1D-like structuring elements."""
253+
254+
@pytest.mark.parametrize("h, w", [(1, 3), (3, 1), (1, 5), (5, 1)])
255+
def test_1d_structuring_elements(self, rng, h, w):
256+
"""Test grey dilation with 1D-like structuring elements on 2D arrays."""
257+
x = rng.random((8, 8))
258+
259+
# Test with size parameter
260+
size_tuple = (h, w)
261+
result_size = grey_dilation(x, size=size_tuple)
262+
263+
# Verify output shape matches input
264+
assert result_size.shape == x.shape
265+
266+
# Verify that dilation actually increases values (or keeps them the same)
267+
assert np.all(result_size >= x)
268+
269+
# Test that we can also use structure parameter with 1D-like arrays
270+
structure = np.ones((h, w))
271+
result_struct = grey_dilation(x, structure=structure)
272+
assert result_struct.shape == x.shape
273+
274+
def test_1d_gradient_flow(self, rng):
275+
"""Test gradient flow through 1D-like structuring elements."""
276+
x = rng.random((6, 6))
277+
278+
# Test horizontal 1D structure
279+
check_grads(lambda x: grey_dilation(x, size=(1, 3)), modes=["rev"], order=1)(x)
280+
281+
# Test vertical 1D structure
282+
check_grads(lambda x: grey_dilation(x, size=(3, 1)), modes=["rev"], order=1)(x)
283+
284+
# Test with structure parameter
285+
struct_h = np.ones((1, 3))
286+
struct_v = np.ones((3, 1))
287+
check_grads(lambda x: grey_dilation(x, structure=struct_h), modes=["rev"], order=1)(x)
288+
check_grads(lambda x: grey_dilation(x, structure=struct_v), modes=["rev"], order=1)(x)
289+
290+
251291
class TestMorphologyExceptions:
252292
"""Test exceptions in morphological operations."""
253293

@@ -264,6 +304,13 @@ def test_even_structure_dimensions(self, rng):
264304
with pytest.raises(ValueError, match="Structuring element dimensions must be odd"):
265305
grey_dilation(x, structure=k_even)
266306

307+
def test_both_size_and_structure(self, rng):
308+
"""Test that an exception is raised when both size and structure are provided."""
309+
x = rng.random((5, 5))
310+
k = np.ones((3, 3))
311+
with pytest.raises(ValueError, match="Cannot specify both size and structure"):
312+
grey_dilation(x, size=3, structure=k)
313+
267314

268315
@pytest.mark.parametrize(
269316
"array, out_min, out_max, in_min, in_max, expected",

tidy3d/plugins/autograd/functions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ def _get_footprint(size, structure, maxval):
208208
"""Helper to generate the morphological footprint from size or structure."""
209209
if size is None and structure is None:
210210
raise ValueError("Either size or structure must be provided.")
211+
if size is not None and structure is not None:
212+
raise ValueError("Cannot specify both size and structure.")
211213
if structure is None:
212214
size_np = onp.atleast_1d(size)
213215
shape = (size_np[0], size_np[-1]) if size_np.size > 1 else (size_np[0], size_np[0])
@@ -238,8 +240,11 @@ def grey_dilation(
238240
The input array to perform grey dilation on.
239241
size : Union[Union[int, tuple[int, int]], None] = None
240242
The size of the structuring element. If None, `structure` must be provided.
243+
If a single integer is provided, a square structuring element is created.
244+
For 1D arrays, use a tuple (size, 1) or (1, size) for horizontal or vertical operations.
241245
structure : Union[np.ndarray, None] = None
242246
The structuring element. If None, `size` must be provided.
247+
For 1D operations on 2D arrays, use a 2D structure with one dimension being 1.
243248
mode : PaddingType = "reflect"
244249
The padding mode to use.
245250
maxval : float = 1e4
@@ -286,6 +291,10 @@ def _vjp_maker_dilation(ans, array, size=None, structure=None, *, mode="reflect"
286291
is_max_mask = (dilated_windows == output_reshaped).astype(onp.float64)
287292

288293
# normalize the gradient for cases where multiple elements are the maximum.
294+
# When multiple elements in a window equal the maximum value, the gradient
295+
# is distributed equally among them. This ensures gradient conservation.
296+
# Note: Values can never exceed maxval in the output since we add structure
297+
# values (capped at maxval) to the input array values.
289298
multiplicity = onp.sum(is_max_mask, axis=(-2, -1), keepdims=True)
290299
is_max_mask /= onp.maximum(multiplicity, 1)
291300

0 commit comments

Comments
 (0)