diff --git a/dvc/data_cloud.py b/dvc/data_cloud.py index 12d40317a7..9eeead819b 100644 --- a/dvc/data_cloud.py +++ b/dvc/data_cloud.py @@ -1,10 +1,12 @@ """Manages dvc remotes that user can use with push/pull/status commands.""" +import json from collections.abc import Iterable -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, ClassVar, Optional from dvc.config import NoRemoteError, RemoteConfigError from dvc.log import logger +from dvc.utils import bytes_hash from dvc.utils.objects import cached_property from dvc_data.hashfile.db import get_index from dvc_data.hashfile.transfer import TransferResult @@ -19,15 +21,69 @@ class Remote: + _CACHE: ClassVar[dict[tuple[str, ...], "FileSystem"]] = {} + _NAME_TO_KEY: ClassVar[dict[str, tuple[str, ...]]] = {} + def __init__(self, name: str, path: str, fs: "FileSystem", *, index=None, **config): self.path = path self.fs = fs self.name = name self.index = index - self.worktree: bool = config.pop("worktree", False) self.config = config + self._ensure_cached_fs() + + @classmethod + def _fs_cache_key( + cls, + remote_name: str, + fs_cls, + config: dict, + fs_path: str, + ) -> tuple[str, ...]: + serialized_config = json.dumps(config, sort_keys=True, default=str) + config_hash = bytes_hash(serialized_config.encode("utf-8"), "sha256") + return ( + remote_name.lower(), + fs_cls.__module__ or "", + getattr(fs_cls, "__qualname__", getattr(fs_cls, "__name__", "")), + fs_path, + config_hash, + ) + + def _close_fs(self) -> None: + close = getattr(self.fs, "close", None) + if callable(close): + close() + + def _ensure_cached_fs(self) -> None: + cls = type(self) + cache_key = cls._fs_cache_key( + self.name, + type(self.fs), + self.config, + self.path, + ) + + cached_fs = cls._CACHE.get(cache_key) + if cached_fs is not None and cached_fs is not self.fs: + self._close_fs() + self.fs = cached_fs + else: + cls._CACHE[cache_key] = self.fs + + name_key = self.name.lower() + prev_key = cls._NAME_TO_KEY.get(name_key) + if prev_key and prev_key != cache_key: + prev_fs = cls._CACHE.pop(prev_key, None) + if prev_fs is not None and prev_fs is not self.fs: + close = getattr(prev_fs, "close", None) + if callable(close): + close() + + cls._NAME_TO_KEY[name_key] = cache_key + @cached_property def odb(self) -> "HashFileDB": from dvc.cachemgr import CacheManager @@ -100,13 +156,20 @@ def get_remote( if version_aware is None: config["version_aware"] = True - fs = cls(**config) - config["tmp_dir"] = self.repo.site_cache_dir + fs_config = dict(config) + fs = cls(**fs_config) + runtime_config = {**fs_config, "tmp_dir": self.repo.site_cache_dir} if self.repo.data_index is not None: index = self.repo.data_index.view(("remote", name)) else: index = None - return Remote(name, fs_path, fs, index=index, **config) + return Remote( + name, + fs_path, + fs, + index=index, + **runtime_config, + ) if bool(self.repo.config["remote"]): error_msg = (