Skip to content

Commit 83333c7

Browse files
committed
xarray interface
1 parent 1fa0792 commit 83333c7

File tree

4 files changed

+262
-1
lines changed

4 files changed

+262
-1
lines changed

api/python/SConstruct

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@ import sys, os, string, re
22
sys.path.append('../../framework')
33
import bldutil
44

5-
modules = Split('m8r vplot las')
5+
modules = Split('m8r vplot las mxarray')
6+
7+
try:
8+
import xarray as xr
9+
except ImportError:
10+
modules.remove('mxarray')
611

712
try: # distribution version
813
Import('env root libdir incdir pkgdir bindir')
@@ -33,6 +38,7 @@ if root:
3338
pass
3439

3540
modules.remove('las')
41+
modules.remove('mxarray')
3642

3743
if env.get('SWIG') and env.get('NUMPY'):
3844
# try:

api/python/mxarray.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
## xarray interface for RSF files
2+
##
3+
## Copyright (C) 2025 University of Texas at Austin
4+
##
5+
## This program is free software; you can redistribute it and/or modify
6+
## it under the terms of the GNU General Public License as published by
7+
## the Free Software Foundation; either version 2 of the License, or
8+
## (at your option) any later version.
9+
##
10+
## This program is distributed in the hope that it will be useful,
11+
## but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
## GNU General Public License for more details.
14+
##
15+
## You should have received a copy of the GNU General Public License
16+
## along with this program; if not, write to the Free Software
17+
## Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
18+
19+
import m8r
20+
import numpy as np
21+
22+
try:
23+
import xarray as xr
24+
except ImportError:
25+
xr = None
26+
27+
try:
28+
import dask.array as da
29+
except ImportError:
30+
da = None
31+
32+
def rsf_to_xarray(path, chunks = "auto"):
33+
"""Convert an RSF file to an xarray DataArray with optimal Dask lazy loading."""
34+
35+
if xr is None:
36+
raise ImportError("xarray is required.")
37+
38+
f = m8r.Input(path)
39+
40+
shape = f.shape()
41+
ndim = len(shape)
42+
43+
dims = []
44+
# units = []
45+
coords = {}
46+
for axis in range(1, ndim+1):
47+
n = f.int(f"n{axis}")
48+
o = f.float(f"o{axis}")
49+
d = f.float(f"d{axis}")
50+
label = f.string(f"label{axis}")
51+
if label == None:
52+
label = f"dim{axis}"
53+
# unit = f.string(f"unit{axis}")
54+
55+
coords[label] = np.arange(n) * d + o
56+
dims.append(label)
57+
# units.append(unit)
58+
59+
# data = np.asarray(f)
60+
binFile = f.string("in")
61+
dtype = f.string("type")
62+
if dtype is None or dtype == "float":
63+
dtype = np.float32
64+
elif dtype == "int":
65+
dtype = np.int32
66+
else:
67+
raise ValueError(f"Unsupported data type: {dtype}")
68+
69+
mm = np.memmap(
70+
binFile,
71+
dtype=dtype,
72+
mode='r',
73+
shape=shape
74+
)
75+
76+
if da is not None:
77+
data = da.from_array(mm, chunks=chunks)
78+
else:
79+
data = np.asarray(mm)
80+
81+
82+
# covert c order to python order
83+
# data = data.reshape(shape_py)
84+
data = data.transpose(*reversed(range(data.ndim)))
85+
86+
ds = xr.DataArray(
87+
data,
88+
dims=dims,
89+
coords=coords
90+
)
91+
92+
return ds
93+
94+
def xarray_to_rsf(ds, outpath):
95+
"""Convert an xarray Dataset to an RSF file."""
96+
97+
if xr is None:
98+
raise ImportError("xarray is required.")
99+
100+
if not isinstance(ds, xr.DataArray):
101+
raise ValueError("Input must be an xarray DataArray.")
102+
103+
data = ds.values
104+
data = data.transpose(*reversed(range(data.ndim)))
105+
106+
dims = ds.dims
107+
108+
109+
out = m8r.Output(outpath)
110+
111+
for i, dim in enumerate(dims, start=1):
112+
113+
coord = ds.coords[dim].values
114+
if len(coord) > 1:
115+
d = coord[1] - coord[0]
116+
else:
117+
d = np.float32(1.)
118+
119+
o = coord[0]
120+
n = len(coord)
121+
122+
out.put(f"n{i}", n)
123+
out.put(f"o{i}", o)
124+
out.put(f"d{i}", d)
125+
out.put(f"label{i}", str(dim))
126+
127+
out.write(data.astype(np.float32))
128+
out.close()
129+
130+
def rsf_to_xarrayds(path, chunks = "auto"):
131+
"""Convert an RSF file to an xarray Dataset."""
132+
133+
da = rsf_to_xarray(path, chunks=chunks)
134+
ds = da.to_dataset(name="data")
135+
136+
return ds

api/python/test/SConstruct

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,10 @@ Flow('test.attr','clip.py',
99
binary = Command('afdm.exe','afdm.py','cp $SOURCE $TARGET')
1010
AddPostAction(binary,Chmod(binary,0o755))
1111

12+
Flow('test.mxarray','test_mxarray.py',
13+
'''
14+
./$SOURCE
15+
''',stdin=0)
16+
1217
End()
1318

api/python/test/test_mxarray.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#!/usr/bin/env python
2+
3+
import mxarray as mx
4+
import numpy as np
5+
import xarray as xr
6+
import sys
7+
8+
def test_roundtrip(path):
9+
10+
# create xarray
11+
arr = xr.DataArray(
12+
np.arange(3*4*5*6, dtype=np.float32).reshape(3,4,5,6),
13+
dims=("t","x","y","z"),
14+
coords={
15+
"t": [0,1,2],
16+
"x": [0.0, 1.0, 2.0, 3.0],
17+
"y": [0.0, 1.0, 2.0, 3.0, 4.0],
18+
"z": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
19+
}
20+
)
21+
# write to RSF
22+
out = path
23+
mx.xarray_to_rsf(arr, str(out))
24+
25+
# read back
26+
da = mx.rsf_to_xarray(str(out))
27+
diff = da - arr
28+
29+
assert da.shape == (3, 4, 5, 6)
30+
assert da.values.dtype == arr.values.dtype == np.float32
31+
assert da.dims == arr.dims
32+
assert da.coords["t"].values.tolist() == arr.coords["t"].values.tolist()
33+
assert da.coords["x"].values.tolist() == arr.coords["x"].values.tolist()
34+
assert da.coords["y"].values.tolist() == arr.coords["y"].values.tolist()
35+
assert da.coords["z"].values.tolist() == arr.coords["z"].values.tolist()
36+
assert da.values.tolist() == arr.values.tolist()
37+
assert diff.values.tolist() == [[[[0.0 for _ in range(6)] for _ in range(5)] for _ in range(4)] for _ in range(3)]
38+
39+
# read as xarray Dataset
40+
arr = mx.rsf_to_xarrayds(path)
41+
42+
# write to netcdf (can be loaded in paraview)
43+
arr.to_netcdf("test.nc")
44+
arr2 = xr.open_dataarray("test.nc")
45+
46+
assert arr.sizes == arr2.sizes
47+
assert arr.coords["t"].values.tolist() == arr2.coords["t"].values.tolist()
48+
assert arr.coords["x"].values.tolist() == arr2.coords["x"].values.tolist()
49+
assert arr.coords["y"].values.tolist() == arr2.coords["y"].values.tolist()
50+
assert arr.coords["z"].values.tolist() == arr2.coords["z"].values.tolist()
51+
assert arr.data.values.tolist() == arr2.values.tolist()
52+
53+
import m8r, os
54+
tmp = m8r.Input(file)
55+
binFile = tmp.string("in")
56+
os.remove(binFile)
57+
os.remove(file)
58+
os.remove("test.nc")
59+
60+
def test_cunks(path):
61+
arr = xr.DataArray(
62+
np.arange(100*100*100*5, dtype=np.float32).reshape(100,100,100,5),
63+
dims=("t","x","y","z"),
64+
coords={
65+
"t" : np.arange(100),
66+
"x" : np.arange(100),
67+
"y" : np.arange(100),
68+
"z" : np.arange(5)
69+
}
70+
)
71+
72+
mx.xarray_to_rsf(arr, path)
73+
del arr
74+
75+
import time
76+
start_time = time.time()
77+
da = mx.rsf_to_xarray(path, chunks=(2, 3, 5, 1))
78+
print(f"mean: {da.mean().compute().values}")
79+
end_time = time.time()
80+
print(f"read with chuks {(2, 3, 5, 1)} took {end_time - start_time} seconds")
81+
del da
82+
start_time = time.time()
83+
da = mx.rsf_to_xarray(path, chunks=(1, 1,1, 1))
84+
print(f"mean: {da.mean().compute().values}")
85+
end_time = time.time()
86+
print(f"read with chuks {(1, 1, 1, 1)} took {end_time - start_time} seconds")
87+
del da
88+
start_time = time.time()
89+
da = mx.rsf_to_xarray(path, chunks="auto")
90+
print(f"mean: {da.mean().compute().values}")
91+
end_time = time.time()
92+
print(f"read with chuks auto took {end_time - start_time} seconds")
93+
del da
94+
95+
import os
96+
os.remove(path)
97+
98+
99+
if __name__ == "__main__":
100+
file = "tmp.rsf"
101+
test_roundtrip(path=file)
102+
# test_cunks(path="./tmp_large.rsf")
103+
sys.stdout.write("All tests passed.\n")
104+
## test chunks
105+
####################################################################################
106+
############# output for data with shape (100, 100, 100, 5) ########################
107+
####################################################################################
108+
# mean: 2499999.5
109+
# read with chuks (2, 3, 5, 1) took 21.33984923362732 seconds
110+
# mean: 2499999.5
111+
# read with chuks (1, 1, 1, 1) took 559.3831532001495 seconds
112+
# mean: 2499999.5
113+
# read with chuks auto took 0.01817178726196289 seconds
114+
####################################################################################

0 commit comments

Comments
 (0)