Skip to content

Commit ceed33f

Browse files
committed
Refactor get_susceptibilities function for improved clarity and error handling; add comprehensive tests for various input scenarios and edge cases
1 parent 2114e22 commit ceed33f

File tree

2 files changed

+212
-54
lines changed

2 files changed

+212
-54
lines changed

src/magpylib_material_response/demag.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,46 +38,75 @@ def get_susceptibilities(sources, susceptibility=None):
3838
the top level of the tree."""
3939
n = len(sources)
4040

41-
# susceptibilities from source attributes
4241
if susceptibility is None:
43-
susis = []
42+
# Get susceptibilities from source attributes
43+
susceptibilities = []
4444
for src in sources:
45-
susceptibility = getattr(src, "susceptibility", None)
46-
if susceptibility is None:
45+
src_susceptibility = getattr(src, "susceptibility", None)
46+
if src_susceptibility is None:
4747
if src.parent is None:
4848
msg = "No susceptibility defined in any parent collection"
4949
raise ValueError(msg)
50-
susis.extend(get_susceptibilities(src.parent))
51-
elif not hasattr(susceptibility, "__len__"):
52-
susis.append((susceptibility, susceptibility, susceptibility))
53-
elif len(susceptibility) == 3:
54-
susis.append(susceptibility)
55-
else:
56-
msg = "susceptibility is not scalar or array of length 3"
57-
raise ValueError(msg)
58-
# susceptibilities as input to demag function
59-
elif np.isscalar(susceptibility):
60-
susis = np.ones((n, 3)) * susceptibility
61-
elif len(susceptibility) == 3:
50+
src_susceptibility = _get_susceptibility_from_hierarchy(src.parent)
51+
susceptibilities.append(src_susceptibility)
52+
53+
susis = _convert_to_array(susceptibilities, n)
54+
else:
55+
# Use function input susceptibility
56+
susis = _convert_to_array(susceptibility, n)
57+
58+
return np.reshape(susis, 3 * n, order="F")
59+
60+
61+
def _convert_to_array(susceptibility, n):
62+
"""Convert susceptibility input(s) to (n, 3) array format"""
63+
# Handle single values (scalar or 3-vector) applied to all sources
64+
if np.isscalar(susceptibility):
65+
return np.ones((n, 3)) * susceptibility
66+
elif hasattr(susceptibility, '__len__') and len(susceptibility) == 3 and not isinstance(susceptibility[0], (list, tuple, np.ndarray)):
67+
# This is a 3-vector, not a list of 3 items
6268
susis = np.tile(susceptibility, (n, 1))
6369
if n == 3:
6470
msg = (
6571
"Apply_demag input susceptibility is ambiguous - either scalar list or vector single entry. "
6672
"Please choose different means of input or change the number of cells in the Collection."
6773
)
6874
raise ValueError(msg)
69-
else:
70-
if len(susceptibility) != n:
71-
msg = "Apply_demag input susceptibility must be scalar, 3-vector, or same length as input Collection."
75+
return susis
76+
77+
# Handle list of susceptibilities (one per source)
78+
susceptibility_list = list(susceptibility) if not isinstance(susceptibility, list) else susceptibility
79+
80+
if len(susceptibility_list) != n:
81+
msg = "Apply_demag input susceptibility must be scalar, 3-vector, or same length as input Collection."
82+
raise ValueError(msg)
83+
84+
# Convert each susceptibility to 3-tuple format
85+
susis = []
86+
for sus in susceptibility_list:
87+
if np.isscalar(sus):
88+
susis.append((float(sus), float(sus), float(sus)))
89+
elif hasattr(sus, '__len__') and len(sus) == 3:
90+
susis.append(tuple(sus))
91+
else:
92+
msg = "susceptibility is not scalar or array of length 3"
7293
raise ValueError(msg)
73-
susis = np.array(susceptibility)
74-
if susis.ndim == 1:
75-
susis = np.repeat(susis, 3).reshape(n, 3)
76-
77-
susis = np.reshape(susis, 3 * n, order="F")
94+
7895
return np.array(susis)
7996

8097

98+
def _get_susceptibility_from_hierarchy(source):
99+
"""Helper function to get susceptibility value from source or its parent hierarchy.
100+
Returns the raw susceptibility value (scalar or 3-tuple), not the reshaped array."""
101+
susceptibility = getattr(source, "susceptibility", None)
102+
if susceptibility is not None:
103+
return susceptibility
104+
if source.parent is None:
105+
msg = "No susceptibility defined in any parent collection"
106+
raise ValueError(msg)
107+
return _get_susceptibility_from_hierarchy(source.parent)
108+
109+
81110
def get_H_ext(*sources, H_ext=None):
82111
"""Return a list of length (len(sources)) with H_ext values
83112
Priority is given at the source level, however if value is not found, it is searched up the

tests/test_basic.py

Lines changed: 159 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,178 @@
22

33
import magpylib as magpy
44
import numpy as np
5+
import pytest
56

67
import magpylib_material_response
7-
from magpylib_material_response.demag import apply_demag
8+
from magpylib_material_response.demag import apply_demag, get_susceptibilities
89
from magpylib_material_response.meshing import mesh_Cuboid
910

1011

1112
def test_version():
1213
assert isinstance(magpylib_material_response.__version__, str)
1314

1415

15-
def test_susceptibility_inputs():
16-
"""
17-
test if different xi inputs give the same result
18-
"""
19-
20-
zone = magpy.magnet.Cuboid(
21-
dimension=(1, 1, 1),
22-
polarization=(0, 0, 1),
23-
)
16+
def test_apply_demag_integration():
17+
"""Integration test: verify get_susceptibilities works correctly with apply_demag"""
18+
zone = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))
2419
mesh = mesh_Cuboid(zone, (2, 2, 2))
2520

21+
# Test that different equivalent susceptibility inputs give same result
2622
dm1 = apply_demag(mesh, susceptibility=4)
2723
dm2 = apply_demag(mesh, susceptibility=(4, 4, 4))
28-
dm3 = apply_demag(mesh, susceptibility=[4] * 8)
29-
dm4 = apply_demag(mesh, susceptibility=[(4, 4, 4)] * 8)
30-
31-
zone = magpy.magnet.Cuboid(
32-
dimension=(1, 1, 1),
33-
polarization=(0, 0, 1),
34-
)
24+
3525
zone.susceptibility = 4
36-
mesh = mesh_Cuboid(zone, (2, 2, 2))
37-
dm5 = apply_demag(mesh)
26+
mesh_with_attr = mesh_Cuboid(zone, (2, 2, 2))
27+
dm3 = apply_demag(mesh_with_attr)
28+
29+
# All should give same magnetic field result
30+
b_ref = dm1.getB((1, 2, 3))
31+
np.testing.assert_allclose(dm2.getB((1, 2, 3)), b_ref)
32+
np.testing.assert_allclose(dm3.getB((1, 2, 3)), b_ref)
33+
34+
35+
@pytest.mark.parametrize(
36+
"test_case,susceptibility_input,expected_output",
37+
[
38+
pytest.param(
39+
"source_scalar",
40+
[(2.5,), (3.0,)],
41+
np.array([2.5, 3.0, 2.5, 3.0, 2.5, 3.0]),
42+
id="source_scalar"
43+
),
44+
pytest.param(
45+
"source_vector",
46+
[(1.0, 2.0, 3.0), (4.0, 5.0, 6.0)],
47+
np.array([1.0, 4.0, 2.0, 5.0, 3.0, 6.0]),
48+
id="source_vector"
49+
),
50+
pytest.param(
51+
"function_scalar",
52+
1.5,
53+
np.array([1.5, 1.5, 1.5, 1.5, 1.5, 1.5]),
54+
id="function_scalar"
55+
),
56+
pytest.param(
57+
"function_vector",
58+
(2.0, 3.0, 4.0),
59+
np.array([2.0, 2.0, 3.0, 3.0, 4.0, 4.0]),
60+
id="function_vector"
61+
),
62+
pytest.param(
63+
"function_list",
64+
[1.5, 2.5],
65+
np.array([1.5, 2.5, 1.5, 2.5, 1.5, 2.5]),
66+
id="function_list"
67+
),
68+
]
69+
)
70+
def test_get_susceptibilities_basic(test_case, susceptibility_input, expected_output):
71+
"""Test basic get_susceptibilities functionality with source attributes and function inputs"""
72+
sources = []
73+
for _ in range(2):
74+
zone = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))
75+
sources.append(zone)
76+
77+
if test_case.startswith("source"):
78+
# Set susceptibility on sources
79+
for i, sus_val in enumerate(susceptibility_input):
80+
if len(sus_val) == 1:
81+
sources[i].susceptibility = sus_val[0]
82+
else:
83+
sources[i].susceptibility = sus_val
84+
result = get_susceptibilities(sources)
85+
else:
86+
# Use function input
87+
result = get_susceptibilities(sources, susceptibility=susceptibility_input)
88+
89+
np.testing.assert_allclose(result, expected_output)
90+
91+
92+
def test_get_susceptibilities_hierarchy():
93+
"""Test susceptibility inheritance from parent collections and mixed scenarios"""
94+
# Create collection with susceptibility
95+
collection = magpy.Collection()
96+
collection.susceptibility = 2.0
97+
98+
# Source with its own susceptibility
99+
zone_own = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))
100+
zone_own.susceptibility = 5.0
101+
102+
# Source inheriting from parent
103+
zone_inherit = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))
104+
collection.add(zone_inherit)
105+
106+
# Test mixed sources (critical edge case)
107+
result = get_susceptibilities([zone_own, zone_inherit])
108+
expected = np.array([5.0, 2.0, 5.0, 2.0, 5.0, 2.0])
109+
np.testing.assert_allclose(result, expected)
110+
111+
# Test single inheritance
112+
result_single = get_susceptibilities([zone_inherit])
113+
expected_single = np.array([2.0, 2.0, 2.0])
114+
np.testing.assert_allclose(result_single, expected_single)
115+
116+
117+
@pytest.mark.parametrize(
118+
"error_case,setup_func,error_message",
119+
[
120+
pytest.param(
121+
"no_susceptibility",
122+
lambda: [magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))],
123+
"No susceptibility defined in any parent collection",
124+
id="no_susceptibility"
125+
),
126+
pytest.param(
127+
"invalid_format",
128+
lambda: [_create_zone_with_bad_susceptibility()],
129+
"susceptibility is not scalar or array of length 3",
130+
id="invalid_format"
131+
),
132+
pytest.param(
133+
"wrong_length",
134+
lambda: [magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1)) for _ in range(4)],
135+
"Apply_demag input susceptibility must be scalar, 3-vector, or same length as input Collection",
136+
id="wrong_length"
137+
),
138+
pytest.param(
139+
"ambiguous_input",
140+
lambda: [magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1)) for _ in range(3)],
141+
"Apply_demag input susceptibility is ambiguous",
142+
id="ambiguous_input"
143+
),
144+
]
145+
)
146+
def test_get_susceptibilities_errors(error_case, setup_func, error_message):
147+
"""Test error cases for get_susceptibilities function"""
148+
sources = setup_func()
149+
150+
if error_case == "wrong_length":
151+
with pytest.raises(ValueError, match=error_message):
152+
get_susceptibilities(sources, susceptibility=[1.0, 2.0, 3.0, 4.0, 5.0])
153+
elif error_case == "ambiguous_input":
154+
with pytest.raises(ValueError, match=error_message):
155+
get_susceptibilities(sources, susceptibility=(1.0, 2.0, 3.0))
156+
else:
157+
with pytest.raises(ValueError, match=error_message):
158+
get_susceptibilities(sources)
159+
160+
161+
def _create_zone_with_bad_susceptibility():
162+
"""Helper to create a zone with invalid susceptibility format"""
163+
zone = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))
164+
zone.susceptibility = (1, 2) # Invalid: should be scalar or length 3
165+
return zone
38166

39-
zone = magpy.magnet.Cuboid(
40-
dimension=(1, 1, 1),
41-
polarization=(0, 0, 1),
42-
)
43-
zone.susceptibility = (4, 4, 4)
44-
mesh = mesh_Cuboid(zone, (2, 2, 2))
45-
dm6 = apply_demag(mesh)
46167

47-
b1 = dm1.getB((1, 2, 3))
48-
for dm in [dm2, dm3, dm4, dm5, dm6]:
49-
bb = dm.getB((1, 2, 3))
50-
np.testing.assert_allclose(b1, bb)
168+
def test_get_susceptibilities_edge_cases():
169+
"""Test edge cases: empty list, single source"""
170+
# Empty sources
171+
result = get_susceptibilities([])
172+
assert len(result) == 0
173+
174+
# Single source
175+
zone = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))
176+
zone.susceptibility = 3.0
177+
result = get_susceptibilities([zone])
178+
expected = np.array([3.0, 3.0, 3.0])
179+
np.testing.assert_allclose(result, expected)

0 commit comments

Comments
 (0)