Skip to content

Commit a8cb40e

Browse files
authored
Support scipy special function with tuple output (#3139)
1 parent 3e3f4fd commit a8cb40e

File tree

7 files changed

+154
-5
lines changed

7 files changed

+154
-5
lines changed

docs/source/reference/tensor/special.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Bessel functions
3535
mars.tensor.special.hankel2e
3636

3737

38-
Error function
38+
Error functions and fresnel integrals
3939
--------------
4040

4141
.. autosummary::
@@ -48,6 +48,7 @@ Error function
4848
mars.tensor.special.erfi
4949
mars.tensor.special.erfinv
5050
mars.tensor.special.erfcinv
51+
mars.tensor.special.fresnel
5152

5253

5354
Ellipsoidal harmonics

mars/tensor/special/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
TensorErfinv,
2929
erfcinv,
3030
TensorErfcinv,
31+
fresnel,
32+
TensorFresnel,
3133
)
3234
from .gamma_funcs import (
3335
gamma,

mars/tensor/special/core.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import scipy.special as spspecial
1516

17+
from ...core import ExecutableTuple
1618
from ... import opcodes
19+
from ..datasource import tensor as astensor
1720
from ..arithmetic.core import TensorUnaryOp, TensorBinOp, TensorMultiOp
1821
from ..array_utils import (
1922
np,
@@ -112,3 +115,59 @@ def execute(cls, ctx, op):
112115
if ret.dtype != op.dtype:
113116
ret = ret.astype(op.dtype)
114117
ctx[op.outputs[0].key] = ret
118+
119+
120+
class TensorTupleOp(TensorSpecialUnaryOp):
121+
@property
122+
def output_limit(self):
123+
return self._n_outputs
124+
125+
def __call__(self, x, out=None):
126+
x = astensor(x)
127+
128+
if out is not None:
129+
if not isinstance(out, ExecutableTuple):
130+
raise TypeError(
131+
f"out should be ExecutableTuple object, got {type(out)} instead"
132+
)
133+
if len(out) != self._n_outputs:
134+
raise TypeError(
135+
f"out should be an ExecutableTuple object with {self._n_outputs} elements, got {len(out)} instead"
136+
)
137+
138+
func = getattr(spspecial, self._func_name)
139+
res = func(np.ones(x.shape, dtype=x.dtype))
140+
res_tensors = self.new_tensors(
141+
[x],
142+
kws=[
143+
{
144+
"side": f"{self._func_name}[{i}]",
145+
"dtype": output.dtype,
146+
"shape": output.shape,
147+
}
148+
for i, output in enumerate(res)
149+
],
150+
)
151+
152+
if out is None:
153+
return ExecutableTuple(res_tensors)
154+
155+
for res_tensor, out_tensor in zip(res_tensors, out):
156+
out_tensor.data = res_tensor.data
157+
return out
158+
159+
@classmethod
160+
def execute(cls, ctx, op):
161+
inputs, device_id, xp = as_same_device(
162+
[ctx[c.key] for c in op.inputs], device=op.device, ret_extra=True
163+
)
164+
165+
with device(device_id):
166+
with np.errstate(**op.err):
167+
if op.is_gpu():
168+
ret = cls._execute_gpu(op, xp, inputs[0])
169+
else:
170+
ret = cls._execute_cpu(op, xp, inputs[0])
171+
172+
for output, ret_element in zip(op.outputs, ret):
173+
ctx[output.key] = ret_element

mars/tensor/special/err_fresnel.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616

1717
from ..arithmetic.utils import arithmetic_operand
1818
from ..utils import infer_dtype, implement_scipy
19-
from .core import TensorSpecialUnaryOp, _register_special_op
19+
from .core import (
20+
TensorSpecialUnaryOp,
21+
TensorTupleOp,
22+
_register_special_op,
23+
)
2024

2125

2226
@_register_special_op
@@ -55,6 +59,12 @@ class TensorErfcinv(TensorSpecialUnaryOp):
5559
_func_name = "erfcinv"
5660

5761

62+
@_register_special_op
63+
class TensorFresnel(TensorTupleOp):
64+
_func_name = "fresnel"
65+
_n_outputs = 2
66+
67+
5868
@implement_scipy(spspecial.erf)
5969
@infer_dtype(spspecial.erf)
6070
def erf(x, out=None, where=None, **kwargs):
@@ -140,3 +150,10 @@ def erfinv(x, out=None, where=None, **kwargs):
140150
def erfcinv(x, out=None, where=None, **kwargs):
141151
op = TensorErfcinv(**kwargs)
142152
return op(x, out=out, where=where)
153+
154+
155+
@implement_scipy(spspecial.fresnel)
156+
@infer_dtype(spspecial.fresnel, multi_outputs=True)
157+
def fresnel(x, out=None, **kwargs):
158+
op = TensorFresnel(**kwargs)
159+
return op(x, out=out)

mars/tensor/special/tests/test_special.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@
2828
ellipkinc as scipy_ellipkinc,
2929
ellipe as scipy_ellipe,
3030
ellipeinc as scipy_ellipeinc,
31+
fresnel as scipy_fresnel,
3132
betainc as scipy_betainc,
3233
)
3334

3435
from ....lib.version import parse as parse_version
35-
from ....core import tile
36+
from ....core import tile, ExecutableTuple
3637
from ... import tensor
3738
from ..err_fresnel import (
3839
erf,
@@ -47,6 +48,8 @@
4748
TensorErfinv,
4849
erfcinv,
4950
TensorErfcinv,
51+
fresnel,
52+
TensorFresnel,
5053
)
5154
from ..gamma_funcs import (
5255
gammaln,
@@ -276,6 +279,48 @@ def test_erfcinv():
276279
assert c.shape == c.inputs[0].shape
277280

278281

282+
def test_fresnel():
283+
raw = np.random.rand(10, 8, 5)
284+
t = tensor(raw, chunk_size=3)
285+
286+
r = fresnel(t)
287+
expect = scipy_fresnel(raw)
288+
289+
assert isinstance(r, ExecutableTuple)
290+
assert len(r) == 2
291+
292+
for i in range(len(r)):
293+
assert r[i].shape == expect[i].shape
294+
assert r[i].dtype == expect[i].dtype
295+
assert isinstance(r[i].op, TensorFresnel)
296+
297+
non_tuple_out = tensor(raw, chunk_size=3)
298+
with pytest.raises(TypeError):
299+
r = fresnel(t, non_tuple_out)
300+
301+
mismatch_size_tuple = ExecutableTuple([t])
302+
with pytest.raises(TypeError):
303+
r = fresnel(t, mismatch_size_tuple)
304+
305+
out = ExecutableTuple([t, t])
306+
r_out = fresnel(t, out=out)
307+
308+
assert isinstance(out, ExecutableTuple)
309+
assert isinstance(r_out, ExecutableTuple)
310+
311+
assert len(out) == 2
312+
assert len(r_out) == 2
313+
314+
for r_output, expected_output, out_output in zip(r, expect, out):
315+
assert r_output.shape == expected_output.shape
316+
assert r_output.dtype == expected_output.dtype
317+
assert isinstance(r_output.op, TensorFresnel)
318+
319+
assert out_output.shape == expected_output.shape
320+
assert out_output.dtype == expected_output.dtype
321+
assert isinstance(out_output.op, TensorFresnel)
322+
323+
279324
def test_beta_inc():
280325
raw1 = np.random.rand(4, 3, 2)
281326
raw2 = np.random.rand(4, 3, 2)

mars/tensor/special/tests/test_special_execution.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,25 @@ def test_quintuple_execution(setup, func):
299299

300300
expected = sp_func(raw1.toarray(), raw2, raw3, raw4, raw5)
301301
np.testing.assert_array_equal(result.toarray(), expected)
302+
303+
304+
@pytest.mark.parametrize(
305+
"func",
306+
[
307+
"fresnel",
308+
],
309+
)
310+
def test_unary_tuple_execution(setup, func):
311+
sp_func = getattr(spspecial, func)
312+
mt_func = getattr(mt_special, func)
313+
314+
raw = np.random.rand(10, 8, 6)
315+
a = tensor(raw, chunk_size=3)
316+
317+
r = mt_func(a)
318+
319+
result = r.execute().fetch()
320+
expected = sp_func(raw)
321+
322+
for actual_output, expected_output in zip(result, expected):
323+
np.testing.assert_array_equal(actual_output, expected_output)

mars/tensor/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def call(*tensors, **kw):
229229
return inner
230230

231231

232-
def infer_dtype(np_func, empty=True, reverse=False, check=True):
232+
def infer_dtype(np_func, multi_outputs=False, empty=True, reverse=False, check=True):
233233
def make_arg(arg):
234234
if empty:
235235
return np.empty((1,) * max(1, arg.ndim), dtype=arg.dtype)
@@ -267,7 +267,10 @@ def h(*tensors, **kw):
267267
# that implements __tensor_ufunc__
268268
try:
269269
with np.errstate(all="ignore"):
270-
dtype = np_func(*args, **np_kw).dtype
270+
if multi_outputs:
271+
dtype = np_func(*args, **np_kw)[0].dtype
272+
else:
273+
dtype = np_func(*args, **np_kw).dtype
271274
except: # noqa: E722
272275
dtype = None
273276

0 commit comments

Comments
 (0)