Skip to content

Commit 578cf63

Browse files
authored
Stop calling user funcs when dtypes is specified (#2587)
1 parent 54636e6 commit 578cf63

File tree

2 files changed

+76
-45
lines changed

2 files changed

+76
-45
lines changed

mars/dataframe/base/map_chunk.py

Lines changed: 75 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pandas as pd
1717

1818
from ... import opcodes
19-
from ...core import recursive_tile
19+
from ...core import recursive_tile, get_output_types
2020
from ...core.custom_log import redirect_custom_log
2121
from ...serialization.serializables import (
2222
KeyField,
@@ -90,67 +90,98 @@ def _set_inputs(self, inputs):
9090
super()._set_inputs(inputs)
9191
self._input = self._inputs[0]
9292

93-
def __call__(self, df_or_series, index=None, dtypes=None):
93+
def _infer_attrs_by_call(self, df_or_series):
9494
test_obj = (
9595
build_df(df_or_series, size=2)
9696
if df_or_series.ndim == 2
9797
else build_series(df_or_series, size=2, name=df_or_series.name)
9898
)
99-
output_type = self._output_types[0] if self.output_types else None
100-
101-
# try run to infer meta
102-
try:
103-
kwargs = self.kwargs or dict()
104-
if self.with_chunk_index:
105-
kwargs["chunk_index"] = (0,) * df_or_series.ndim
106-
with np.errstate(all="ignore"), quiet_stdio():
107-
obj = self._func(test_obj, *self._args, **kwargs)
108-
except: # noqa: E722 # nosec
109-
if df_or_series.ndim == 1 or output_type == OutputType.series:
110-
obj = pd.Series([], dtype=np.dtype(object))
111-
elif output_type == OutputType.dataframe and dtypes is not None:
112-
obj = build_empty_df(dtypes)
99+
kwargs = self.kwargs or dict()
100+
if self.with_chunk_index:
101+
kwargs["chunk_index"] = (0,) * df_or_series.ndim
102+
with np.errstate(all="ignore"), quiet_stdio():
103+
obj = self._func(test_obj, *self._args, **kwargs)
104+
105+
if obj.ndim == 2:
106+
output_type = OutputType.dataframe
107+
dtypes = obj.dtypes
108+
if obj.shape == test_obj.shape:
109+
shape = (df_or_series.shape[0], len(dtypes))
110+
else: # pragma: no cover
111+
shape = (np.nan, len(dtypes))
112+
else:
113+
output_type = OutputType.series
114+
dtypes = pd.Series([obj.dtype], name=obj.name)
115+
if obj.shape == test_obj.shape:
116+
shape = df_or_series.shape
113117
else:
114-
raise TypeError(
115-
"Cannot determine `output_type`, "
116-
"you have to specify it as `dataframe` or `series`, "
117-
"for dataframe, `dtypes` is required as well "
118-
"if output_type='dataframe'"
119-
)
118+
shape = (np.nan,)
120119

121-
if getattr(obj, "ndim", 0) == 1 or output_type == OutputType.series:
122-
shape = self._kwargs.pop("shape", None)
123-
if shape is None:
124-
# series
125-
if obj.shape == test_obj.shape:
126-
shape = df_or_series.shape
127-
else:
128-
shape = (np.nan,)
129-
if index is None:
130-
index = obj.index
120+
index_value = parse_index(
121+
obj.index, df_or_series, self._func, self._args, self._kwargs
122+
)
123+
return {
124+
"output_type": output_type,
125+
"index_value": index_value,
126+
"shape": shape,
127+
"dtypes": dtypes,
128+
}
129+
130+
def __call__(self, df_or_series, index=None, dtypes=None):
131+
output_type = (
132+
self.output_types[0]
133+
if self.output_types
134+
else get_output_types(df_or_series)[0]
135+
)
136+
shape = self._kwargs.pop("shape", None)
137+
138+
if dtypes is not None:
139+
index = index if index is not None else pd.RangeIndex(-1)
131140
index_value = parse_index(
132141
index, df_or_series, self._func, self._args, self._kwargs
133142
)
143+
if shape is None: # pragma: no branch
144+
shape = (
145+
(np.nan,)
146+
if output_type == OutputType.series
147+
else (np.nan, len(dtypes))
148+
)
149+
else:
150+
# try run to infer meta
151+
try:
152+
attrs = self._infer_attrs_by_call(df_or_series)
153+
output_type = attrs["output_type"]
154+
index_value = attrs["index_value"]
155+
shape = attrs["shape"]
156+
dtypes = attrs["dtypes"]
157+
except: # noqa: E722 # nosec
158+
if df_or_series.ndim == 1 or output_type == OutputType.series:
159+
output_type = OutputType.series
160+
index = index if index is not None else pd.RangeIndex(-1)
161+
index_value = parse_index(
162+
index, df_or_series, self._func, self._args, self._kwargs
163+
)
164+
dtypes = pd.Series([np.dtype(object)])
165+
shape = (np.nan,)
166+
else:
167+
raise TypeError(
168+
"Cannot determine `output_type`, "
169+
"you have to specify it as `dataframe` or `series`, "
170+
"for dataframe, `dtypes` is required as well "
171+
"if output_type='dataframe'"
172+
)
173+
174+
if output_type == OutputType.series:
134175
return self.new_series(
135176
[df_or_series],
136-
dtype=obj.dtype,
177+
dtype=dtypes.iloc[0],
137178
shape=shape,
138179
index_value=index_value,
139-
name=obj.name,
180+
name=dtypes.name,
140181
)
141182
else:
142-
dtypes = dtypes if dtypes is not None else obj.dtypes
143183
# dataframe
144-
if obj.shape == test_obj.shape:
145-
shape = (df_or_series.shape[0], len(dtypes))
146-
else:
147-
shape = (np.nan, len(dtypes))
148184
columns_value = parse_index(dtypes.index, store_data=True)
149-
if index is None:
150-
index = obj.index
151-
index_value = parse_index(
152-
index, df_or_series, self._func, self._args, self._kwargs
153-
)
154185
return self.new_dataframe(
155186
[df_or_series],
156187
shape=shape,

mars/dataframe/base/tests/test_base_execution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1791,7 +1791,7 @@ def f4(pdf):
17911791
r = df2.map_chunk(
17921792
lambda x: x["a"].apply(pd.Series), output_type="dataframe", dtypes=dtypes
17931793
)
1794-
assert r.shape == (2, 3)
1794+
assert r.shape == (np.nan, 3)
17951795
pd.testing.assert_series_equal(r.dtypes, dtypes)
17961796
result = r.execute().fetch()
17971797
expected = raw2.apply(lambda x: x["a"], axis=1, result_type="expand")

0 commit comments

Comments
 (0)