Skip to content

Commit 90df454

Browse files
Alexboiboipre-commit-ci[bot]Copilot
authored
Fix susceptibility parent tree traversal (#26)
* Add VSCode settings for pytest configuration * Refactor get_susceptibilities function for improved clarity and error handling; add comprehensive tests for various input scenarios and edge cases * style: pre-commit fixes * linting * Update src/magpylib_material_response/demag.py Co-authored-by: Copilot <[email protected]> * Update src/magpylib_material_response/demag.py Co-authored-by: Copilot <[email protected]> * Refactor demag.py for improved clarity and organization; enhance logging configuration --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <[email protected]>
1 parent bb8fba6 commit 90df454

File tree

3 files changed

+232
-52
lines changed

3 files changed

+232
-52
lines changed

.vscode/settings.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"python.testing.pytestArgs": ["tests"],
3+
"python.testing.unittestEnabled": false,
4+
"python.testing.pytestEnabled": true
5+
}

src/magpylib_material_response/demag.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,46 +38,86 @@ def get_susceptibilities(sources, susceptibility=None):
3838
the top level of the tree."""
3939
n = len(sources)
4040

41-
# susceptibilities from source attributes
4241
if susceptibility is None:
43-
susis = []
42+
# Get susceptibilities from source attributes
43+
susceptibilities = []
4444
for src in sources:
45-
susceptibility = getattr(src, "susceptibility", None)
46-
if susceptibility is None:
45+
src_susceptibility = getattr(src, "susceptibility", None)
46+
if src_susceptibility is None:
4747
if src.parent is None:
4848
msg = "No susceptibility defined in any parent collection"
4949
raise ValueError(msg)
50-
susis.extend(get_susceptibilities(src.parent))
51-
elif not hasattr(susceptibility, "__len__"):
52-
susis.append((susceptibility, susceptibility, susceptibility))
53-
elif len(susceptibility) == 3:
54-
susis.append(susceptibility)
55-
else:
56-
msg = "susceptibility is not scalar or array of length 3"
57-
raise ValueError(msg)
58-
# susceptibilities as input to demag function
59-
elif np.isscalar(susceptibility):
60-
susis = np.ones((n, 3)) * susceptibility
61-
elif len(susceptibility) == 3:
50+
src_susceptibility = _get_susceptibility_from_hierarchy(src.parent)
51+
susceptibilities.append(src_susceptibility)
52+
53+
susis = _convert_to_array(susceptibilities, n)
54+
else:
55+
# Use function input susceptibility
56+
susis = _convert_to_array(susceptibility, n)
57+
58+
return np.reshape(susis, 3 * n, order="F")
59+
60+
61+
def _convert_to_array(susceptibility, n):
62+
"""Convert susceptibility input(s) to (n, 3) array format"""
63+
# Handle single values (scalar or 3-vector) applied to all sources
64+
if np.isscalar(susceptibility):
65+
return np.ones((n, 3)) * susceptibility
66+
if (
67+
hasattr(susceptibility, "__len__")
68+
and len(susceptibility) == 3
69+
and all(not isinstance(x, list | tuple | np.ndarray) for x in susceptibility)
70+
):
71+
# This is a 3-vector, not a list of 3 items
6272
susis = np.tile(susceptibility, (n, 1))
6373
if n == 3:
6474
msg = (
6575
"Apply_demag input susceptibility is ambiguous - either scalar list or vector single entry. "
6676
"Please choose different means of input or change the number of cells in the Collection."
6777
)
6878
raise ValueError(msg)
69-
else:
70-
if len(susceptibility) != n:
71-
msg = "Apply_demag input susceptibility must be scalar, 3-vector, or same length as input Collection."
79+
return susis
80+
81+
# Handle list of susceptibilities (one per source)
82+
susceptibility_list = (
83+
list(susceptibility) if not isinstance(susceptibility, list) else susceptibility
84+
)
85+
86+
if len(susceptibility_list) != n:
87+
msg = "Apply_demag input susceptibility must be scalar, 3-vector, or same length as input Collection."
88+
raise ValueError(msg)
89+
90+
# Convert each susceptibility to 3-tuple format
91+
susis = []
92+
for sus in susceptibility_list:
93+
if np.isscalar(sus):
94+
susis.append((float(sus), float(sus), float(sus)))
95+
elif hasattr(sus, "__len__") and len(sus) == 3:
96+
try:
97+
sus_tuple = tuple(float(x) for x in sus)
98+
except Exception as e:
99+
msg = f"Each element of susceptibility 3-vector must be numeric. Got: {sus!r} ({e})"
100+
raise ValueError(msg) from e
101+
susis.append(sus_tuple)
102+
else:
103+
msg = "susceptibility is not scalar or array of length 3"
72104
raise ValueError(msg)
73-
susis = np.array(susceptibility)
74-
if susis.ndim == 1:
75-
susis = np.repeat(susis, 3).reshape(n, 3)
76105

77-
susis = np.reshape(susis, 3 * n, order="F")
78106
return np.array(susis)
79107

80108

109+
def _get_susceptibility_from_hierarchy(source):
110+
"""Helper function to get susceptibility value from source or its parent hierarchy.
111+
Returns the raw susceptibility value (scalar or 3-tuple), not the reshaped array."""
112+
susceptibility = getattr(source, "susceptibility", None)
113+
if susceptibility is not None:
114+
return susceptibility
115+
if source.parent is None:
116+
msg = "No susceptibility defined in any parent collection"
117+
raise ValueError(msg)
118+
return _get_susceptibility_from_hierarchy(source.parent)
119+
120+
81121
def get_H_ext(*sources, H_ext=None):
82122
"""Return a list of length (len(sources)) with H_ext values
83123
Priority is given at the source level, however if value is not found, it is searched up the

tests/test_basic.py

Lines changed: 164 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,184 @@
22

33
import magpylib as magpy
44
import numpy as np
5+
import pytest
56

67
import magpylib_material_response
7-
from magpylib_material_response.demag import apply_demag
8+
from magpylib_material_response.demag import apply_demag, get_susceptibilities
89
from magpylib_material_response.meshing import mesh_Cuboid
910

1011

1112
def test_version():
1213
assert isinstance(magpylib_material_response.__version__, str)
1314

1415

15-
def test_susceptibility_inputs():
16-
"""
17-
test if different xi inputs give the same result
18-
"""
19-
20-
zone = magpy.magnet.Cuboid(
21-
dimension=(1, 1, 1),
22-
polarization=(0, 0, 1),
23-
)
16+
def test_apply_demag_integration():
17+
"""Integration test: verify get_susceptibilities works correctly with apply_demag"""
18+
zone = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))
2419
mesh = mesh_Cuboid(zone, (2, 2, 2))
2520

21+
# Test that different equivalent susceptibility inputs give same result
2622
dm1 = apply_demag(mesh, susceptibility=4)
2723
dm2 = apply_demag(mesh, susceptibility=(4, 4, 4))
28-
dm3 = apply_demag(mesh, susceptibility=[4] * 8)
29-
dm4 = apply_demag(mesh, susceptibility=[(4, 4, 4)] * 8)
3024

31-
zone = magpy.magnet.Cuboid(
32-
dimension=(1, 1, 1),
33-
polarization=(0, 0, 1),
34-
)
3525
zone.susceptibility = 4
36-
mesh = mesh_Cuboid(zone, (2, 2, 2))
37-
dm5 = apply_demag(mesh)
26+
mesh_with_attr = mesh_Cuboid(zone, (2, 2, 2))
27+
dm3 = apply_demag(mesh_with_attr)
3828

39-
zone = magpy.magnet.Cuboid(
40-
dimension=(1, 1, 1),
41-
polarization=(0, 0, 1),
42-
)
43-
zone.susceptibility = (4, 4, 4)
44-
mesh = mesh_Cuboid(zone, (2, 2, 2))
45-
dm6 = apply_demag(mesh)
29+
# All should give same magnetic field result
30+
b_ref = dm1.getB((1, 2, 3))
31+
np.testing.assert_allclose(dm2.getB((1, 2, 3)), b_ref)
32+
np.testing.assert_allclose(dm3.getB((1, 2, 3)), b_ref)
33+
34+
35+
@pytest.mark.parametrize(
36+
("test_case", "susceptibility_input", "expected_output"),
37+
[
38+
pytest.param(
39+
"source_scalar",
40+
[(2.5,), (3.0,)],
41+
np.array([2.5, 3.0, 2.5, 3.0, 2.5, 3.0]),
42+
id="source_scalar",
43+
),
44+
pytest.param(
45+
"source_vector",
46+
[(1.0, 2.0, 3.0), (4.0, 5.0, 6.0)],
47+
np.array([1.0, 4.0, 2.0, 5.0, 3.0, 6.0]),
48+
id="source_vector",
49+
),
50+
pytest.param(
51+
"function_scalar",
52+
1.5,
53+
np.array([1.5, 1.5, 1.5, 1.5, 1.5, 1.5]),
54+
id="function_scalar",
55+
),
56+
pytest.param(
57+
"function_vector",
58+
(2.0, 3.0, 4.0),
59+
np.array([2.0, 2.0, 3.0, 3.0, 4.0, 4.0]),
60+
id="function_vector",
61+
),
62+
pytest.param(
63+
"function_list",
64+
[1.5, 2.5],
65+
np.array([1.5, 2.5, 1.5, 2.5, 1.5, 2.5]),
66+
id="function_list",
67+
),
68+
],
69+
)
70+
def test_get_susceptibilities_basic(test_case, susceptibility_input, expected_output):
71+
"""Test basic get_susceptibilities functionality with source attributes and function inputs"""
72+
sources = []
73+
for _ in range(2):
74+
zone = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))
75+
sources.append(zone)
76+
77+
if test_case.startswith("source"):
78+
# Set susceptibility on sources
79+
for i, sus_val in enumerate(susceptibility_input):
80+
if len(sus_val) == 1:
81+
sources[i].susceptibility = sus_val[0]
82+
else:
83+
sources[i].susceptibility = sus_val
84+
result = get_susceptibilities(sources)
85+
else:
86+
# Use function input
87+
result = get_susceptibilities(sources, susceptibility=susceptibility_input)
88+
89+
np.testing.assert_allclose(result, expected_output)
90+
91+
92+
def test_get_susceptibilities_hierarchy():
93+
"""Test susceptibility inheritance from parent collections and mixed scenarios"""
94+
# Create collection with susceptibility
95+
collection = magpy.Collection()
96+
collection.susceptibility = 2.0
97+
98+
# Source with its own susceptibility
99+
zone_own = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))
100+
zone_own.susceptibility = 5.0
101+
102+
# Source inheriting from parent
103+
zone_inherit = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))
104+
collection.add(zone_inherit)
105+
106+
# Test mixed sources (critical edge case)
107+
result = get_susceptibilities([zone_own, zone_inherit])
108+
expected = np.array([5.0, 2.0, 5.0, 2.0, 5.0, 2.0])
109+
np.testing.assert_allclose(result, expected)
110+
111+
# Test single inheritance
112+
result_single = get_susceptibilities([zone_inherit])
113+
expected_single = np.array([2.0, 2.0, 2.0])
114+
np.testing.assert_allclose(result_single, expected_single)
115+
116+
117+
@pytest.mark.parametrize(
118+
("error_case", "setup_func", "error_message"),
119+
[
120+
pytest.param(
121+
"no_susceptibility",
122+
lambda: [magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))],
123+
"No susceptibility defined in any parent collection",
124+
id="no_susceptibility",
125+
),
126+
pytest.param(
127+
"invalid_format",
128+
lambda: [_create_zone_with_bad_susceptibility()],
129+
"susceptibility is not scalar or array of length 3",
130+
id="invalid_format",
131+
),
132+
pytest.param(
133+
"wrong_length",
134+
lambda: [
135+
magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))
136+
for _ in range(4)
137+
],
138+
"Apply_demag input susceptibility must be scalar, 3-vector, or same length as input Collection",
139+
id="wrong_length",
140+
),
141+
pytest.param(
142+
"ambiguous_input",
143+
lambda: [
144+
magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))
145+
for _ in range(3)
146+
],
147+
"Apply_demag input susceptibility is ambiguous",
148+
id="ambiguous_input",
149+
),
150+
],
151+
)
152+
def test_get_susceptibilities_errors(error_case, setup_func, error_message):
153+
"""Test error cases for get_susceptibilities function"""
154+
sources = setup_func()
155+
156+
if error_case == "wrong_length":
157+
with pytest.raises(ValueError, match=error_message):
158+
get_susceptibilities(sources, susceptibility=[1.0, 2.0, 3.0, 4.0, 5.0])
159+
elif error_case == "ambiguous_input":
160+
with pytest.raises(ValueError, match=error_message):
161+
get_susceptibilities(sources, susceptibility=(1.0, 2.0, 3.0))
162+
else:
163+
with pytest.raises(ValueError, match=error_message):
164+
get_susceptibilities(sources)
165+
166+
167+
def _create_zone_with_bad_susceptibility():
168+
"""Helper to create a zone with invalid susceptibility format"""
169+
zone = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))
170+
zone.susceptibility = (1, 2) # Invalid: should be scalar or length 3
171+
return zone
172+
173+
174+
def test_get_susceptibilities_edge_cases():
175+
"""Test edge cases: empty list, single source"""
176+
# Empty sources
177+
result = get_susceptibilities([])
178+
assert len(result) == 0
46179

47-
b1 = dm1.getB((1, 2, 3))
48-
for dm in [dm2, dm3, dm4, dm5, dm6]:
49-
bb = dm.getB((1, 2, 3))
50-
np.testing.assert_allclose(b1, bb)
180+
# Single source
181+
zone = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))
182+
zone.susceptibility = 3.0
183+
result = get_susceptibilities([zone])
184+
expected = np.array([3.0, 3.0, 3.0])
185+
np.testing.assert_allclose(result, expected)

0 commit comments

Comments
 (0)