33import asyncio
44from base64 import b64decode , b64encode
55from collections import defaultdict
6- from collections .abc import Awaitable
6+ from collections .abc import AsyncGenerator , Awaitable
7+ from contextlib import asynccontextmanager
78from copy import deepcopy
89from datetime import timedelta
910from functools import cached_property
1213import logging
1314from pathlib import Path
1415import tarfile
16+ from tarfile import TarFile
1517from tempfile import TemporaryDirectory
1618import time
1719from typing import Any , Self
5658from ..utils import remove_folder
5759from ..utils .dt import parse_datetime , utcnow
5860from ..utils .json import json_bytes
61+ from ..utils .sentinel import DEFAULT
5962from .const import BUF_SIZE , LOCATION_CLOUD_BACKUP , BackupType
6063from .utils import key_to_iv , password_to_key
6164from .validate import SCHEMA_BACKUP
@@ -86,7 +89,6 @@ def __init__(
8689 self ._data : dict [str , Any ] = data or {ATTR_SLUG : slug }
8790 self ._tmp = None
8891 self ._outer_secure_tarfile : SecureTarFile | None = None
89- self ._outer_secure_tarfile_tarfile : tarfile .TarFile | None = None
9092 self ._key : bytes | None = None
9193 self ._aes : Cipher | None = None
9294 self ._locations : dict [str | None , Path ] = {location : tar_file }
@@ -375,59 +377,68 @@ def _load_file():
375377
376378 return True
377379
378- async def __aenter__ (self ):
379- """Async context to open a backup."""
380+ @asynccontextmanager
381+ async def create (self ) -> AsyncGenerator [None ]:
382+ """Create new backup file."""
383+ if self .tarfile .is_file ():
384+ raise BackupError (
385+ f"Cannot make new backup at { self .tarfile .as_posix ()} , file already exists!" ,
386+ _LOGGER .error ,
387+ )
380388
381- # create a backup
382- if not self .tarfile .is_file ():
383- self ._outer_secure_tarfile = SecureTarFile (
384- self .tarfile ,
385- "w" ,
386- gzip = False ,
387- bufsize = BUF_SIZE ,
389+ self ._outer_secure_tarfile = SecureTarFile (
390+ self .tarfile ,
391+ "w" ,
392+ gzip = False ,
393+ bufsize = BUF_SIZE ,
394+ )
395+ try :
396+ with self ._outer_secure_tarfile as outer_tarfile :
397+ yield
398+ await self ._create_cleanup (outer_tarfile )
399+ finally :
400+ self ._outer_secure_tarfile = None
401+
402+ @asynccontextmanager
403+ async def open (self , location : str | None | type [DEFAULT ]) -> AsyncGenerator [None ]:
404+ """Open backup for restore."""
405+ if location != DEFAULT and location not in self .all_locations :
406+ raise BackupError (
407+ f"Backup { self .slug } does not exist in location { location } " ,
408+ _LOGGER .error ,
409+ )
410+
411+ backup_tarfile = (
412+ self .tarfile if location == DEFAULT else self .all_locations [location ]
413+ )
414+ if not backup_tarfile .is_file ():
415+ raise BackupError (
416+ f"Cannot open backup at { backup_tarfile .as_posix ()} , file does not exist!" ,
417+ _LOGGER .error ,
388418 )
389- self ._outer_secure_tarfile_tarfile = self ._outer_secure_tarfile .__enter__ ()
390- return
391419
392420 # extract an existing backup
393- self ._tmp = TemporaryDirectory (dir = str (self . tarfile .parent ))
421+ self ._tmp = TemporaryDirectory (dir = str (backup_tarfile .parent ))
394422
395423 def _extract_backup ():
396424 """Extract a backup."""
397- with tarfile .open (self . tarfile , "r:" ) as tar :
425+ with tarfile .open (backup_tarfile , "r:" ) as tar :
398426 tar .extractall (
399427 path = self ._tmp .name ,
400428 members = secure_path (tar ),
401429 filter = "fully_trusted" ,
402430 )
403431
404- await self .sys_run_in_executor (_extract_backup )
405-
406- async def __aexit__ (self , exception_type , exception_value , traceback ):
407- """Async context to close a backup."""
408- # exists backup or exception on build
409- try :
410- await self ._aexit (exception_type , exception_value , traceback )
411- finally :
412- if self ._tmp :
413- self ._tmp .cleanup ()
414- if self ._outer_secure_tarfile :
415- self ._outer_secure_tarfile .__exit__ (
416- exception_type , exception_value , traceback
417- )
418- self ._outer_secure_tarfile = None
419- self ._outer_secure_tarfile_tarfile = None
432+ with self ._tmp :
433+ await self .sys_run_in_executor (_extract_backup )
434+ yield
420435
421- async def _aexit (self , exception_type , exception_value , traceback ) :
436+ async def _create_cleanup (self , outer_tarfile : TarFile ) -> None :
422437 """Cleanup after backup creation.
423438
424- This is a separate method to allow it to be called from __aexit__ to ensure
439+ Separate method to be called from create to ensure
425440 that cleanup is always performed, even if an exception is raised.
426441 """
427- # If we're not creating a new backup, or if an exception was raised, we're done
428- if not self ._outer_secure_tarfile or exception_type is not None :
429- return
430-
431442 # validate data
432443 try :
433444 self ._data = SCHEMA_BACKUP (self ._data )
@@ -445,7 +456,7 @@ def _add_backup_json():
445456 tar_info = tarfile .TarInfo (name = "./backup.json" )
446457 tar_info .size = len (raw_bytes )
447458 tar_info .mtime = int (time .time ())
448- self . _outer_secure_tarfile_tarfile .addfile (tar_info , fileobj = fileobj )
459+ outer_tarfile .addfile (tar_info , fileobj = fileobj )
449460
450461 try :
451462 await self .sys_run_in_executor (_add_backup_json )
0 commit comments