Skip to content

Commit 14b9569

Browse files
authored
Optimize tileable graph construction (#2583)
1 parent 7ad7e03 commit 14b9569

File tree

29 files changed

+476
-270
lines changed

29 files changed

+476
-270
lines changed

mars/_utils.pyx

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import pickle
1818
import pkgutil
1919
import types
2020
import uuid
21-
from collections import deque
2221
from datetime import date, datetime, timedelta, tzinfo
2322
from enum import Enum
2423
from functools import lru_cache, partial
@@ -142,7 +141,7 @@ cdef class TypeDispatcher:
142141
cdef inline build_canonical_bytes(tuple args, kwargs):
143142
if kwargs:
144143
args = args + (kwargs,)
145-
return str([tokenize_handler(arg) for arg in args]).encode('utf-8')
144+
return pickle.dumps(tokenize_handler(args))
146145

147146

148147
def tokenize(*args, **kwargs):
@@ -155,26 +154,31 @@ def tokenize_int(*args, **kwargs):
155154

156155
cdef class Tokenizer(TypeDispatcher):
157156
def __call__(self, object obj, *args, **kwargs):
158-
if hasattr(obj, '__mars_tokenize__') and not isinstance(obj, type):
159-
return super().__call__(obj.__mars_tokenize__(), *args, **kwargs)
160-
if callable(obj):
161-
if PDTick is not None and not isinstance(obj, PDTick):
162-
return tokenize_function(obj)
163-
164157
try:
165158
return super().__call__(obj, *args, **kwargs)
166159
except KeyError:
160+
if hasattr(obj, '__mars_tokenize__') and not isinstance(obj, type):
161+
if len(args) == 0 and len(kwargs) == 0:
162+
return obj.__mars_tokenize__()
163+
else:
164+
return super().__call__(obj.__mars_tokenize__(), *args, **kwargs)
165+
if callable(obj):
166+
if PDTick is not None and not isinstance(obj, PDTick):
167+
return tokenize_function(obj)
168+
167169
try:
168170
return cloudpickle.dumps(obj)
169171
except:
170172
raise TypeError(f'Cannot generate token for {obj}, type: {type(obj)}') from None
171173

172174

173175
cdef inline list iterative_tokenize(object ob):
174-
dq = deque(ob)
175-
h_list = []
176-
while dq:
177-
x = dq.pop()
176+
cdef list dq = [ob]
177+
cdef int dq_pos = 0
178+
cdef list h_list = []
179+
while dq_pos < len(dq):
180+
x = dq[dq_pos]
181+
dq_pos += 1
178182
if isinstance(x, (list, tuple)):
179183
dq.extend(x)
180184
elif isinstance(x, set):
@@ -188,7 +192,6 @@ cdef inline list iterative_tokenize(object ob):
188192

189193
cdef inline tuple tokenize_numpy(ob):
190194
cdef int offset
191-
cdef str data
192195

193196
if not ob.shape:
194197
return str(ob), ob.dtype
@@ -288,6 +291,11 @@ cdef list tokenize_sqlalchemy_selectable(ob):
288291
return iterative_tokenize([str(ob)])
289292

290293

294+
cdef list tokenize_enum(ob):
295+
cls = type(ob)
296+
return iterative_tokenize([id(cls), cls.__name__, ob.name])
297+
298+
291299
@lru_cache(500)
292300
def tokenize_function(ob):
293301
if isinstance(ob, partial):
@@ -342,8 +350,8 @@ tokenize_handler.register(np.ndarray, tokenize_numpy)
342350
tokenize_handler.register(dict, lambda ob: iterative_tokenize(sorted(ob.items())))
343351
tokenize_handler.register(set, lambda ob: iterative_tokenize(sorted(ob)))
344352
tokenize_handler.register(np.random.RandomState, lambda ob: iterative_tokenize(ob.get_state()))
345-
tokenize_handler.register(Enum, lambda ob: iterative_tokenize((type(ob), ob.name)))
346353
tokenize_handler.register(memoryview, lambda ob: mmh3_hash_from_buffer(ob))
354+
tokenize_handler.register(Enum, tokenize_enum)
347355
tokenize_handler.register(pd.Index, tokenize_pandas_index)
348356
tokenize_handler.register(pd.Series, tokenize_pandas_series)
349357
tokenize_handler.register(pd.DataFrame, tokenize_pandas_dataframe)

mars/core/base.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,31 @@ def _keys_(self):
4646
setattr(cls, member, slots)
4747
return slots
4848

49+
@property
50+
def _copy_tags_(self):
51+
cls = type(self)
52+
member = f"__copy_tags_{cls.__name__}"
53+
try:
54+
return getattr(cls, member)
55+
except AttributeError:
56+
slots = sorted(
57+
f.attr_name
58+
for k, f in self._FIELDS.items()
59+
if k not in self._no_copy_attrs_
60+
)
61+
setattr(cls, member, slots)
62+
return slots
63+
4964
@property
5065
def _values_(self):
51-
return [
52-
getattr(self, k, None) for k in self._keys_ if k not in self._no_copy_attrs_
53-
]
66+
return [self._FIELD_VALUES.get(k) for k in self._copy_tags_]
5467

5568
def __mars_tokenize__(self):
56-
if hasattr(self, "_key"):
69+
try:
70+
return self._key
71+
except AttributeError: # pragma: no cover
72+
self._update_key()
5773
return self._key
58-
else:
59-
return (type(self), *self._values_)
6074

6175
def _obj_set(self, k, v):
6276
object.__setattr__(self, k, v)

mars/core/entity/tileables.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ...typing import OperandType, TileableType, ChunkType
2626
from ...utils import on_serialize_shape, on_deserialize_shape, on_serialize_nsplits
2727
from ..base import Base
28-
from ..mode import enter_mode, is_build_mode
28+
from ..mode import enter_mode
2929
from .chunks import Chunk
3030
from .core import EntityData, Entity
3131
from .executable import _ExecutableMixin
@@ -280,8 +280,12 @@ def __init__(self: TileableType, *args, **kwargs):
280280

281281
super().__init__(*args, **kwargs)
282282

283-
if hasattr(self, "_chunks") and self._chunks:
284-
self._chunks = sorted(self._chunks, key=attrgetter("index"))
283+
try:
284+
chunks = self._chunks
285+
if chunks:
286+
self._chunks = sorted(chunks, key=attrgetter("index"))
287+
except AttributeError: # pragma: no cover
288+
pass
285289

286290
self._entities = WeakSet()
287291
self._executed_sessions = []
@@ -331,11 +335,9 @@ def is_coarse(self):
331335
return True
332336
return False
333337

334-
@enter_mode(build=True)
335338
def attach(self, entity):
336339
self._entities.add(entity)
337340

338-
@enter_mode(build=True)
339341
def detach(self, entity):
340342
self._entities.discard(entity)
341343

@@ -345,10 +347,11 @@ class Tileable(Entity):
345347

346348
def __init__(self, data: TileableType = None, **kw):
347349
super().__init__(data=data, **kw)
348-
if self._data is not None:
349-
self._data.attach(self)
350-
if self._data.op.create_view:
351-
entity_view_handler.add_observer(self._data.inputs[0], self)
350+
data = self._data
351+
if data is not None:
352+
data.attach(self)
353+
if data.op.create_view:
354+
entity_view_handler.add_observer(data.inputs[0], self)
352355

353356
def __copy__(self):
354357
return self._view()
@@ -412,11 +415,9 @@ def ndim(self):
412415

413416
def __len__(self):
414417
try:
415-
return self.shape[0]
416-
except IndexError:
417-
if is_build_mode():
418-
return 0
419-
raise TypeError("len() of unsized object")
418+
return int(self.shape[0])
419+
except (IndexError, ValueError): # pragma: no cover
420+
return 0
420421

421422
@property
422423
def shape(self):

mars/core/mode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,5 +91,6 @@ def enter_mode(kernel=None, build=None):
9191
"kernel": kernel,
9292
"build": build,
9393
}
94+
mode_name_to_value = {k: v for k, v in mode_name_to_value.items() if v is not None}
9495

9596
return _EnterModeFuncWrapper(mode_name_to_value)

mars/core/operand/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ class Operand(Base, metaclass=OperandMetaclass):
127127
attr_tag = "attr"
128128
_init_update_key_ = False
129129
_output_type_ = None
130+
_no_copy_attrs_ = Base._no_copy_attrs_ | {"scheduling_hint"}
130131

131132
sparse = BoolField("sparse", default=False)
132133
device = Int32Field("device", default=None)

mars/core/operand/core.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@
2323
UFuncTypeError = None
2424

2525
from ...typing import TileableType, ChunkType, OperandType
26-
from ...utils import calc_data_size
26+
from ...utils import calc_data_size, tokenize
2727
from ..context import Context
2828
from ..mode import is_eager_mode
2929
from ..entity import (
3030
OutputType,
31-
TILEABLE_TYPE,
3231
ExecutableTuple,
3332
get_chunk_types,
3433
get_tileable_types,
@@ -46,13 +45,11 @@ class TileableOperandMixin:
4645
def check_inputs(self, inputs: List[TileableType]):
4746
if not inputs:
4847
return
48+
49+
from ...dataframe.core import DATAFRAME_TYPE
50+
4951
for inp in inputs:
50-
if isinstance(inp, TILEABLE_TYPE):
51-
i = inp.extra_params["_i"]
52-
if not inp.op.output_types:
53-
continue
54-
if inp.op.output_types[i] != OutputType.dataframe:
55-
continue
52+
if isinstance(inp, DATAFRAME_TYPE):
5653
dtypes = getattr(inp, "dtypes", None)
5754
if dtypes is None:
5855
raise ValueError(
@@ -62,24 +59,25 @@ def check_inputs(self, inputs: List[TileableType]):
6259

6360
@classmethod
6461
def _check_if_gpu(cls, inputs: List[TileableType]):
65-
if (
66-
inputs is not None
67-
and len(
68-
[
69-
inp
70-
for inp in inputs
71-
if inp is not None and getattr(inp, "op", None) is not None
72-
]
73-
)
74-
> 0
75-
):
76-
if all(inp.op.gpu is True for inp in inputs):
77-
return True
78-
elif all(inp.op.gpu is False for inp in inputs):
79-
return False
62+
if not inputs:
63+
return None
64+
true_num = 0
65+
for inp in inputs:
66+
op = getattr(inp, "op", None)
67+
if op is None or op.gpu is None:
68+
return None
69+
true_num += int(op.gpu)
70+
if true_num == len(inputs):
71+
return True
72+
elif true_num == 0:
73+
return False
74+
return None
75+
76+
def _tokenize_output(self, output_idx: int, **kw):
77+
return tokenize(self._key, output_idx)
8078

8179
def _create_chunk(self, output_idx: int, index: Tuple[int], **kw) -> ChunkType:
82-
output_type = kw.pop("output_type", self._get_output_type(output_idx))
80+
output_type = kw.pop("output_type", None) or self._get_output_type(output_idx)
8381
if not output_type:
8482
raise ValueError("output_type should be specified")
8583

@@ -92,6 +90,11 @@ def _create_chunk(self, output_idx: int, index: Tuple[int], **kw) -> ChunkType:
9290
if output_type == OutputType.scalar:
9391
# tensor
9492
kw["order"] = "C_ORDER"
93+
94+
# key of output chunks may only contain keys for its output ids
95+
if "_key" not in kw:
96+
kw["_key"] = self._tokenize_output(output_idx, **kw)
97+
9598
data = chunk_data_type(**kw)
9699
return chunk_type(data)
97100

@@ -189,6 +192,7 @@ def _create_tileable(self, output_idx: int, **kw) -> TileableType:
189192

190193
if isinstance(output_type, (list, tuple)):
191194
output_type = output_type[output_idx]
195+
192196
tileable_type, tileable_data_type = get_tileable_types(output_type)
193197
kw["_i"] = output_idx
194198
kw["op"] = self
@@ -197,6 +201,11 @@ def _create_tileable(self, output_idx: int, **kw) -> TileableType:
197201
kw["order"] = "C_ORDER"
198202

199203
kw = self._fill_nan_shape(kw)
204+
205+
# key of output chunks may only contain keys for its output ids
206+
if "_key" not in kw:
207+
kw["_key"] = self._tokenize_output(output_idx, **kw)
208+
200209
data = tileable_data_type(**kw)
201210
return tileable_type(data)
202211

@@ -207,12 +216,11 @@ def _new_tileables(
207216
if output_limit is None:
208217
output_limit = getattr(self, "output_limit")
209218

210-
self.check_inputs(inputs)
211-
getattr(self, "_set_inputs")(inputs)
212-
if getattr(self, "gpu", None) is None:
219+
self._set_inputs(inputs)
220+
if self.gpu is None:
213221
self.gpu = self._check_if_gpu(self._inputs)
214222
if getattr(self, "_key", None) is None:
215-
getattr(self, "_update_key")() # update key when inputs are set
223+
self._update_key() # update key when inputs are set
216224

217225
tileables = []
218226
for j in range(output_limit):
@@ -222,7 +230,7 @@ def _new_tileables(
222230
tileable = self._create_tileable(j, **create_tensor_kw)
223231
tileables.append(tileable)
224232

225-
setattr(self, "outputs", tileables)
233+
self.outputs = tileables
226234
if len(tileables) > 1:
227235
# for each output tileable, hold the reference to the other outputs
228236
# so that either no one or everyone are gc collected

mars/dataframe/arithmetic/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,10 @@ def _calc_properties(cls, x1, x2=None, axis="columns"):
434434
):
435435
x2_dtype = x2.dtype if hasattr(x2, "dtype") else type(x2)
436436
x2_dtype = get_dtype(x2_dtype)
437-
dtype = infer_dtype(x1.dtype, x2_dtype, cls._operator)
437+
if hasattr(cls, "return_dtype"):
438+
dtype = cls.return_dtype
439+
else:
440+
dtype = infer_dtype(x1.dtype, x2_dtype, cls._operator)
438441
ret = {"shape": x1.shape, "dtype": dtype, "index_value": x1.index_value}
439442
if pd.api.types.is_scalar(x2) or (
440443
hasattr(x2, "ndim") and (x2.ndim == 0 or x2.ndim == 1)

mars/dataframe/arithmetic/equal.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import numpy as np
16+
1517
from ... import opcodes as OperandDef
1618
from ...utils import classproperty
1719
from .core import DataFrameBinopUfunc
@@ -24,6 +26,8 @@ class DataFrameEqual(DataFrameBinopUfunc):
2426
_func_name = "eq"
2527
_rfunc_name = "eq"
2628

29+
return_dtype = np.dtype(bool)
30+
2731
@classproperty
2832
def _operator(self):
2933
return lambda lhs, rhs: lhs.eq(rhs)

mars/dataframe/arithmetic/greater.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import numpy as np
16+
1517
from ... import opcodes as OperandDef
1618
from ...utils import classproperty
1719
from .core import DataFrameBinopUfunc
@@ -24,6 +26,8 @@ class DataFrameGreater(DataFrameBinopUfunc):
2426
_func_name = "gt"
2527
_rfunc_name = "lt"
2628

29+
return_dtype = np.dtype(bool)
30+
2731
@classproperty
2832
def _operator(self):
2933
return lambda lhs, rhs: lhs.gt(rhs)

0 commit comments

Comments
 (0)