Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions gcsfs/concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import asyncio
from contextlib import asynccontextmanager


@asynccontextmanager
async def parallel_tasks_first_completed(coros):
"""
Starts coroutines in parallel and enters the context as soon as
at least one task has completed. Automatically cancels pending tasks
when exiting the context.
"""
tasks = [asyncio.create_task(c) for c in coros]
try:
# Suspend until the first task finishes for maximum responsiveness
done, pending = await asyncio.wait(
set(tasks), return_when=asyncio.FIRST_COMPLETED
)
yield tasks, done, pending
finally:
# Ensure 'losing' tasks are cancelled immediately
for t in tasks:
if not t.done():
t.cancel()
26 changes: 17 additions & 9 deletions gcsfs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from . import __version__ as version
from .checkers import get_consistency_checker
from .concurrency import parallel_tasks_first_completed
from .credentials import GoogleCredentials
from .inventory_report import InventoryReport
from .retry import errs, retry_request, validate_response
Expand Down Expand Up @@ -1080,15 +1081,22 @@ async def _info(self, path, generation=None, **kwargs):
"storageClass": "DIRECTORY",
"type": "directory",
}
# Check exact file path
try:
exact = await self._get_object(path)
# this condition finds a "placeholder" - still need to check if it's a directory
if not _is_directory_marker(exact):
return exact
except FileNotFoundError:
pass
return await self._get_directory_info(path, bucket, key, generation)

async with parallel_tasks_first_completed(
[
self._get_object(path),
self._get_directory_info(path, bucket, key, generation),
]
) as (tasks, done, pending):
get_object_task, get_directory_info_task = tasks

try:
get_object_res = await get_object_task
if not _is_directory_marker(get_object_res):
return get_object_res
except FileNotFoundError:
pass
return await get_directory_info_task

async def _get_directory_info(self, path, bucket, key, generation):
"""
Expand Down
78 changes: 78 additions & 0 deletions gcsfs/tests/test_concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import asyncio

import pytest

from gcsfs.concurrency import parallel_tasks_first_completed


@pytest.mark.asyncio
async def test_parallel_tasks_first_completed_basic():
async def slow_task():
await asyncio.sleep(1)
return "slow"

async def fast_task():
await asyncio.sleep(0.1)
return "fast"

async with parallel_tasks_first_completed([slow_task(), fast_task()]) as (
tasks,
done,
pending,
):
assert len(done) == 1
assert len(pending) == 1
completed_task = done.pop()
assert completed_task.result() == "fast"
assert len(tasks) == 2


@pytest.mark.asyncio
async def test_parallel_tasks_first_completed_cancellation():
task_cancelled = False

async def slow_task():
nonlocal task_cancelled
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
task_cancelled = True
raise

async def fast_task():
await asyncio.sleep(0.1)
return "fast"

async with parallel_tasks_first_completed([slow_task(), fast_task()]) as (
tasks,
done,
pending,
):
assert len(done) == 1
completed_task = done.pop()
assert completed_task.result() == "fast"

# After exiting context, slow_task should be cancelled
await asyncio.sleep(0.1) # Give it a moment to run cancellation cleanup
assert task_cancelled


@pytest.mark.asyncio
async def test_parallel_tasks_first_completed_exception():
async def error_task():
await asyncio.sleep(0.1)
raise ValueError("error")

async def slow_task():
await asyncio.sleep(1)
return "slow"

async with parallel_tasks_first_completed([error_task(), slow_task()]) as (
tasks,
done,
pending,
):
assert len(done) == 1
completed_task = done.pop()
with pytest.raises(ValueError, match="error"):
completed_task.result()
114 changes: 114 additions & 0 deletions gcsfs/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2367,3 +2367,117 @@ def test_walk(gcs):
exp_dirs, exp_files = expected_structure[root]
assert set(d_list) == exp_dirs
assert set(f_list) == exp_files


@pytest.mark.asyncio
@pytest.mark.parametrize(
"object_behavior, dir_behavior, expected",
[
(
{"return": {"name": TEST_BUCKET + "/file", "type": "file", "size": 100}},
{"exception": FileNotFoundError},
{"return": {"type": "file"}},
),
(
{"exception": FileNotFoundError},
{"return": {"name": TEST_BUCKET + "/file", "type": "directory", "size": 0}},
{"return": {"type": "directory"}},
),
(
{
"return": {
"name": TEST_BUCKET + "/file/",
"type": "directory",
"size": 0,
}
},
{
"return": {
"name": TEST_BUCKET + "/file",
"type": "directory",
"size": 0,
"extra": "info",
}
},
{"return": {"type": "directory", "extra": "info"}},
),
(
{"exception": Exception("Generic error")},
{"exception": FileNotFoundError},
{"exception": Exception, "match": "Generic error"},
),
(
{"exception": FileNotFoundError},
{"exception": Exception("Directory error")},
{"exception": Exception, "match": "Directory error"},
),
(
{"exception": FileNotFoundError},
{"exception": FileNotFoundError},
{"exception": FileNotFoundError},
),
],
)
async def test_info_parallel(gcs, object_behavior, dir_behavior, expected):
path = TEST_BUCKET + "/file"

with (
mock.patch.object(
gcs, "_get_object", new_callable=mock.AsyncMock
) as mock_get_object,
mock.patch.object(
gcs, "_get_directory_info", new_callable=mock.AsyncMock
) as mock_get_dir,
):

if "return" in object_behavior:
mock_get_object.return_value = object_behavior["return"]
elif "exception" in object_behavior:
mock_get_object.side_effect = object_behavior["exception"]

if "return" in dir_behavior:
mock_get_dir.return_value = dir_behavior["return"]
elif "exception" in dir_behavior:
mock_get_dir.side_effect = dir_behavior["exception"]

if "exception" in expected:
with pytest.raises(expected["exception"], match=expected.get("match")):
await gcs._info(path)
else:
res = await gcs._info(path)
for k, v in expected["return"].items():
assert res[k] == v

assert mock_get_object.call_count == 1
assert mock_get_dir.call_count == 1


@pytest.mark.asyncio
async def test_info_parallel_dir_first(gcs):
import asyncio

path = TEST_BUCKET + "/dir"

with (
mock.patch.object(
gcs, "_get_object", new_callable=mock.AsyncMock
) as mock_get_object,
mock.patch.object(
gcs, "_get_directory_info", new_callable=mock.AsyncMock
) as mock_get_dir,
):

# Make _get_object slower than _get_directory_info
async def slow_get_object(*args, **kwargs):
await asyncio.sleep(0.1)
return {"name": path, "type": "file", "size": 100}

mock_get_object.side_effect = slow_get_object
# Directory check finishes immediately and succeeds
mock_get_dir.return_value = {"name": path, "type": "directory", "size": 0}

res = await gcs._info(path)
assert res["type"] == "file"

assert mock_get_object.call_count == 1
assert mock_get_dir.call_count == 1
Loading