Skip to content

Commit 866a5fa

Browse files
authored
Reduce the time cost of fetching tileable data (#2594)
1 parent 49ce0c1 commit 866a5fa

File tree

6 files changed

+126
-24
lines changed

6 files changed

+126
-24
lines changed

mars/core/entity/executable.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ def register(self, tileable: TileableType, session: SessionType):
8080
def _get_session(executable: "_ExecutableMixin", session: SessionType = None):
8181
from ...deploy.oscar.session import get_default_session
8282

83-
if session is None and len(executable._executed_sessions) > 0:
84-
session = executable._executed_sessions[-1]
83+
# if session is not specified, use default session
8584
if session is None:
8685
session = get_default_session()
8786

@@ -151,6 +150,9 @@ def _execute_and_fetch(self, session: SessionType = None, **kw):
151150

152151
session = _get_session(self, session)
153152
fetch_kwargs = kw.pop("fetch_kwargs", dict())
153+
if session in self._executed_sessions:
154+
# if has been executed, fetch directly.
155+
return self.fetch(session=session, **fetch_kwargs)
154156
ret = self.execute(session=session, **kw)
155157
if isinstance(ret, ExecutionInfo):
156158
# wait=False

mars/dataframe/core.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
on_serialize_numpy_type,
6262
ceildiv,
6363
tokenize,
64+
estimate_pandas_size,
6465
)
6566
from .utils import fetch_corner_data, ReprSeries, parse_index, merge_index_value
6667
from ..tensor import statistics
@@ -565,28 +566,52 @@ def to_pandas(self, session=None, **kw):
565566
class _BatchedFetcher:
566567
__slots__ = ()
567568

568-
def _iter(self, batch_size=1000, session=None, **kw):
569+
def _iter(self, batch_size=None, session=None, **kw):
569570
from .indexing.iloc import iloc
570571

571-
size = self.shape[0]
572-
n_batch = ceildiv(size, batch_size)
572+
if batch_size is not None:
573+
size = self.shape[0]
574+
n_batch = ceildiv(size, batch_size)
573575

574-
if n_batch > 1:
575-
for i in range(n_batch):
576-
batch_data = iloc(self)[batch_size * i : batch_size * (i + 1)]
577-
yield batch_data._fetch(session=session, **kw)
576+
if n_batch > 1:
577+
for i in range(n_batch):
578+
batch_data = iloc(self)[batch_size * i : batch_size * (i + 1)]
579+
yield batch_data._fetch(session=session, **kw)
580+
else:
581+
yield self._fetch(session=session, **kw)
578582
else:
579-
yield self._fetch(session=session, **kw)
583+
# if batch_size is not specified, use first batch to estimate
584+
# batch_size.
585+
default_batch_bytes = 50 * 1024 ** 2
586+
first_batch = 1000
587+
size = self.shape[0]
588+
589+
if size >= first_batch:
590+
batch_data = iloc(self)[:first_batch]
591+
first_batch_data = batch_data._fetch(session=session, **kw)
592+
yield first_batch_data
593+
data_size = estimate_pandas_size(first_batch_data)
594+
batch_size = int(default_batch_bytes / data_size * first_batch)
595+
n_batch = ceildiv(size - 1000, batch_size)
596+
for i in range(n_batch):
597+
batch_data = iloc(self)[
598+
first_batch
599+
+ batch_size * i : first_batch
600+
+ batch_size * (i + 1)
601+
]
602+
yield batch_data._fetch(session=session, **kw)
603+
else:
604+
yield self._fetch(session=session, **kw)
580605

581-
def iterbatch(self, batch_size=1000, session=None, **kw):
606+
def iterbatch(self, batch_size=None, session=None, **kw):
582607
# trigger execution
583608
self.execute(session=session, **kw)
584609
return self._iter(batch_size=batch_size, session=session)
585610

586611
def fetch(self, session=None, **kw):
587612
from .indexing.iloc import DataFrameIlocGetItem, SeriesIlocGetItem
588613

589-
batch_size = kw.pop("batch_size", 1000)
614+
batch_size = kw.pop("batch_size", None)
590615
if isinstance(self.op, (DataFrameIlocGetItem, SeriesIlocGetItem)):
591616
# see GH#1871
592617
# already iloc, do not trigger batch fetch

mars/learn/tests/test_wrappers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,9 @@
2323
from ..wrappers import ParallelPostFit
2424

2525

26-
raw_x, raw_y = make_classification(n_samples=1000)
27-
X, y = mt.tensor(raw_x, chunk_size=100), mt.tensor(raw_y, chunk_size=100)
28-
29-
3026
def test_parallel_post_fit_basic(setup):
27+
raw_x, raw_y = make_classification(n_samples=1000)
28+
X, y = mt.tensor(raw_x, chunk_size=100), mt.tensor(raw_y, chunk_size=100)
3129
clf = ParallelPostFit(GradientBoostingClassifier())
3230
clf.fit(X, y)
3331

@@ -47,6 +45,8 @@ def test_parallel_post_fit_basic(setup):
4745

4846

4947
def test_parallel_post_fit_predict(setup):
48+
raw_x, raw_y = make_classification(n_samples=1000)
49+
X, y = mt.tensor(raw_x, chunk_size=100), mt.tensor(raw_y, chunk_size=100)
5050
base = LogisticRegression(random_state=0, n_jobs=1, solver="lbfgs")
5151
wrap = ParallelPostFit(LogisticRegression(random_state=0, n_jobs=1, solver="lbfgs"))
5252

@@ -67,6 +67,8 @@ def test_parallel_post_fit_predict(setup):
6767

6868

6969
def test_parallel_post_fit_transform(setup):
70+
raw_x, raw_y = make_classification(n_samples=1000)
71+
X, y = mt.tensor(raw_x, chunk_size=100), mt.tensor(raw_y, chunk_size=100)
7072
base = PCA(random_state=0)
7173
wrap = ParallelPostFit(PCA(random_state=0))
7274

@@ -79,6 +81,8 @@ def test_parallel_post_fit_transform(setup):
7981

8082

8183
def test_parallel_post_fit_multiclass(setup):
84+
raw_x, raw_y = make_classification(n_samples=1000)
85+
X, y = mt.tensor(raw_x, chunk_size=100), mt.tensor(raw_y, chunk_size=100)
8286
raw_x, raw_y = make_classification(n_classes=3, n_informative=4)
8387
X, y = mt.tensor(raw_x, chunk_size=50), mt.tensor(raw_y, chunk_size=50)
8488

mars/services/meta/api/web.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ async def get_chunk_meta(self, session_id: str, data_key: str):
4848
result = await oscar_api.get_chunk_meta(data_key, fields=fields, error=error)
4949
self.write(serialize_serializable(result))
5050

51+
@web_api("", method="post")
52+
async def get_chunks_meta(self, session_id: str):
53+
body_args = deserialize_serializable(self.request.body)
54+
oscar_api = await self._get_oscar_meta_api(session_id)
55+
get_metas = []
56+
for data_key, fields, error in body_args:
57+
get_metas.append(oscar_api.get_chunk_meta.delay(data_key, fields, error))
58+
results = await oscar_api.get_chunk_meta.batch(*get_metas)
59+
self.write(serialize_serializable(results))
60+
5161

5262
web_handlers = {MetaWebAPIHandler.get_root_pattern(): MetaWebAPIHandler}
5363

@@ -67,3 +77,16 @@ async def get_chunk_meta(
6777
params["fields"] = ",".join(fields)
6878
res = await self._request_url("GET", req_addr, params=params)
6979
return deserialize_serializable(res.body)
80+
81+
@get_chunk_meta.batch
82+
async def get_chunks_meta(self, args_list, kwargs_list):
83+
get_chunk_metas = []
84+
for args, kwargs in zip(args_list, kwargs_list):
85+
object_id, fields, error = self.get_chunk_meta.bind(*args, **kwargs)
86+
get_chunk_metas.append([object_id, fields, error])
87+
88+
req_addr = f"{self._address}/api/session/{self._session_id}/meta"
89+
res = await self._request_url(
90+
"POST", req_addr, data=serialize_serializable(get_chunk_metas)
91+
)
92+
return deserialize_serializable(res.body)

mars/services/storage/api/web.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections import defaultdict
1516
from typing import Any, List
1617

1718
from .... import oscar as mo
@@ -55,6 +56,25 @@ async def get_data(self, session_id: str, data_key: str):
5556
result = await oscar_api.get(data_key)
5657
self.write(serialize_serializable(result))
5758

59+
@web_api("batch/get", method="post")
60+
async def get_batch_data(self, session_id: str):
61+
body_args = deserialize_serializable(self.request.body)
62+
storage_api_to_gets = defaultdict(list)
63+
storage_api_to_idx = defaultdict(list)
64+
results = [None] * len(body_args)
65+
for i, (data_key, conditions, error) in enumerate(body_args):
66+
oscar_api = await self._get_storage_api_by_object_id(session_id, data_key)
67+
storage_api_to_idx[oscar_api].append(i)
68+
storage_api_to_gets[oscar_api].append(
69+
oscar_api.get.delay(data_key, conditions=conditions, error=error)
70+
)
71+
for api, fetches in storage_api_to_gets.items():
72+
data_list = await api.get.batch(*fetches)
73+
for idx, data in zip(storage_api_to_idx[api], data_list):
74+
results[idx] = data
75+
res_data = serialize_serializable(results)
76+
self.write(res_data)
77+
5878
@web_api("(?P<data_key>[^/]+)", method="post")
5979
async def get_data_by_post(self, session_id: str, data_key: str):
6080
body_args = (
@@ -110,6 +130,21 @@ async def get(
110130
)
111131
return deserialize_serializable(res.body)
112132

133+
@get.batch
134+
async def get_batch(self, args_list, kwargs_list):
135+
get_chunks = []
136+
for args, kwargs in zip(args_list, kwargs_list):
137+
data_key, conditions, error = self.get.bind(*args, **kwargs)
138+
get_chunks.append([data_key, conditions, error])
139+
140+
path = f"{self._address}/api/session/{self._session_id}/storage/batch/get"
141+
res = await self._request_url(
142+
path=path,
143+
method="POST",
144+
data=serialize_serializable(get_chunks),
145+
)
146+
return deserialize_serializable(res.body)
147+
113148
@mo.extensible
114149
async def put(
115150
self, data_key: str, obj: object, level: StorageLevel = StorageLevel.MEMORY

mars/utils.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,27 +1085,40 @@ def arrow_array_to_objects(
10851085
return obj
10861086

10871087

1088+
_enter_counter = 0
1089+
_initial_session = None
1090+
1091+
10881092
def enter_current_session(func: Callable):
10891093
@functools.wraps(func)
10901094
def wrapped(cls, ctx, op):
10911095
from .deploy.oscar.session import AbstractSession, get_default_session
10921096

1097+
global _enter_counter, _initial_session
10931098
# skip in some test cases
10941099
if not hasattr(ctx, "get_current_session"):
10951100
return func(cls, ctx, op)
10961101

1097-
session = ctx.get_current_session()
1098-
prev_default_session = get_default_session()
1099-
session.as_default()
1102+
with AbstractSession._lock:
1103+
if _enter_counter == 0:
1104+
# to handle nested call, only set initial session
1105+
# in first call
1106+
session = ctx.get_current_session()
1107+
_initial_session = get_default_session()
1108+
session.as_default()
1109+
_enter_counter += 1
11001110

11011111
try:
11021112
result = func(cls, ctx, op)
11031113
finally:
1104-
if prev_default_session:
1105-
prev_default_session.as_default()
1106-
else:
1107-
AbstractSession.reset_default()
1108-
1114+
with AbstractSession._lock:
1115+
_enter_counter -= 1
1116+
if _enter_counter == 0:
1117+
# set previous session when counter is 0
1118+
if _initial_session:
1119+
_initial_session.as_default()
1120+
else:
1121+
AbstractSession.reset_default()
11091122
return result
11101123

11111124
return wrapped

0 commit comments

Comments
 (0)