Skip to content

Commit 2aeaa48

Browse files
jcitrinTorax team
authored andcommitted
Extend and refactor MEQ mat file loading and add tests.
Updated the MEQ file loading function to handle both "LY" and "L" structured array wrappers commonly found in MEQ outputs. Added unit tests for the loader function. PiperOrigin-RevId: 848554269
1 parent 692ca3b commit 2aeaa48

File tree

3 files changed

+114
-16
lines changed

3 files changed

+114
-16
lines changed

torax/_src/geometry/geometry.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# limitations under the License.
1414

1515
"""Classes for representing the problem geometry."""
16+
1617
from collections.abc import Sequence
1718
import dataclasses
1819
import enum
20+
from typing import TypeVar
1921

2022
import chex
2123
import jax
@@ -385,7 +387,10 @@ def z_magnetic_axis(self) -> chex.Numeric:
385387
raise ValueError('Geometry does not have a z magnetic axis.')
386388

387389

388-
def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
390+
GeometryT = TypeVar('GeometryT', bound='Geometry')
391+
392+
393+
def stack_geometries(geometries: Sequence[GeometryT]) -> GeometryT:
389394
"""Batch together a sequence of geometries.
390395
391396
Args:

torax/_src/geometry/geometry_loader.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,29 @@ def _load_CHEASE_data( # pylint: disable=invalid-name
5959

6060

6161
def _load_fbt_data(file_path: str | IO[bytes]) -> dict[str, np.ndarray]:
62-
"""Loads data into a dictionary from an FBT-LY file or file path."""
62+
"""Loads data into a dictionary from an MEQ LY or L file or file path."""
6363

6464
meq_data = scipy.io.loadmat(file_path)
6565

66-
if "LY" not in meq_data:
67-
# L file or LY file not the output of MEQ meqlpack. Return as is.
68-
return meq_data
69-
else:
70-
# LY bundle file. numpy structured array likely resulting from
71-
# scipy.io.loadmat returning the output of MEQ meqlpack.m.
72-
# Reformat to return a dict. Shapes of 2D arrays are (radius, time)
73-
LY_bundle_dict = {} # pylint: disable=invalid-name
74-
LY_bundle_array = meq_data["LY"] # pylint: disable=invalid-name
75-
field_names = LY_bundle_array.dtype.names
76-
for name in field_names:
77-
LY_bundle_dict[name] = LY_bundle_array[name].item()
78-
79-
return LY_bundle_dict
66+
# Check for nested structured arrays common in MEQ outputs from MEQ meqlpack.
67+
# The actual data is often wrapped in a top-level "LY" or "L" key.
68+
for key in ("LY", "L"):
69+
if key in meq_data:
70+
data_array = meq_data[key]
71+
# Check if the array is a structured array (has field names).
72+
if data_array.dtype.names is None:
73+
raise ValueError(f"Provided MEQ {key} data not in expected format.")
74+
extracted_dict = {}
75+
for name in data_array.dtype.names:
76+
# structured_array[name] returns an array; .item() retrieves the
77+
# inner array object for that specific field.
78+
item = data_array[name].item()
79+
if not isinstance(item, np.ndarray):
80+
raise ValueError(f"MEQ data field '{name}' is not a numpy array.")
81+
extracted_dict[name] = item
82+
return extracted_dict
83+
# If no nested structured array is found, return the loaded mat dict as is.
84+
return meq_data
8085

8186

8287
def _load_eqdsk_data(file_path: str, cocos: int) -> dict[str, np.ndarray]:
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2025 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for the geometry_loader module."""
16+
17+
from unittest import mock
18+
from absl.testing import absltest
19+
import numpy as np
20+
from torax._src.geometry import geometry_loader
21+
22+
# pylint: disable=invalid-name
23+
24+
25+
class GeometryLoaderTest(absltest.TestCase):
26+
27+
@mock.patch("scipy.io.loadmat")
28+
def test_load_data_flat(self, mock_loadmat):
29+
"""Tests loading MEQ data that is already flat in the mat file."""
30+
mock_data = {
31+
"rBt": np.array([5.0]),
32+
"aminor": np.linspace(0.1, 1.0, 10),
33+
}
34+
mock_loadmat.return_value = mock_data
35+
36+
result = geometry_loader._load_fbt_data("dummy_path.mat")
37+
self.assertEqual(result, mock_data)
38+
39+
@mock.patch("scipy.io.loadmat")
40+
def test_load_fbt_data_nested_LY(self, mock_loadmat):
41+
"""Tests loading MEQ data nested inside an 'LY' structured array."""
42+
# Build a structured array to mimic scipy.io.loadmat's behavior for
43+
# meqlpack outputs.
44+
aminor_data = np.linspace(0.1, 1.0, 10)
45+
dtype = [("rBt", "O"), ("aminor", "O")]
46+
# Structured array with 1 element, fields contain arrays.
47+
LY_array = np.array([(np.array([5.0]), aminor_data)], dtype=dtype)
48+
49+
mock_loadmat.return_value = {"LY": LY_array}
50+
51+
result = geometry_loader._load_fbt_data("dummy_path.mat")
52+
53+
self.assertIn("rBt", result)
54+
self.assertIn("aminor", result)
55+
np.testing.assert_array_equal(result["rBt"], np.array([5.0]))
56+
np.testing.assert_array_equal(result["aminor"], aminor_data)
57+
58+
@mock.patch("scipy.io.loadmat")
59+
def test_load_fbt_data_nested_L(self, mock_loadmat):
60+
"""Tests loading MEQ data nested inside an 'L' structured array."""
61+
pQ_data = np.linspace(0.0, 1.0, 20)
62+
dtype = [("pQ", "O")]
63+
L_array = np.array([(pQ_data,)], dtype=dtype)
64+
65+
mock_loadmat.return_value = {"L": L_array}
66+
67+
result = geometry_loader._load_fbt_data("dummy_path.mat")
68+
69+
self.assertIn("pQ", result)
70+
np.testing.assert_array_equal(result["pQ"], pQ_data)
71+
72+
@mock.patch("scipy.io.loadmat")
73+
def test_load_fbt_data_invalid_item_type(self, mock_loadmat):
74+
"""Tests that a ValueError is raised if a nested item is not an array."""
75+
dtype = [("bad_field", "O")]
76+
# Create a structured array where the item is NOT a numpy array (e.g., int)
77+
LY_array = np.array([(123,)], dtype=dtype)
78+
79+
mock_loadmat.return_value = {"LY": LY_array}
80+
81+
with self.assertRaisesRegex(
82+
ValueError, "MEQ data field 'bad_field' is not a numpy array"
83+
):
84+
geometry_loader._load_fbt_data("dummy_path.mat")
85+
86+
87+
if __name__ == "__main__":
88+
absltest.main()

0 commit comments

Comments
 (0)