Skip to content

Commit 4bc545b

Browse files
committed
Merge branch 'master' of https://github.com/ahay/src
2 parents d6f3a5e + 11cb871 commit 4bc545b

File tree

6 files changed

+785
-391
lines changed

6 files changed

+785
-391
lines changed

.github/workflows/dockerimage.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ jobs:
1313
name: madagascar/m8r
1414
username: ${{ secrets.DOCKER_USERNAME }}
1515
password: ${{ secrets.DOCKER_PASSWORD }}
16-
dockerfile: admin/docker/2.0-dev-tex/Dockerfile
16+
dockerfile: admin/docker/4.2-dev-tex/Dockerfile

admin/docker/4.2-dev-tex/Dockerfile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ RUN apt-get update && apt-get install -y \
3939
&& rm -rf /var/lib/apt/lists/*
4040

4141
# install python packages
42-
RUN pip install numpy scipy
42+
RUN pip install --break-system-packages numpy scipy
4343

4444
# get code from github
4545
RUN git clone https://github.com/ahay/src.git $HOME/RSFSRC
@@ -59,7 +59,6 @@ RUN apt-get update && apt-get install -y \
5959
texlive-fonts-recommended \
6060
texlive-bibtex-extra \
6161
texlive-lang-english \
62-
texlive-generic-extra \
6362
biber \
6463
--no-install-recommends \
6564
&& rm -rf /var/lib/apt/lists/*

api/python/mxarray.py

Lines changed: 179 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
## along with this program; if not, write to the Free Software
1717
## Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
1818

19+
import os
1920
import m8r
2021
import numpy as np
22+
import subprocess
2123

2224
try:
2325
import xarray as xr
@@ -31,17 +33,15 @@
3133

3234
def rsf_to_xarray(path, chunks = "auto"):
3335
"""Convert an RSF file to an xarray DataArray with optimal Dask lazy loading."""
34-
3536
if xr is None:
3637
raise ImportError("xarray is required.")
3738

3839
f = m8r.Input(path)
39-
4040
shape = f.shape()
4141
ndim = len(shape)
4242

4343
dims = []
44-
# units = []
44+
units = []
4545
coords = {}
4646
for axis in range(1, ndim+1):
4747
n = f.int(f"n{axis}")
@@ -50,21 +50,35 @@ def rsf_to_xarray(path, chunks = "auto"):
5050
label = f.string(f"label{axis}")
5151
if label == None:
5252
label = f"dim{axis}"
53-
# unit = f.string(f"unit{axis}")
53+
unit = f.string(f"unit{axis}")
54+
if unit == None:
55+
unit = "??"
5456

5557
coords[label] = np.arange(n) * d + o
5658
dims.append(label)
57-
# units.append(unit)
59+
units.append(unit)
5860

5961
# data = np.asarray(f)
6062
binFile = f.string("in")
61-
dtype = f.string("type")
62-
if dtype is None or dtype == "float":
63-
dtype = np.float32
64-
elif dtype == "int":
65-
dtype = np.int32
66-
else:
67-
raise ValueError(f"Unsupported data type: {dtype}")
63+
label = f.string("label")
64+
unit0 = f.string("unit")
65+
66+
rsf_type = getattr(f, 'type', None)
67+
if not rsf_type:
68+
fmt = f.string("data_format")
69+
if fmt and "complex" in fmt: rsf_type = 'complex'
70+
elif fmt and "int" in fmt: rsf_type = 'int'
71+
else: rsf_type = 'float'
72+
73+
dtype_map = {
74+
'float': np.float32,
75+
'int': np.int32,
76+
'complex': np.complex64,
77+
'uchar': np.uint8,
78+
'char': np.int8
79+
}
80+
dtype = dtype_map.get(rsf_type, np.float32)
81+
# -------------------------------
6882

6983
mm = np.memmap(
7084
binFile,
@@ -80,33 +94,61 @@ def rsf_to_xarray(path, chunks = "auto"):
8094

8195

8296
# covert c order to python order
83-
# data = data.reshape(shape_py)
8497
data = data.transpose(*reversed(range(data.ndim)))
8598

8699
ds = xr.DataArray(
87100
data,
88101
dims=dims,
89-
coords=coords
102+
coords=coords,
90103
)
104+
if label:
105+
ds.attrs['long_name'] = label
106+
if unit0:
107+
ds.attrs['units'] = unit0
108+
109+
for dim, unit in zip(dims, units):
110+
ds.coords[dim].attrs['units'] = unit
91111

92112
return ds
93113

94114
def xarray_to_rsf(ds, outpath):
95115
"""Convert an xarray Dataset to an RSF file."""
96-
97116
if xr is None:
98117
raise ImportError("xarray is required.")
99118

100119
if not isinstance(ds, xr.DataArray):
101120
raise ValueError("Input must be an xarray DataArray.")
102121

122+
if np.issubdtype(ds.dtype, np.complexfloating):
123+
rsf_type = 'complex'
124+
out_dtype = np.complex64
125+
fmt_str = "native_complex"
126+
elif np.issubdtype(ds.dtype, np.integer):
127+
rsf_type = 'int'
128+
out_dtype = np.int32
129+
fmt_str = "native_int"
130+
else:
131+
rsf_type = 'float'
132+
out_dtype = np.float32
133+
fmt_str = "native_float"
134+
# -------------------------------
135+
103136
data = ds.values
104137
data = data.transpose(*reversed(range(data.ndim)))
105138

106139
dims = ds.dims
140+
out = m8r.Output(outpath)
141+
label = ds.attrs.get('long_name', None)
142+
unit0 = ds.attrs.get('units', None)
107143

144+
# Set Type
145+
out.settype(rsf_type)
146+
out.put("data_format", fmt_str)
108147

109-
out = m8r.Output(outpath)
148+
if label:
149+
out.put("label", label)
150+
if unit0:
151+
out.put("unit", unit0)
110152

111153
for i, dim in enumerate(dims, start=1):
112154

@@ -123,8 +165,9 @@ def xarray_to_rsf(ds, outpath):
123165
out.put(f"o{i}", o)
124166
out.put(f"d{i}", d)
125167
out.put(f"label{i}", str(dim))
126-
127-
out.write(data.astype(np.float32))
168+
out.put(f"unit{i}", str(ds.coords[dim].attrs.get('units', '??')))
169+
170+
out.write(data.astype(out_dtype))
128171
out.close()
129172

130173
def rsf_to_xarrayds(path, chunks = "auto"):
@@ -133,4 +176,121 @@ def rsf_to_xarrayds(path, chunks = "auto"):
133176
da = rsf_to_xarray(path, chunks=chunks)
134177
ds = da.to_dataset(name="data")
135178

136-
return ds
179+
return ds
180+
181+
## Monkey patching m8r.Filter to handle xarray inputs/outputs
182+
183+
def _patched_setcommand(self, kw, args=[]):
184+
"""
185+
Patched version of Filter.setcommand to handle auxiliary xarrays.
186+
Example: Filter('sfprog')(velocity=my_xarray)
187+
"""
188+
if not hasattr(self, '_mx_aux_refs'):
189+
self._mx_aux_refs = []
190+
191+
for key, val in list(kw.items()):
192+
if isinstance(val, xr.DataArray):
193+
tmp_name = m8r.Temp()
194+
xarray_to_rsf(val, tmp_name)
195+
196+
f_obj = m8r.File(tmp_name, temp=True)
197+
self._mx_aux_refs.append(f_obj)
198+
199+
kw[key] = str(tmp_name)
200+
201+
new_args = []
202+
for val in args:
203+
if isinstance(val, xr.DataArray):
204+
tmp_name = m8r.Temp()
205+
xarray_to_rsf(val, tmp_name)
206+
f_obj = m8r.File(tmp_name, temp=True)
207+
self._mx_aux_refs.append(f_obj)
208+
new_args.append(str(tmp_name))
209+
else:
210+
new_args.append(val)
211+
212+
# Call the original function with strings/files only
213+
original_func = getattr(self.__class__, '_original_setcommand_func', None)
214+
if not original_func:
215+
# Fallback if class attribute missing
216+
real_module = m8r.wrapped if hasattr(m8r, 'wrapped') else m8r
217+
original_func = real_module.Filter._original_setcommand_func
218+
219+
return original_func(self, kw, new_args)
220+
221+
222+
def _patched_apply(self, *srcs):
223+
"""
224+
Patched version of Filter.apply to handle xarray inputs/outputs.
225+
"""
226+
227+
# Check if any input is an xarray
228+
if not any(isinstance(s, xr.DataArray) for s in srcs):
229+
# Fallback to original
230+
if hasattr(self.__class__, '_original_apply_func'):
231+
return self.__class__._original_apply_func(self, *srcs)
232+
real_module = m8r.wrapped if hasattr(m8r, 'wrapped') else m8r
233+
return real_module.Filter._original_apply_func(self, *srcs)
234+
235+
# Handle xarray Input
236+
clean = []
237+
rsf_inputs = []
238+
try:
239+
for s in srcs:
240+
if isinstance(s, xr.DataArray):
241+
tmp = m8r.Temp()
242+
xarray_to_rsf(s, tmp)
243+
rsf_inputs.append(str(tmp))
244+
clean.extend([str(tmp), str(tmp)+'@'])
245+
else:
246+
rsf_inputs.append(str(s))
247+
248+
out_file = m8r.Temp()
249+
250+
# Split command for correct pipe handling
251+
first, pipe_char, rest = self.command.partition('|')
252+
cmd_parts = [first]
253+
254+
if len(rsf_inputs) > 0:
255+
cmd_parts.append(f"< {rsf_inputs[0]}")
256+
if len(rsf_inputs) > 1:
257+
cmd_parts.extend(rsf_inputs[1:])
258+
259+
if pipe_char:
260+
cmd_parts.append(pipe_char)
261+
cmd_parts.append(rest)
262+
263+
cmd_parts.append(f"> {out_file}")
264+
265+
full_cmd = " ".join(cmd_parts)
266+
267+
result = subprocess.run(full_cmd, shell=True, stderr=subprocess.PIPE, stdout=subprocess.PIPE, universal_newlines=True)
268+
269+
if result.returncode != 0:
270+
raise RuntimeError(f"{full_cmd}\n{result.stderr}")
271+
272+
if self.plot:
273+
return m8r.Vplot(out_file, temp=True)
274+
275+
res = rsf_to_xarray(out_file)
276+
clean.extend([str(out_file), str(out_file)+'@'])
277+
return res
278+
279+
finally:
280+
for f in clean:
281+
if os.path.exists(f):
282+
try: os.unlink(f)
283+
except: pass
284+
285+
if hasattr(m8r, 'wrapped'):
286+
real_m8r_module = m8r.wrapped
287+
else:
288+
real_m8r_module = m8r
289+
290+
if not hasattr(real_m8r_module.Filter, '_original_apply_func'):
291+
real_m8r_module.Filter._original_apply_func = real_m8r_module.Filter.apply
292+
real_m8r_module.Filter.apply = _patched_apply
293+
294+
if not hasattr(real_m8r_module.Filter, '_original_setcommand_func'):
295+
real_m8r_module.Filter._original_setcommand_func = real_m8r_module.Filter.setcommand
296+
real_m8r_module.Filter.setcommand = _patched_setcommand

0 commit comments

Comments
 (0)