|
5 | 5 | import time |
6 | 6 | import zipfile |
7 | 7 | from collections.abc import Callable, Coroutine |
| 8 | +from contextlib import AsyncExitStack |
8 | 9 | from io import IOBase |
9 | 10 | from pathlib import Path |
10 | 11 | from typing import Any, Final, TypedDict, cast |
@@ -163,31 +164,50 @@ async def pull_file_from_remote( |
163 | 164 | storage_kwargs: S3FsSettingsDict | dict[str, Any] = {} |
164 | 165 | if s3_settings and src_url.scheme in S3_FILE_SYSTEM_SCHEMES: |
165 | 166 | storage_kwargs = _s3fs_settings_from_s3_settings(s3_settings) |
166 | | - await _copy_file( |
167 | | - src_url, |
168 | | - TypeAdapter(FileUrl).validate_python(dst_path.as_uri()), |
169 | | - src_storage_cfg=cast(dict[str, Any], storage_kwargs), |
170 | | - log_publishing_cb=log_publishing_cb, |
171 | | - text_prefix=f"Downloading '{src_url.path.strip('/')}':", |
172 | | - ) |
173 | 167 |
|
174 | | - await log_publishing_cb( |
175 | | - f"Download of '{src_url}' into local file '{dst_path}' complete.", |
176 | | - logging.INFO, |
| 168 | + need_extraction = (src_mime_type == _ZIP_MIME_TYPE) and ( |
| 169 | + target_mime_type != _ZIP_MIME_TYPE |
177 | 170 | ) |
178 | | - |
179 | | - if src_mime_type == _ZIP_MIME_TYPE and target_mime_type != _ZIP_MIME_TYPE: |
180 | | - await log_publishing_cb(f"Uncompressing '{dst_path.name}'...", logging.INFO) |
181 | | - logger.debug("%s is a zip file and will be now uncompressed", dst_path) |
182 | | - with repro_zipfile.ReproducibleZipFile(dst_path, "r") as zip_obj: |
183 | | - await asyncio.get_event_loop().run_in_executor( |
184 | | - None, zip_obj.extractall, dst_path.parents[0] |
| 171 | + async with AsyncExitStack() as exit_stack: |
| 172 | + if need_extraction: |
| 173 | + # we need to extract the file, so we create a temporary directory |
| 174 | + # where the file will be downloaded and extracted |
| 175 | + tmp_dir = await exit_stack.enter_async_context( |
| 176 | + aiofiles.tempfile.TemporaryDirectory() |
185 | 177 | ) |
186 | | - # finally remove the zip archive |
| 178 | + download_dst_path = Path(f"{tmp_dir}") / Path(src_url.path).name |
| 179 | + else: |
| 180 | + # no extraction needed, so we can use the provided dst_path directly |
| 181 | + download_dst_path = dst_path |
| 182 | + |
| 183 | + await _copy_file( |
| 184 | + src_url, |
| 185 | + TypeAdapter(FileUrl).validate_python(f"{download_dst_path.as_uri()}"), |
| 186 | + src_storage_cfg=cast(dict[str, Any], storage_kwargs), |
| 187 | + log_publishing_cb=log_publishing_cb, |
| 188 | + text_prefix=f"Downloading '{src_url.path.strip('/')}':", |
| 189 | + ) |
| 190 | + |
187 | 191 | await log_publishing_cb( |
188 | | - f"Uncompressing '{dst_path.name}' complete.", logging.INFO |
| 192 | + f"Download of '{src_url}' into local file '{download_dst_path}' complete.", |
| 193 | + logging.INFO, |
189 | 194 | ) |
190 | | - dst_path.unlink() |
| 195 | + |
| 196 | + if need_extraction: |
| 197 | + await log_publishing_cb( |
| 198 | + f"Uncompressing '{download_dst_path.name}'...", logging.INFO |
| 199 | + ) |
| 200 | + logger.debug( |
| 201 | + "%s is a zip file and will be now uncompressed", download_dst_path |
| 202 | + ) |
| 203 | + with repro_zipfile.ReproducibleZipFile(download_dst_path, "r") as zip_obj: |
| 204 | + await asyncio.get_event_loop().run_in_executor( |
| 205 | + None, zip_obj.extractall, dst_path.parents[0] |
| 206 | + ) |
| 207 | + # finally remove the zip archive |
| 208 | + await log_publishing_cb( |
| 209 | + f"Uncompressing '{download_dst_path.name}' complete.", logging.INFO |
| 210 | + ) |
191 | 211 |
|
192 | 212 |
|
193 | 213 | async def _push_file_to_http_link( |
|
0 commit comments