1616## along with this program; if not, write to the Free Software
1717## Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
1818
19+ import os
1920import m8r
2021import numpy as np
22+ import subprocess
2123
2224try :
2325 import xarray as xr
3133
3234def rsf_to_xarray (path , chunks = "auto" ):
3335 """Convert an RSF file to an xarray DataArray with optimal Dask lazy loading."""
34-
3536 if xr is None :
3637 raise ImportError ("xarray is required." )
3738
3839 f = m8r .Input (path )
39-
4040 shape = f .shape ()
4141 ndim = len (shape )
4242
4343 dims = []
44- # units = []
44+ units = []
4545 coords = {}
4646 for axis in range (1 , ndim + 1 ):
4747 n = f .int (f"n{ axis } " )
@@ -50,21 +50,35 @@ def rsf_to_xarray(path, chunks = "auto"):
5050 label = f .string (f"label{ axis } " )
5151 if label == None :
5252 label = f"dim{ axis } "
53- # unit = f.string(f"unit{axis}")
53+ unit = f .string (f"unit{ axis } " )
54+ if unit == None :
55+ unit = "??"
5456
5557 coords [label ] = np .arange (n ) * d + o
5658 dims .append (label )
57- # units.append(unit)
59+ units .append (unit )
5860
5961 # data = np.asarray(f)
6062 binFile = f .string ("in" )
61- dtype = f .string ("type" )
62- if dtype is None or dtype == "float" :
63- dtype = np .float32
64- elif dtype == "int" :
65- dtype = np .int32
66- else :
67- raise ValueError (f"Unsupported data type: { dtype } " )
63+ label = f .string ("label" )
64+ unit0 = f .string ("unit" )
65+
66+ rsf_type = getattr (f , 'type' , None )
67+ if not rsf_type :
68+ fmt = f .string ("data_format" )
69+ if fmt and "complex" in fmt : rsf_type = 'complex'
70+ elif fmt and "int" in fmt : rsf_type = 'int'
71+ else : rsf_type = 'float'
72+
73+ dtype_map = {
74+ 'float' : np .float32 ,
75+ 'int' : np .int32 ,
76+ 'complex' : np .complex64 ,
77+ 'uchar' : np .uint8 ,
78+ 'char' : np .int8
79+ }
80+ dtype = dtype_map .get (rsf_type , np .float32 )
81+ # -------------------------------
6882
6983 mm = np .memmap (
7084 binFile ,
@@ -80,33 +94,61 @@ def rsf_to_xarray(path, chunks = "auto"):
8094
8195
8296 # covert c order to python order
83- # data = data.reshape(shape_py)
8497 data = data .transpose (* reversed (range (data .ndim )))
8598
8699 ds = xr .DataArray (
87100 data ,
88101 dims = dims ,
89- coords = coords
102+ coords = coords ,
90103 )
104+ if label :
105+ ds .attrs ['long_name' ] = label
106+ if unit0 :
107+ ds .attrs ['units' ] = unit0
108+
109+ for dim , unit in zip (dims , units ):
110+ ds .coords [dim ].attrs ['units' ] = unit
91111
92112 return ds
93113
94114def xarray_to_rsf (ds , outpath ):
95115 """Convert an xarray Dataset to an RSF file."""
96-
97116 if xr is None :
98117 raise ImportError ("xarray is required." )
99118
100119 if not isinstance (ds , xr .DataArray ):
101120 raise ValueError ("Input must be an xarray DataArray." )
102121
122+ if np .issubdtype (ds .dtype , np .complexfloating ):
123+ rsf_type = 'complex'
124+ out_dtype = np .complex64
125+ fmt_str = "native_complex"
126+ elif np .issubdtype (ds .dtype , np .integer ):
127+ rsf_type = 'int'
128+ out_dtype = np .int32
129+ fmt_str = "native_int"
130+ else :
131+ rsf_type = 'float'
132+ out_dtype = np .float32
133+ fmt_str = "native_float"
134+ # -------------------------------
135+
103136 data = ds .values
104137 data = data .transpose (* reversed (range (data .ndim )))
105138
106139 dims = ds .dims
140+ out = m8r .Output (outpath )
141+ label = ds .attrs .get ('long_name' , None )
142+ unit0 = ds .attrs .get ('units' , None )
107143
144+ # Set Type
145+ out .settype (rsf_type )
146+ out .put ("data_format" , fmt_str )
108147
109- out = m8r .Output (outpath )
148+ if label :
149+ out .put ("label" , label )
150+ if unit0 :
151+ out .put ("unit" , unit0 )
110152
111153 for i , dim in enumerate (dims , start = 1 ):
112154
@@ -123,8 +165,9 @@ def xarray_to_rsf(ds, outpath):
123165 out .put (f"o{ i } " , o )
124166 out .put (f"d{ i } " , d )
125167 out .put (f"label{ i } " , str (dim ))
126-
127- out .write (data .astype (np .float32 ))
168+ out .put (f"unit{ i } " , str (ds .coords [dim ].attrs .get ('units' , '??' )))
169+
170+ out .write (data .astype (out_dtype ))
128171 out .close ()
129172
130173def rsf_to_xarrayds (path , chunks = "auto" ):
@@ -133,4 +176,121 @@ def rsf_to_xarrayds(path, chunks = "auto"):
133176 da = rsf_to_xarray (path , chunks = chunks )
134177 ds = da .to_dataset (name = "data" )
135178
136- return ds
179+ return ds
180+
181+ ## Monkey patching m8r.Filter to handle xarray inputs/outputs
182+
183+ def _patched_setcommand (self , kw , args = []):
184+ """
185+ Patched version of Filter.setcommand to handle auxiliary xarrays.
186+ Example: Filter('sfprog')(velocity=my_xarray)
187+ """
188+ if not hasattr (self , '_mx_aux_refs' ):
189+ self ._mx_aux_refs = []
190+
191+ for key , val in list (kw .items ()):
192+ if isinstance (val , xr .DataArray ):
193+ tmp_name = m8r .Temp ()
194+ xarray_to_rsf (val , tmp_name )
195+
196+ f_obj = m8r .File (tmp_name , temp = True )
197+ self ._mx_aux_refs .append (f_obj )
198+
199+ kw [key ] = str (tmp_name )
200+
201+ new_args = []
202+ for val in args :
203+ if isinstance (val , xr .DataArray ):
204+ tmp_name = m8r .Temp ()
205+ xarray_to_rsf (val , tmp_name )
206+ f_obj = m8r .File (tmp_name , temp = True )
207+ self ._mx_aux_refs .append (f_obj )
208+ new_args .append (str (tmp_name ))
209+ else :
210+ new_args .append (val )
211+
212+ # Call the original function with strings/files only
213+ original_func = getattr (self .__class__ , '_original_setcommand_func' , None )
214+ if not original_func :
215+ # Fallback if class attribute missing
216+ real_module = m8r .wrapped if hasattr (m8r , 'wrapped' ) else m8r
217+ original_func = real_module .Filter ._original_setcommand_func
218+
219+ return original_func (self , kw , new_args )
220+
221+
222+ def _patched_apply (self , * srcs ):
223+ """
224+ Patched version of Filter.apply to handle xarray inputs/outputs.
225+ """
226+
227+ # Check if any input is an xarray
228+ if not any (isinstance (s , xr .DataArray ) for s in srcs ):
229+ # Fallback to original
230+ if hasattr (self .__class__ , '_original_apply_func' ):
231+ return self .__class__ ._original_apply_func (self , * srcs )
232+ real_module = m8r .wrapped if hasattr (m8r , 'wrapped' ) else m8r
233+ return real_module .Filter ._original_apply_func (self , * srcs )
234+
235+ # Handle xarray Input
236+ clean = []
237+ rsf_inputs = []
238+ try :
239+ for s in srcs :
240+ if isinstance (s , xr .DataArray ):
241+ tmp = m8r .Temp ()
242+ xarray_to_rsf (s , tmp )
243+ rsf_inputs .append (str (tmp ))
244+ clean .extend ([str (tmp ), str (tmp )+ '@' ])
245+ else :
246+ rsf_inputs .append (str (s ))
247+
248+ out_file = m8r .Temp ()
249+
250+ # Split command for correct pipe handling
251+ first , pipe_char , rest = self .command .partition ('|' )
252+ cmd_parts = [first ]
253+
254+ if len (rsf_inputs ) > 0 :
255+ cmd_parts .append (f"< { rsf_inputs [0 ]} " )
256+ if len (rsf_inputs ) > 1 :
257+ cmd_parts .extend (rsf_inputs [1 :])
258+
259+ if pipe_char :
260+ cmd_parts .append (pipe_char )
261+ cmd_parts .append (rest )
262+
263+ cmd_parts .append (f"> { out_file } " )
264+
265+ full_cmd = " " .join (cmd_parts )
266+
267+ result = subprocess .run (full_cmd , shell = True , stderr = subprocess .PIPE , stdout = subprocess .PIPE , universal_newlines = True )
268+
269+ if result .returncode != 0 :
270+ raise RuntimeError (f"{ full_cmd } \n { result .stderr } " )
271+
272+ if self .plot :
273+ return m8r .Vplot (out_file , temp = True )
274+
275+ res = rsf_to_xarray (out_file )
276+ clean .extend ([str (out_file ), str (out_file )+ '@' ])
277+ return res
278+
279+ finally :
280+ for f in clean :
281+ if os .path .exists (f ):
282+ try : os .unlink (f )
283+ except : pass
284+
285+ if hasattr (m8r , 'wrapped' ):
286+ real_m8r_module = m8r .wrapped
287+ else :
288+ real_m8r_module = m8r
289+
290+ if not hasattr (real_m8r_module .Filter , '_original_apply_func' ):
291+ real_m8r_module .Filter ._original_apply_func = real_m8r_module .Filter .apply
292+ real_m8r_module .Filter .apply = _patched_apply
293+
294+ if not hasattr (real_m8r_module .Filter , '_original_setcommand_func' ):
295+ real_m8r_module .Filter ._original_setcommand_func = real_m8r_module .Filter .setcommand
296+ real_m8r_module .Filter .setcommand = _patched_setcommand
0 commit comments