diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..d969f96 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "python.testing.pytestArgs": ["tests"], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} diff --git a/src/magpylib_material_response/demag.py b/src/magpylib_material_response/demag.py index bffcdd3..f0f80b4 100644 --- a/src/magpylib_material_response/demag.py +++ b/src/magpylib_material_response/demag.py @@ -38,27 +38,37 @@ def get_susceptibilities(sources, susceptibility=None): the top level of the tree.""" n = len(sources) - # susceptibilities from source attributes if susceptibility is None: - susis = [] + # Get susceptibilities from source attributes + susceptibilities = [] for src in sources: - susceptibility = getattr(src, "susceptibility", None) - if susceptibility is None: + src_susceptibility = getattr(src, "susceptibility", None) + if src_susceptibility is None: if src.parent is None: msg = "No susceptibility defined in any parent collection" raise ValueError(msg) - susis.extend(get_susceptibilities(src.parent)) - elif not hasattr(susceptibility, "__len__"): - susis.append((susceptibility, susceptibility, susceptibility)) - elif len(susceptibility) == 3: - susis.append(susceptibility) - else: - msg = "susceptibility is not scalar or array of length 3" - raise ValueError(msg) - # susceptibilities as input to demag function - elif np.isscalar(susceptibility): - susis = np.ones((n, 3)) * susceptibility - elif len(susceptibility) == 3: + src_susceptibility = _get_susceptibility_from_hierarchy(src.parent) + susceptibilities.append(src_susceptibility) + + susis = _convert_to_array(susceptibilities, n) + else: + # Use function input susceptibility + susis = _convert_to_array(susceptibility, n) + + return np.reshape(susis, 3 * n, order="F") + + +def _convert_to_array(susceptibility, n): + """Convert susceptibility input(s) to (n, 3) array format""" + # Handle single values (scalar or 3-vector) applied to all sources + if np.isscalar(susceptibility): + return np.ones((n, 3)) * susceptibility + if ( + hasattr(susceptibility, "__len__") + and len(susceptibility) == 3 + and all(not isinstance(x, list | tuple | np.ndarray) for x in susceptibility) + ): + # This is a 3-vector, not a list of 3 items susis = np.tile(susceptibility, (n, 1)) if n == 3: msg = ( @@ -66,18 +76,48 @@ def get_susceptibilities(sources, susceptibility=None): "Please choose different means of input or change the number of cells in the Collection." ) raise ValueError(msg) - else: - if len(susceptibility) != n: - msg = "Apply_demag input susceptibility must be scalar, 3-vector, or same length as input Collection." + return susis + + # Handle list of susceptibilities (one per source) + susceptibility_list = ( + list(susceptibility) if not isinstance(susceptibility, list) else susceptibility + ) + + if len(susceptibility_list) != n: + msg = "Apply_demag input susceptibility must be scalar, 3-vector, or same length as input Collection." + raise ValueError(msg) + + # Convert each susceptibility to 3-tuple format + susis = [] + for sus in susceptibility_list: + if np.isscalar(sus): + susis.append((float(sus), float(sus), float(sus))) + elif hasattr(sus, "__len__") and len(sus) == 3: + try: + sus_tuple = tuple(float(x) for x in sus) + except Exception as e: + msg = f"Each element of susceptibility 3-vector must be numeric. Got: {sus!r} ({e})" + raise ValueError(msg) from e + susis.append(sus_tuple) + else: + msg = "susceptibility is not scalar or array of length 3" raise ValueError(msg) - susis = np.array(susceptibility) - if susis.ndim == 1: - susis = np.repeat(susis, 3).reshape(n, 3) - susis = np.reshape(susis, 3 * n, order="F") return np.array(susis) +def _get_susceptibility_from_hierarchy(source): + """Helper function to get susceptibility value from source or its parent hierarchy. + Returns the raw susceptibility value (scalar or 3-tuple), not the reshaped array.""" + susceptibility = getattr(source, "susceptibility", None) + if susceptibility is not None: + return susceptibility + if source.parent is None: + msg = "No susceptibility defined in any parent collection" + raise ValueError(msg) + return _get_susceptibility_from_hierarchy(source.parent) + + def get_H_ext(*sources, H_ext=None): """Return a list of length (len(sources)) with H_ext values Priority is given at the source level, however if value is not found, it is searched up the diff --git a/tests/test_basic.py b/tests/test_basic.py index d0d74b4..fbc0dfb 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -2,9 +2,10 @@ import magpylib as magpy import numpy as np +import pytest import magpylib_material_response -from magpylib_material_response.demag import apply_demag +from magpylib_material_response.demag import apply_demag, get_susceptibilities from magpylib_material_response.meshing import mesh_Cuboid @@ -12,39 +13,173 @@ def test_version(): assert isinstance(magpylib_material_response.__version__, str) -def test_susceptibility_inputs(): - """ - test if different xi inputs give the same result - """ - - zone = magpy.magnet.Cuboid( - dimension=(1, 1, 1), - polarization=(0, 0, 1), - ) +def test_apply_demag_integration(): + """Integration test: verify get_susceptibilities works correctly with apply_demag""" + zone = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1)) mesh = mesh_Cuboid(zone, (2, 2, 2)) + # Test that different equivalent susceptibility inputs give same result dm1 = apply_demag(mesh, susceptibility=4) dm2 = apply_demag(mesh, susceptibility=(4, 4, 4)) - dm3 = apply_demag(mesh, susceptibility=[4] * 8) - dm4 = apply_demag(mesh, susceptibility=[(4, 4, 4)] * 8) - zone = magpy.magnet.Cuboid( - dimension=(1, 1, 1), - polarization=(0, 0, 1), - ) zone.susceptibility = 4 - mesh = mesh_Cuboid(zone, (2, 2, 2)) - dm5 = apply_demag(mesh) + mesh_with_attr = mesh_Cuboid(zone, (2, 2, 2)) + dm3 = apply_demag(mesh_with_attr) - zone = magpy.magnet.Cuboid( - dimension=(1, 1, 1), - polarization=(0, 0, 1), - ) - zone.susceptibility = (4, 4, 4) - mesh = mesh_Cuboid(zone, (2, 2, 2)) - dm6 = apply_demag(mesh) + # All should give same magnetic field result + b_ref = dm1.getB((1, 2, 3)) + np.testing.assert_allclose(dm2.getB((1, 2, 3)), b_ref) + np.testing.assert_allclose(dm3.getB((1, 2, 3)), b_ref) + + +@pytest.mark.parametrize( + ("test_case", "susceptibility_input", "expected_output"), + [ + pytest.param( + "source_scalar", + [(2.5,), (3.0,)], + np.array([2.5, 3.0, 2.5, 3.0, 2.5, 3.0]), + id="source_scalar", + ), + pytest.param( + "source_vector", + [(1.0, 2.0, 3.0), (4.0, 5.0, 6.0)], + np.array([1.0, 4.0, 2.0, 5.0, 3.0, 6.0]), + id="source_vector", + ), + pytest.param( + "function_scalar", + 1.5, + np.array([1.5, 1.5, 1.5, 1.5, 1.5, 1.5]), + id="function_scalar", + ), + pytest.param( + "function_vector", + (2.0, 3.0, 4.0), + np.array([2.0, 2.0, 3.0, 3.0, 4.0, 4.0]), + id="function_vector", + ), + pytest.param( + "function_list", + [1.5, 2.5], + np.array([1.5, 2.5, 1.5, 2.5, 1.5, 2.5]), + id="function_list", + ), + ], +) +def test_get_susceptibilities_basic(test_case, susceptibility_input, expected_output): + """Test basic get_susceptibilities functionality with source attributes and function inputs""" + sources = [] + for _ in range(2): + zone = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1)) + sources.append(zone) + + if test_case.startswith("source"): + # Set susceptibility on sources + for i, sus_val in enumerate(susceptibility_input): + if len(sus_val) == 1: + sources[i].susceptibility = sus_val[0] + else: + sources[i].susceptibility = sus_val + result = get_susceptibilities(sources) + else: + # Use function input + result = get_susceptibilities(sources, susceptibility=susceptibility_input) + + np.testing.assert_allclose(result, expected_output) + + +def test_get_susceptibilities_hierarchy(): + """Test susceptibility inheritance from parent collections and mixed scenarios""" + # Create collection with susceptibility + collection = magpy.Collection() + collection.susceptibility = 2.0 + + # Source with its own susceptibility + zone_own = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1)) + zone_own.susceptibility = 5.0 + + # Source inheriting from parent + zone_inherit = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1)) + collection.add(zone_inherit) + + # Test mixed sources (critical edge case) + result = get_susceptibilities([zone_own, zone_inherit]) + expected = np.array([5.0, 2.0, 5.0, 2.0, 5.0, 2.0]) + np.testing.assert_allclose(result, expected) + + # Test single inheritance + result_single = get_susceptibilities([zone_inherit]) + expected_single = np.array([2.0, 2.0, 2.0]) + np.testing.assert_allclose(result_single, expected_single) + + +@pytest.mark.parametrize( + ("error_case", "setup_func", "error_message"), + [ + pytest.param( + "no_susceptibility", + lambda: [magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1))], + "No susceptibility defined in any parent collection", + id="no_susceptibility", + ), + pytest.param( + "invalid_format", + lambda: [_create_zone_with_bad_susceptibility()], + "susceptibility is not scalar or array of length 3", + id="invalid_format", + ), + pytest.param( + "wrong_length", + lambda: [ + magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1)) + for _ in range(4) + ], + "Apply_demag input susceptibility must be scalar, 3-vector, or same length as input Collection", + id="wrong_length", + ), + pytest.param( + "ambiguous_input", + lambda: [ + magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1)) + for _ in range(3) + ], + "Apply_demag input susceptibility is ambiguous", + id="ambiguous_input", + ), + ], +) +def test_get_susceptibilities_errors(error_case, setup_func, error_message): + """Test error cases for get_susceptibilities function""" + sources = setup_func() + + if error_case == "wrong_length": + with pytest.raises(ValueError, match=error_message): + get_susceptibilities(sources, susceptibility=[1.0, 2.0, 3.0, 4.0, 5.0]) + elif error_case == "ambiguous_input": + with pytest.raises(ValueError, match=error_message): + get_susceptibilities(sources, susceptibility=(1.0, 2.0, 3.0)) + else: + with pytest.raises(ValueError, match=error_message): + get_susceptibilities(sources) + + +def _create_zone_with_bad_susceptibility(): + """Helper to create a zone with invalid susceptibility format""" + zone = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1)) + zone.susceptibility = (1, 2) # Invalid: should be scalar or length 3 + return zone + + +def test_get_susceptibilities_edge_cases(): + """Test edge cases: empty list, single source""" + # Empty sources + result = get_susceptibilities([]) + assert len(result) == 0 - b1 = dm1.getB((1, 2, 3)) - for dm in [dm2, dm3, dm4, dm5, dm6]: - bb = dm.getB((1, 2, 3)) - np.testing.assert_allclose(b1, bb) + # Single source + zone = magpy.magnet.Cuboid(dimension=(1, 1, 1), polarization=(0, 0, 1)) + zone.susceptibility = 3.0 + result = get_susceptibilities([zone]) + expected = np.array([3.0, 3.0, 3.0]) + np.testing.assert_allclose(result, expected)