1313"""
1414
1515import copy as copy
16+ import functools
17+ import hashlib
18+ import importlib .metadata
1619import itertools as it
17- import os
20+ import subprocess
21+ import tomllib
1822import warnings
23+ from pathlib import Path
1924from typing import Any
2025
2126import numpy as np
2227import yaml
2328
2429from spras .config .container_schema import ProcessedContainerSettings
25- from spras .config .schema import RawConfig
26- from spras .util import NpHashEncoder , hash_params_sha1_base32
30+ from spras .config .schema import DatasetSchema , RawConfig
31+ from spras .util import LoosePathLike , NpHashEncoder , hash_params_sha1_base32
2732
2833config = None
2934
35+ @functools .cache
36+ def spras_revision () -> str :
37+ """
38+ Gets the revision of the current SPRAS repository. This function is meant to be user-friendly to warn for bad SPRAS installs.
39+ 1. If this file is inside the correct `.git` repository, we use the revision hash. This is for development in SPRAS as well as SPRAS installs via a cloned git repository.
40+ 2. If SPRAS was installed via a PyPA-compliant package manager, we use the hash of the RECORD file (https://packaging.python.org/en/latest/specifications/recording-installed-packages/#the-record-file).
41+ which contains the hashes of all installed files to the package.
42+ """
43+ clone_tip = "Make sure SPRAS is installed through the installation instructions: https://spras.readthedocs.io/en/latest/install.html."
44+
45+ # Check if we're inside the right git repository
46+ try :
47+ project_directory = subprocess .check_output (
48+ ["git" , "rev-parse" , "--show-toplevel" ],
49+ encoding = 'utf-8' ,
50+ # In case the CWD is not inside the actual SPRAS directory
51+ cwd = Path (__file__ ).parent .resolve ()
52+ ).strip ()
53+
54+ # We check the pyproject.toml name attribute to confirm that this is the SPRAS project. This is susceptible
55+ # to false negatives, but we use this as a preliminary check against bad SPRAS installs.
56+ pyproject_path = Path (project_directory , 'pyproject.toml' )
57+ try :
58+ pyproject_toml = tomllib .loads (pyproject_path .read_text ())
59+ if "project" not in pyproject_toml or "name" not in pyproject_toml ["project" ]:
60+ raise RuntimeError (f"The git top-level `{ pyproject_path } ` does not have the expected attributes. { clone_tip } " )
61+ if pyproject_toml ["project" ]["name" ] != "spras" :
62+ raise RuntimeError (f"The git top-level `{ pyproject_path } ` is not the SPRAS pyproject.toml. { clone_tip } " )
63+
64+ return subprocess .check_output (
65+ ["git" , "rev-parse" , "--short" , "HEAD" ],
66+ encoding = 'utf-8' ,
67+ cwd = project_directory
68+ ).strip ()
69+ except FileNotFoundError as err :
70+ # pyproject.toml wasn't found during the `read_text` call
71+ raise RuntimeError (f"The git top-level { pyproject_path } wasn't found. { clone_tip } " ) from err
72+ except tomllib .TOMLDecodeError as err :
73+ raise RuntimeError (f"The git top-level { pyproject_path } is malformed. { clone_tip } " ) from err
74+ except subprocess .CalledProcessError :
75+ try :
76+ # `git` failed: use the truncated hash of the RECORD file in .dist-info instead.
77+ record_path = str (importlib .metadata .distribution ('spras' ).locate_file (f"spras-{ importlib .metadata .version ('spras' )} .dist-info/RECORD" ))
78+ with open (record_path , 'rb' , buffering = 0 ) as f :
79+ # Truncated to the magic value 8, the length of the short git revision.
80+ return hashlib .file_digest (f , 'sha256' ).hexdigest ()[:8 ]
81+ except importlib .metadata .PackageNotFoundError as err :
82+ # The metadata.distribution call failed.
83+ raise RuntimeError (f"The spras package wasn't found: { clone_tip } " ) from err
84+
85+ def attach_spras_revision (label : str ) -> str :
86+ return f"{ label } _{ spras_revision ()} "
87+
3088# This will get called in the Snakefile, instantiating the singleton with the raw config
3189def init_global (config_dict ):
3290 global config
3391 config = Config (config_dict )
3492
3593def init_from_file (filepath ):
3694 global config
37-
38- # Handle opening the file and parsing the yaml
39- filepath = os .path .abspath (filepath )
40- try :
41- with open (filepath , 'r' ) as yaml_file :
42- config_dict = yaml .safe_load (yaml_file )
43- except FileNotFoundError as e :
44- raise RuntimeError (f"Error: The specified config '{ filepath } ' could not be found." ) from e
45- except yaml .YAMLError as e :
46- raise RuntimeError (f"Error: Failed to parse config '{ filepath } '" ) from e
47-
48- # And finally, initialize
49- config = Config (config_dict )
95+ config = Config .from_file (filepath )
5096
5197
5298class Config :
@@ -64,7 +110,7 @@ def __init__(self, raw_config: dict[str, Any]):
64110 # Directory used for storing output
65111 self .out_dir = parsed_raw_config .reconstruction_settings .locations .reconstruction_dir
66112 # A dictionary to store configured datasets against which SPRAS will be run
67- self .datasets = None
113+ self .datasets : dict [ str , DatasetSchema ] = {}
68114 # A dictionary to store configured gold standard data against output of SPRAS runs
69115 self .gold_standards = None
70116 # The hash length SPRAS will use to identify parameter combinations.
@@ -103,6 +149,20 @@ def __init__(self, raw_config: dict[str, Any]):
103149
104150 self .process_config (parsed_raw_config )
105151
152+ @classmethod
153+ def from_file (cls , filepath : LoosePathLike ):
154+ # Handle opening the file and parsing the yaml
155+ filepath = Path (filepath ).absolute ()
156+ try :
157+ with open (filepath , 'r' ) as yaml_file :
158+ config_dict = yaml .safe_load (yaml_file )
159+ except FileNotFoundError as e :
160+ raise RuntimeError (f"Error: The specified config '{ filepath } ' could not be found." ) from e
161+ except yaml .YAMLError as e :
162+ raise RuntimeError (f"Error: Failed to parse config '{ filepath } '" ) from e
163+
164+ return cls (config_dict )
165+
106166 def process_datasets (self , raw_config : RawConfig ):
107167 """
108168 Parse dataset information
@@ -115,12 +175,17 @@ def process_datasets(self, raw_config: RawConfig):
115175 # Currently assumes all datasets have a label and the labels are unique
116176 # When Snakemake parses the config file it loads the datasets as OrderedDicts not dicts
117177 # Convert to dicts to simplify the yaml logging
118- self .datasets = {}
178+
179+ for dataset in raw_config .datasets :
180+ dataset .label = attach_spras_revision (dataset .label )
181+ for gold_standard in raw_config .gold_standards :
182+ gold_standard .label = attach_spras_revision (gold_standard .label )
183+
119184 for dataset in raw_config .datasets :
120185 label = dataset .label
121186 if label .lower () in [key .lower () for key in self .datasets .keys ()]:
122187 raise ValueError (f"Datasets must have unique case-insensitive labels, but the label { label } appears at least twice." )
123- self .datasets [label ] = dict ( dataset )
188+ self .datasets [label ] = dataset
124189
125190 # parse gold standard information
126191 self .gold_standards = {gold_standard .label : dict (gold_standard ) for gold_standard in raw_config .gold_standards }
@@ -129,8 +194,11 @@ def process_datasets(self, raw_config: RawConfig):
129194 dataset_labels = set (self .datasets .keys ())
130195 gold_standard_dataset_labels = {dataset_label for value in self .gold_standards .values () for dataset_label in value ['dataset_labels' ]}
131196 for label in gold_standard_dataset_labels :
132- if label not in dataset_labels :
197+ if attach_spras_revision ( label ) not in dataset_labels :
133198 raise ValueError (f"Dataset label '{ label } ' provided in gold standards does not exist in the existing dataset labels." )
199+ # We attach the SPRAS revision to the individual dataset labels afterwards for a cleaner error message above.
200+ for key , gold_standard in self .gold_standards .items ():
201+ self .gold_standards [key ]["dataset_labels" ] = map (attach_spras_revision , gold_standard ["dataset_labels" ])
134202
135203 # Code snipped from Snakefile that may be useful for assigning default labels
136204 # dataset_labels = [dataset.get('label', f'dataset{index}') for index, dataset in enumerate(datasets)]
@@ -186,7 +254,10 @@ def process_algorithms(self, raw_config: RawConfig):
186254 run_dict [param ] = float (value )
187255 if isinstance (value , np .ndarray ):
188256 run_dict [param ] = value .tolist ()
189- params_hash = hash_params_sha1_base32 (run_dict , self .hash_length , cls = NpHashEncoder )
257+ # Incorporates the `spras_revision` into the hash
258+ hash_run_dict = copy .deepcopy (run_dict )
259+ hash_run_dict ["_spras_rev" ] = spras_revision ()
260+ params_hash = hash_params_sha1_base32 (hash_run_dict , self .hash_length , cls = NpHashEncoder )
190261 if params_hash in prior_params_hashes :
191262 raise ValueError (f'Parameter hash collision detected. Increase the hash_length in the config file '
192263 f'(current length { self .hash_length } ).' )
0 commit comments