|
16 | 16 | import pandas as pd |
17 | 17 |
|
18 | 18 | from ... import opcodes |
19 | | -from ...core import recursive_tile |
| 19 | +from ...core import recursive_tile, get_output_types |
20 | 20 | from ...core.custom_log import redirect_custom_log |
21 | 21 | from ...serialization.serializables import ( |
22 | 22 | KeyField, |
@@ -90,67 +90,98 @@ def _set_inputs(self, inputs): |
90 | 90 | super()._set_inputs(inputs) |
91 | 91 | self._input = self._inputs[0] |
92 | 92 |
|
93 | | - def __call__(self, df_or_series, index=None, dtypes=None): |
| 93 | + def _infer_attrs_by_call(self, df_or_series): |
94 | 94 | test_obj = ( |
95 | 95 | build_df(df_or_series, size=2) |
96 | 96 | if df_or_series.ndim == 2 |
97 | 97 | else build_series(df_or_series, size=2, name=df_or_series.name) |
98 | 98 | ) |
99 | | - output_type = self._output_types[0] if self.output_types else None |
100 | | - |
101 | | - # try run to infer meta |
102 | | - try: |
103 | | - kwargs = self.kwargs or dict() |
104 | | - if self.with_chunk_index: |
105 | | - kwargs["chunk_index"] = (0,) * df_or_series.ndim |
106 | | - with np.errstate(all="ignore"), quiet_stdio(): |
107 | | - obj = self._func(test_obj, *self._args, **kwargs) |
108 | | - except: # noqa: E722 # nosec |
109 | | - if df_or_series.ndim == 1 or output_type == OutputType.series: |
110 | | - obj = pd.Series([], dtype=np.dtype(object)) |
111 | | - elif output_type == OutputType.dataframe and dtypes is not None: |
112 | | - obj = build_empty_df(dtypes) |
| 99 | + kwargs = self.kwargs or dict() |
| 100 | + if self.with_chunk_index: |
| 101 | + kwargs["chunk_index"] = (0,) * df_or_series.ndim |
| 102 | + with np.errstate(all="ignore"), quiet_stdio(): |
| 103 | + obj = self._func(test_obj, *self._args, **kwargs) |
| 104 | + |
| 105 | + if obj.ndim == 2: |
| 106 | + output_type = OutputType.dataframe |
| 107 | + dtypes = obj.dtypes |
| 108 | + if obj.shape == test_obj.shape: |
| 109 | + shape = (df_or_series.shape[0], len(dtypes)) |
| 110 | + else: # pragma: no cover |
| 111 | + shape = (np.nan, len(dtypes)) |
| 112 | + else: |
| 113 | + output_type = OutputType.series |
| 114 | + dtypes = pd.Series([obj.dtype], name=obj.name) |
| 115 | + if obj.shape == test_obj.shape: |
| 116 | + shape = df_or_series.shape |
113 | 117 | else: |
114 | | - raise TypeError( |
115 | | - "Cannot determine `output_type`, " |
116 | | - "you have to specify it as `dataframe` or `series`, " |
117 | | - "for dataframe, `dtypes` is required as well " |
118 | | - "if output_type='dataframe'" |
119 | | - ) |
| 118 | + shape = (np.nan,) |
120 | 119 |
|
121 | | - if getattr(obj, "ndim", 0) == 1 or output_type == OutputType.series: |
122 | | - shape = self._kwargs.pop("shape", None) |
123 | | - if shape is None: |
124 | | - # series |
125 | | - if obj.shape == test_obj.shape: |
126 | | - shape = df_or_series.shape |
127 | | - else: |
128 | | - shape = (np.nan,) |
129 | | - if index is None: |
130 | | - index = obj.index |
| 120 | + index_value = parse_index( |
| 121 | + obj.index, df_or_series, self._func, self._args, self._kwargs |
| 122 | + ) |
| 123 | + return { |
| 124 | + "output_type": output_type, |
| 125 | + "index_value": index_value, |
| 126 | + "shape": shape, |
| 127 | + "dtypes": dtypes, |
| 128 | + } |
| 129 | + |
| 130 | + def __call__(self, df_or_series, index=None, dtypes=None): |
| 131 | + output_type = ( |
| 132 | + self.output_types[0] |
| 133 | + if self.output_types |
| 134 | + else get_output_types(df_or_series)[0] |
| 135 | + ) |
| 136 | + shape = self._kwargs.pop("shape", None) |
| 137 | + |
| 138 | + if dtypes is not None: |
| 139 | + index = index if index is not None else pd.RangeIndex(-1) |
131 | 140 | index_value = parse_index( |
132 | 141 | index, df_or_series, self._func, self._args, self._kwargs |
133 | 142 | ) |
| 143 | + if shape is None: # pragma: no branch |
| 144 | + shape = ( |
| 145 | + (np.nan,) |
| 146 | + if output_type == OutputType.series |
| 147 | + else (np.nan, len(dtypes)) |
| 148 | + ) |
| 149 | + else: |
| 150 | + # try run to infer meta |
| 151 | + try: |
| 152 | + attrs = self._infer_attrs_by_call(df_or_series) |
| 153 | + output_type = attrs["output_type"] |
| 154 | + index_value = attrs["index_value"] |
| 155 | + shape = attrs["shape"] |
| 156 | + dtypes = attrs["dtypes"] |
| 157 | + except: # noqa: E722 # nosec |
| 158 | + if df_or_series.ndim == 1 or output_type == OutputType.series: |
| 159 | + output_type = OutputType.series |
| 160 | + index = index if index is not None else pd.RangeIndex(-1) |
| 161 | + index_value = parse_index( |
| 162 | + index, df_or_series, self._func, self._args, self._kwargs |
| 163 | + ) |
| 164 | + dtypes = pd.Series([np.dtype(object)]) |
| 165 | + shape = (np.nan,) |
| 166 | + else: |
| 167 | + raise TypeError( |
| 168 | + "Cannot determine `output_type`, " |
| 169 | + "you have to specify it as `dataframe` or `series`, " |
| 170 | + "for dataframe, `dtypes` is required as well " |
| 171 | + "if output_type='dataframe'" |
| 172 | + ) |
| 173 | + |
| 174 | + if output_type == OutputType.series: |
134 | 175 | return self.new_series( |
135 | 176 | [df_or_series], |
136 | | - dtype=obj.dtype, |
| 177 | + dtype=dtypes.iloc[0], |
137 | 178 | shape=shape, |
138 | 179 | index_value=index_value, |
139 | | - name=obj.name, |
| 180 | + name=dtypes.name, |
140 | 181 | ) |
141 | 182 | else: |
142 | | - dtypes = dtypes if dtypes is not None else obj.dtypes |
143 | 183 | # dataframe |
144 | | - if obj.shape == test_obj.shape: |
145 | | - shape = (df_or_series.shape[0], len(dtypes)) |
146 | | - else: |
147 | | - shape = (np.nan, len(dtypes)) |
148 | 184 | columns_value = parse_index(dtypes.index, store_data=True) |
149 | | - if index is None: |
150 | | - index = obj.index |
151 | | - index_value = parse_index( |
152 | | - index, df_or_series, self._func, self._args, self._kwargs |
153 | | - ) |
154 | 185 | return self.new_dataframe( |
155 | 186 | [df_or_series], |
156 | 187 | shape=shape, |
|
0 commit comments