Skip to content

Commit 54636e6

Browse files
authored
Fix progress always is 0 or 100% (#2591)
1 parent 3bde5a4 commit 54636e6

File tree

3 files changed

+45
-10
lines changed

3 files changed

+45
-10
lines changed

mars/deploy/oscar/session.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -814,16 +814,20 @@ async def _run_in_background(
814814
while True:
815815
try:
816816
if not cancelled:
817-
task_result: TaskResult = await self._task_api.wait_task(
818-
task_id, timeout=0.5
819-
)
820-
if task_result is None:
821-
# not finished, set progress
822-
progress.value = await self._task_api.get_task_progress(
823-
task_id
817+
task_result: Union[TaskResult, None] = None
818+
try:
819+
task_result = await self._task_api.wait_task(
820+
task_id, timeout=0.5
824821
)
825-
else:
826-
progress.value = 1.0
822+
finally:
823+
if task_result is None:
824+
# not finished, set progress
825+
progress.value = await self._task_api.get_task_progress(
826+
task_id
827+
)
828+
else:
829+
progress.value = 1.0
830+
if task_result is not None:
827831
break
828832
else:
829833
# wait for task to finish

mars/deploy/oscar/tests/test_local.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .... import tensor as mt
3333
from .... import remote as mr
3434
from ....config import option_context
35+
from ....core.context import get_context
3536
from ....lib.aio import new_isolation
3637
from ....storage import StorageLevel
3738
from ....services.storage import StorageAPI
@@ -378,6 +379,34 @@ def test_no_default_session():
378379
assert get_default_async_session() is None
379380

380381

382+
@pytest.mark.asyncio
383+
async def test_session_progress(create_cluster):
384+
session = get_default_async_session()
385+
assert session.address is not None
386+
assert session.session_id is not None
387+
388+
def f1(interval: float, count: int):
389+
for idx in range(count):
390+
time.sleep(interval)
391+
get_context().set_progress((1 + idx) * 1.0 / count)
392+
393+
r = mr.spawn(f1, args=(0.5, 10))
394+
395+
info = await session.execute(r)
396+
397+
for _ in range(20):
398+
if 0 < info.progress() < 1:
399+
break
400+
await asyncio.sleep(0.1)
401+
else:
402+
raise Exception("progress test failed.")
403+
404+
await info
405+
assert info.result() is None
406+
assert info.exception() is None
407+
assert info.progress() == 1
408+
409+
381410
@pytest.fixture
382411
def setup_session():
383412
session = new_session(n_cpu=2, use_uvloop=False)

mars/services/task/api/web.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,9 @@ async def get_last_idle_time(self) -> Union[float, None]:
240240

241241
async def wait_task(self, task_id: str, timeout: float = None):
242242
path = f"{self._address}/api/session/{self._session_id}/task/{task_id}"
243-
params = {"action": "wait", "timeout": str(timeout or "")}
243+
# client timeout should be longer than server timeout.
244+
server_timeout = "" if timeout is None else str(max(timeout / 2.0, timeout - 1))
245+
params = {"action": "wait", "timeout": server_timeout}
244246
res = await self._request_url(
245247
"GET", path, params=params, request_timeout=timeout or 0
246248
)

0 commit comments

Comments
 (0)