Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 31 additions & 27 deletions zstash/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import os.path
import sqlite3
import sys
from typing import Any, List, Tuple
from typing import Any, List, Optional, Tuple

from six.moves.urllib.parse import urlparse

from .globus import globus_activate, globus_finalize
from .hpss import hpss_put
from .hpss_utils import add_files
from .settings import DEFAULT_CACHE, config, get_db_filename, logger
from .transfer_tracking import GlobusTransferCollection, HPSSTransferCollection
from .utils import (
create_tars_table,
get_files_to_archive,
Expand Down Expand Up @@ -52,12 +53,13 @@ def create():
logger.error(input_path_error_str)
raise NotADirectoryError(input_path_error_str)

gtc: Optional[GlobusTransferCollection] = None
if hpss != "none":
url = urlparse(hpss)
if url.scheme == "globus":
# identify globus endpoints
logger.debug(f"{ts_utc()}:Calling globus_activate(hpss)")
globus_activate(hpss)
logger.debug(f"{ts_utc()}:Calling globus_activate()")
gtc = globus_activate(hpss)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a new object to hold state data, rather than using global variables.

else:
# config.hpss is not "none", so we need to
# create target HPSS directory
Expand Down Expand Up @@ -88,14 +90,23 @@ def create():

# Create and set up the database
logger.debug(f"{ts_utc()}: Calling create_database()")
failures: List[str] = create_database(cache, args)
htc: HPSSTransferCollection = HPSSTransferCollection()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same idea of using objects to store data. Don't confuse with GlobusTransferCollection().

failures: List[str] = create_database(cache, args, gtc=gtc, htc=htc)

# Transfer to HPSS. Always keep a local copy.
logger.debug(f"{ts_utc()}: calling hpss_put() for {get_db_filename(cache)}")
hpss_put(hpss, get_db_filename(cache), cache, keep=args.keep, is_index=True)
hpss_put(
hpss,
get_db_filename(cache),
cache,
keep=args.keep,
is_index=True,
gtc=gtc,
# htc=htc, # Don't track index.db for deletion
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to track index.db because we never delete it.

)

logger.debug(f"{ts_utc()}: calling globus_finalize()")
globus_finalize(non_blocking=args.non_blocking)
globus_finalize(gtc, htc, non_blocking=args.non_blocking)

if len(failures) > 0:
# List the failures
Expand Down Expand Up @@ -204,7 +215,12 @@ def setup_create() -> Tuple[str, argparse.Namespace]:
return cache, args


def create_database(cache: str, args: argparse.Namespace) -> List[str]:
def create_database(
cache: str,
args: argparse.Namespace,
gtc: Optional[GlobusTransferCollection],
htc: Optional[HPSSTransferCollection],
) -> List[str]:
# Create new database
logger.debug(f"{ts_utc()}:Creating index database")
if os.path.exists(get_db_filename(cache)):
Expand Down Expand Up @@ -263,26 +279,7 @@ def create_database(cache: str, args: argparse.Namespace) -> List[str]:
files: List[str] = get_files_to_archive(cache, args.include, args.exclude)

failures: List[str]
if args.follow_symlinks:
try:
# Add files to archive
failures = add_files(
cur,
con,
-1,
files,
cache,
args.keep,
args.follow_symlinks,
skip_tars_md5=args.no_tars_md5,
non_blocking=args.non_blocking,
error_on_duplicate_tar=args.error_on_duplicate_tar,
overwrite_duplicate_tars=args.overwrite_duplicate_tars,
force_database_corruption=args.for_developers_force_database_corruption,
)
except FileNotFoundError:
raise Exception("Archive creation failed due to broken symlink.")
else:
try:
# Add files to archive
failures = add_files(
cur,
Expand All @@ -297,7 +294,14 @@ def create_database(cache: str, args: argparse.Namespace) -> List[str]:
error_on_duplicate_tar=args.error_on_duplicate_tar,
overwrite_duplicate_tars=args.overwrite_duplicate_tars,
force_database_corruption=args.for_developers_force_database_corruption,
gtc=gtc,
htc=htc,
)
except FileNotFoundError as e:
if args.follow_symlinks:
raise Exception("Archive creation failed due to broken symlink.")
Comment on lines +301 to +302
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just to make the code easier to read and avoid an identical function call being written out twice (once in each part of the if/else block).

else:
raise e

# Close database
con.commit()
Expand Down
Loading
Loading