Skip to content

Commit 8d00585

Browse files
authored
Introduce reset_local_changes() (#189)
1 parent 6b5dc91 commit 8d00585

File tree

3 files changed

+344
-67
lines changed

3 files changed

+344
-67
lines changed

mergin/client.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import ssl
1515
from enum import Enum, auto
1616
import re
17+
import typing
1718
import warnings
1819

1920
from .common import ClientError, LoginError, InvalidProject
@@ -22,6 +23,8 @@
2223
download_file_finalize,
2324
download_project_async,
2425
download_file_async,
26+
download_files_async,
27+
download_files_finalize,
2528
download_diffs_async,
2629
download_project_finalize,
2730
download_project_wait,
@@ -1127,3 +1130,63 @@ def has_writing_permissions(self, project_path):
11271130
"""
11281131
info = self.project_info(project_path)
11291132
return info["permissions"]["upload"]
1133+
1134+
def reset_local_changes(self, directory: str, files_to_reset: typing.List[str] = None) -> None:
1135+
"""
1136+
Reset local changes to either all files or only listed files.
1137+
Added files are removed, removed files are brought back and updates are discarded.
1138+
1139+
:param directory: Project's directory
1140+
:type directory: String
1141+
:param files_to_reset List of files to reset, relative paths of file
1142+
:type files_to_reset: List of strings, default None
1143+
"""
1144+
all_files = files_to_reset is None
1145+
1146+
mp = MerginProject(directory)
1147+
1148+
current_version = mp.version()
1149+
1150+
push_changes = mp.get_push_changes()
1151+
1152+
files_download = []
1153+
1154+
# remove all added files
1155+
for file in push_changes["added"]:
1156+
if all_files or file["path"] in files_to_reset:
1157+
os.remove(mp.fpath(file["path"]))
1158+
1159+
# update files get override with previous version
1160+
for file in push_changes["updated"]:
1161+
if all_files or file["path"] in files_to_reset:
1162+
if mp.is_versioned_file(file["path"]):
1163+
mp.geodiff.make_copy_sqlite(mp.fpath_meta(file["path"]), mp.fpath(file["path"]))
1164+
else:
1165+
files_download.append(file["path"])
1166+
1167+
# removed files are redownloaded
1168+
for file in push_changes["removed"]:
1169+
if all_files or file["path"] in files_to_reset:
1170+
files_download.append(file["path"])
1171+
1172+
if files_download:
1173+
self.download_files(directory, files_download, version=current_version)
1174+
1175+
def download_files(
1176+
self, project_dir: str, file_paths: typing.List[str], output_paths: typing.List[str] = None, version: str = None
1177+
):
1178+
"""
1179+
Download project files at specified version. Get the latest if no version specified.
1180+
1181+
:param project_dir: project local directory
1182+
:type project_dir: String
1183+
:param file_path: List of relative paths of files to download in the project directory
1184+
:type file_path: List[String]
1185+
:param output_paths: List of paths for files to download to. Should be same length of as file_path. Default is `None` which means that files are downloaded into MerginProject at project_dir.
1186+
:type output_paths: List[String]
1187+
:param version: optional version tag for downloaded file
1188+
:type version: String
1189+
"""
1190+
job = download_files_async(self, project_dir, file_paths, output_paths, version=version)
1191+
pull_project_wait(job)
1192+
download_files_finalize(job)

mergin/client_pull.py

Lines changed: 103 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pprint
1616
import shutil
1717
import tempfile
18+
import typing
1819

1920
import concurrent.futures
2021

@@ -621,77 +622,14 @@ def download_file_async(mc, project_dir, file_path, output_file, version):
621622
Starts background download project file at specified version.
622623
Returns handle to the pending download.
623624
"""
624-
mp = MerginProject(project_dir)
625-
project_path = mp.project_full_name()
626-
ver_info = f"at version {version}" if version is not None else "at latest version"
627-
mp.log.info(f"Getting {file_path} {ver_info}")
628-
latest_proj_info = mc.project_info(project_path)
629-
if version:
630-
project_info = mc.project_info(project_path, version=version)
631-
else:
632-
project_info = latest_proj_info
633-
mp.log.info(f"Got project info. version {project_info['version']}")
634-
635-
# set temporary directory for download
636-
temp_dir = tempfile.mkdtemp(prefix="mergin-py-client-")
637-
638-
download_list = []
639-
update_tasks = []
640-
total_size = 0
641-
# None can not be used to indicate latest version of the file, so
642-
# it is necessary to pass actual version.
643-
if version is None:
644-
version = latest_proj_info["version"]
645-
for file in project_info["files"]:
646-
if file["path"] == file_path:
647-
file["version"] = version
648-
items = _download_items(file, temp_dir)
649-
is_latest_version = version == latest_proj_info["version"]
650-
task = UpdateTask(file["path"], items, output_file, latest_version=is_latest_version)
651-
download_list.extend(task.download_queue_items)
652-
for item in task.download_queue_items:
653-
total_size += item.size
654-
update_tasks.append(task)
655-
break
656-
if not download_list:
657-
warn = f"No {file_path} exists at version {version}"
658-
mp.log.warning(warn)
659-
shutil.rmtree(temp_dir)
660-
raise ClientError(warn)
661-
662-
mp.log.info(f"will download file {file_path} in {len(download_list)} chunks, total size {total_size}")
663-
job = DownloadJob(project_path, total_size, version, update_tasks, download_list, temp_dir, mp, project_info)
664-
job.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
665-
job.futures = []
666-
for item in download_list:
667-
future = job.executor.submit(_do_download, item, mc, mp, project_path, job)
668-
job.futures.append(future)
669-
670-
return job
625+
return download_files_async(mc, project_dir, [file_path], [output_file], version)
671626

672627

673628
def download_file_finalize(job):
674629
"""
675630
To be called when download_file_async is finished
676631
"""
677-
job.executor.shutdown(wait=True)
678-
679-
# make sure any exceptions from threads are not lost
680-
for future in job.futures:
681-
if future.exception() is not None:
682-
raise future.exception()
683-
684-
job.mp.log.info("--- download finished")
685-
686-
temp_dir = None
687-
for task in job.update_tasks:
688-
task.apply(job.directory, job.mp)
689-
if task.download_queue_items:
690-
temp_dir = os.path.dirname(task.download_queue_items[0].download_file_path)
691-
692-
# Remove temporary download directory
693-
if temp_dir is not None:
694-
shutil.rmtree(temp_dir)
632+
download_files_finalize(job)
695633

696634

697635
def download_diffs_async(mc, project_directory, file_path, versions):
@@ -804,3 +742,103 @@ def download_diffs_finalize(job):
804742

805743
job.mp.log.info("--- diffs pull finished")
806744
return diffs
745+
746+
747+
def download_files_async(
748+
mc, project_dir: str, file_paths: typing.List[str], output_paths: typing.List[str], version: str
749+
):
750+
"""
751+
Starts background download project files at specified version.
752+
Returns handle to the pending download.
753+
"""
754+
mp = MerginProject(project_dir)
755+
project_path = mp.project_full_name()
756+
ver_info = f"at version {version}" if version is not None else "at latest version"
757+
mp.log.info(f"Getting [{', '.join(file_paths)}] {ver_info}")
758+
latest_proj_info = mc.project_info(project_path)
759+
if version:
760+
project_info = mc.project_info(project_path, version=version)
761+
else:
762+
project_info = latest_proj_info
763+
mp.log.info(f"Got project info. version {project_info['version']}")
764+
765+
# set temporary directory for download
766+
temp_dir = tempfile.mkdtemp(prefix="mergin-py-client-")
767+
768+
if output_paths is None:
769+
output_paths = []
770+
for file in file_paths:
771+
output_paths.append(mp.fpath(file))
772+
773+
if len(output_paths) != len(file_paths):
774+
warn = "Output file paths are not of the same length as file paths. Cannot store required files."
775+
mp.log.warning(warn)
776+
shutil.rmtree(temp_dir)
777+
raise ClientError(warn)
778+
779+
download_list = []
780+
update_tasks = []
781+
total_size = 0
782+
# None can not be used to indicate latest version of the file, so
783+
# it is necessary to pass actual version.
784+
if version is None:
785+
version = latest_proj_info["version"]
786+
for file in project_info["files"]:
787+
if file["path"] in file_paths:
788+
index = file_paths.index(file["path"])
789+
file["version"] = version
790+
items = _download_items(file, temp_dir)
791+
is_latest_version = version == latest_proj_info["version"]
792+
task = UpdateTask(file["path"], items, output_paths[index], latest_version=is_latest_version)
793+
download_list.extend(task.download_queue_items)
794+
for item in task.download_queue_items:
795+
total_size += item.size
796+
update_tasks.append(task)
797+
798+
missing_files = []
799+
files_to_download = []
800+
project_file_paths = [file["path"] for file in project_info["files"]]
801+
for file in file_paths:
802+
if file not in project_file_paths:
803+
missing_files.append(file)
804+
else:
805+
files_to_download.append(file)
806+
807+
if not download_list or missing_files:
808+
warn = f"No [{', '.join(missing_files)}] exists at version {version}"
809+
mp.log.warning(warn)
810+
shutil.rmtree(temp_dir)
811+
raise ClientError(warn)
812+
813+
mp.log.info(
814+
f"will download files [{', '.join(files_to_download)}] in {len(download_list)} chunks, total size {total_size}"
815+
)
816+
job = DownloadJob(project_path, total_size, version, update_tasks, download_list, temp_dir, mp, project_info)
817+
job.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
818+
job.futures = []
819+
for item in download_list:
820+
future = job.executor.submit(_do_download, item, mc, mp, project_path, job)
821+
job.futures.append(future)
822+
823+
return job
824+
825+
826+
def download_files_finalize(job):
827+
"""
828+
To be called when download_file_async is finished
829+
"""
830+
job.executor.shutdown(wait=True)
831+
832+
# make sure any exceptions from threads are not lost
833+
for future in job.futures:
834+
if future.exception() is not None:
835+
raise future.exception()
836+
837+
job.mp.log.info("--- download finished")
838+
839+
for task in job.update_tasks:
840+
task.apply(job.directory, job.mp)
841+
842+
# Remove temporary download directory
843+
if job.directory is not None and os.path.exists(job.directory):
844+
shutil.rmtree(job.directory)

0 commit comments

Comments
 (0)