Skip to content

Commit 339d915

Browse files
committed
Merge branch 'hash' into lazy-stats
2 parents 79cf748 + 6b941e0 commit 339d915

File tree

78 files changed

+567
-177282
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+567
-177282
lines changed

Snakefile

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def get_dataset(_datasets, label):
3535
algorithms = list(algorithm_params)
3636
algorithms_with_params = [f'{algorithm}-params-{params_hash}' for algorithm, param_combos in algorithm_params.items() for params_hash in param_combos.keys()]
3737
dataset_labels = list(_config.config.datasets.keys())
38-
3938
dataset_gold_standard_node_pairs = [f"{dataset}-{gs['label']}" for gs in _config.config.gold_standards.values() if gs['node_files'] for dataset in gs['dataset_labels']]
4039
dataset_gold_standard_edge_pairs = [f"{dataset}-{gs['label']}" for gs in _config.config.gold_standards.values() if gs['edge_files'] for dataset in gs['dataset_labels']]
4140

@@ -62,10 +61,10 @@ def write_parameter_log(algorithm, param_label, logfile):
6261
def write_dataset_log(dataset, logfile):
6362
dataset_contents = get_dataset(_config.config.datasets,dataset)
6463

65-
# safe_dump gives RepresenterError for an OrderedDict
66-
# config file has to convert the dataset from OrderedDict to dict to avoid this
67-
with open(logfile,'w') as f:
68-
yaml.safe_dump(dataset_contents,f)
64+
# safe_dump gives RepresenterError for a DatasetSchema
65+
# config file has to convert the dataset to a dict to avoid this
66+
with open(logfile, 'w') as f:
67+
yaml.safe_dump(dict(dataset_contents), f)
6968

7069
# Choose the final files expected according to the config file options.
7170
def make_final_input(wildcards):
@@ -156,9 +155,9 @@ rule log_datasets:
156155
# Input preparation needs to be rerun if these files are modified
157156
def get_dataset_dependencies(wildcards):
158157
dataset = _config.config.datasets[wildcards.dataset]
159-
all_files = dataset["node_files"] + dataset["edge_files"] + dataset["other_files"]
158+
all_files = dataset.node_files + dataset.edge_files + dataset.other_files
160159
# Add the relative file path
161-
all_files = [dataset["data_dir"] + SEP + data_file for data_file in all_files]
160+
all_files = [dataset.data_dir + SEP + data_file for data_file in all_files]
162161

163162
return all_files
164163

@@ -283,7 +282,7 @@ rule reconstruct:
283282
# Original pathway reconstruction output to universal output
284283
# Use PRRunner as a wrapper to call the algorithm-specific parse_output
285284
rule parse_output:
286-
input:
285+
input:
287286
raw_file = SEP.join([out_dir, '{dataset}-{algorithm}-{params}', 'raw-pathway.txt']),
288287
dataset_file = SEP.join([out_dir, 'dataset-{dataset}-merged.pickle'])
289288
output: standardized_file = SEP.join([out_dir, '{dataset}-{algorithm}-{params}', 'pathway.txt'])

docs/htcondor.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@ might look like:
5454

5555
.. code:: bash
5656
57-
apptainer build spras-v0.6.0.sif docker://reedcompbio/spras:v0.6.0
57+
apptainer build spras-v0.6.0.sif docker://reedcompbio/spras:0.6.0
5858
5959
After running this command, a new file called ``spras-v0.6.0.sif`` will
60-
exist in the directory where the command was run.
60+
exist in the directory where the command was run. Note that the Docker
61+
image does not use a "v" in the tag.
6162

6263
Submitting All Jobs to a Single EP
6364
----------------------------------

spras/allpairs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ def generate_inputs(data: Dataset, filename_map):
3535
# Get sources and targets for node input file
3636
# Borrowed code from pathlinker.py
3737
sources_targets = data.get_node_columns(["sources", "targets"])
38-
if sources_targets is None:
39-
raise ValueError("All Pairs Shortest Paths requires sources and targets")
4038

4139
both_series = sources_targets.sources & sources_targets.targets
4240
for _index, row in sources_targets[both_series].iterrows():

spras/analysis/ml.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,11 @@ def jaccard_similarity_eval(summary_df: pd.DataFrame, output_file: str | PathLik
459459
ax.set_yticklabels(algorithms)
460460
plt.colorbar(cax, ax=ax)
461461
# annotate each cell with the corresponding similarity value
462+
# where we set the precision to be lower as the number of algorithms increases
463+
n = 2
464+
if len(algorithms) > 10: n = 1
462465
for i in range(len(algorithms)):
463466
for j in range(len(algorithms)):
464-
ax.text(j, i, f'{jaccard_matrix.values[i, j]:.2f}', ha='center', va='center', color='white')
467+
ax.text(j, i, f'{jaccard_matrix.values[i, j]:.{n}f}', ha='center', va='center', color='white')
465468
plt.savefig(output_png, bbox_inches="tight", dpi=DPI)
466469
plt.close()

spras/analysis/summary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, algo_params: dict[str, dict],
11-
algo_with_params: list, statistics_files: list) -> pd.DataFrame:
11+
algo_with_params: list[str], statistics_files: list) -> pd.DataFrame:
1212
"""
1313
Generate a table that aggregates summary information about networks in file_paths, including which nodes are present
1414
in node_table columns. Network directionality is ignored and all edges are treated as undirected. The order of the

spras/btb.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,8 @@ def generate_inputs(data, filename_map):
4444

4545
# Get sources and write to file, repeat for targets
4646
# Does not check whether a node is a source and a target
47-
for node_type in ['sources', 'targets']:
48-
nodes = data.get_node_columns([node_type])
49-
if nodes is None:
50-
raise ValueError(f'No {node_type} found in the node files')
51-
52-
# TODO test whether this selection is needed, what values could the column contain that we would want to
53-
# include or exclude?
54-
nodes = nodes.loc[nodes[node_type]]
55-
if node_type == "sources":
56-
nodes.to_csv(filename_map["sources"], sep= '\t', index=False, columns=['NODEID'], header=False)
57-
elif node_type == "targets":
58-
nodes.to_csv(filename_map["targets"], sep= '\t', index=False, columns=['NODEID'], header=False)
59-
47+
for node_type, nodes in data.get_node_columns_separate(['sources', 'targets']).items():
48+
nodes.to_csv(filename_map[node_type], sep='\t', index=False, columns=['NODEID'], header=False)
6049

6150
# Create network file
6251
edges = data.get_interactome()

spras/config/config.py

Lines changed: 92 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,86 @@
1313
"""
1414

1515
import copy as copy
16+
import functools
17+
import hashlib
18+
import importlib.metadata
1619
import itertools as it
17-
import os
20+
import subprocess
21+
import tomllib
1822
import warnings
23+
from pathlib import Path
1924
from typing import Any
2025

2126
import numpy as np
2227
import yaml
2328

2429
from spras.config.container_schema import ProcessedContainerSettings
25-
from spras.config.schema import RawConfig
26-
from spras.util import NpHashEncoder, hash_params_sha1_base32
30+
from spras.config.schema import DatasetSchema, RawConfig
31+
from spras.util import LoosePathLike, NpHashEncoder, hash_params_sha1_base32
2732

2833
config = None
2934

35+
@functools.cache
36+
def spras_revision() -> str:
37+
"""
38+
Gets the revision of the current SPRAS repository. This function is meant to be user-friendly to warn for bad SPRAS installs.
39+
1. If this file is inside the correct `.git` repository, we use the revision hash. This is for development in SPRAS as well as SPRAS installs via a cloned git repository.
40+
2. If SPRAS was installed via a PyPA-compliant package manager, we use the hash of the RECORD file (https://packaging.python.org/en/latest/specifications/recording-installed-packages/#the-record-file).
41+
which contains the hashes of all installed files to the package.
42+
"""
43+
clone_tip = "Make sure SPRAS is installed through the installation instructions: https://spras.readthedocs.io/en/latest/install.html."
44+
45+
# Check if we're inside the right git repository
46+
try:
47+
project_directory = subprocess.check_output(
48+
["git", "rev-parse", "--show-toplevel"],
49+
encoding='utf-8',
50+
# In case the CWD is not inside the actual SPRAS directory
51+
cwd=Path(__file__).parent.resolve()
52+
).strip()
53+
54+
# We check the pyproject.toml name attribute to confirm that this is the SPRAS project. This is susceptible
55+
# to false negatives, but we use this as a preliminary check against bad SPRAS installs.
56+
pyproject_path = Path(project_directory, 'pyproject.toml')
57+
try:
58+
pyproject_toml = tomllib.loads(pyproject_path.read_text())
59+
if "project" not in pyproject_toml or "name" not in pyproject_toml["project"]:
60+
raise RuntimeError(f"The git top-level `{pyproject_path}` does not have the expected attributes. {clone_tip}")
61+
if pyproject_toml["project"]["name"] != "spras":
62+
raise RuntimeError(f"The git top-level `{pyproject_path}` is not the SPRAS pyproject.toml. {clone_tip}")
63+
64+
return subprocess.check_output(
65+
["git", "rev-parse", "--short", "HEAD"],
66+
encoding='utf-8',
67+
cwd=project_directory
68+
).strip()
69+
except FileNotFoundError as err:
70+
# pyproject.toml wasn't found during the `read_text` call
71+
raise RuntimeError(f"The git top-level {pyproject_path} wasn't found. {clone_tip}") from err
72+
except tomllib.TOMLDecodeError as err:
73+
raise RuntimeError(f"The git top-level {pyproject_path} is malformed. {clone_tip}") from err
74+
except subprocess.CalledProcessError:
75+
try:
76+
# `git` failed: use the truncated hash of the RECORD file in .dist-info instead.
77+
record_path = str(importlib.metadata.distribution('spras').locate_file(f"spras-{importlib.metadata.version('spras')}.dist-info/RECORD"))
78+
with open(record_path, 'rb', buffering=0) as f:
79+
# Truncated to the magic value 8, the length of the short git revision.
80+
return hashlib.file_digest(f, 'sha256').hexdigest()[:8]
81+
except importlib.metadata.PackageNotFoundError as err:
82+
# The metadata.distribution call failed.
83+
raise RuntimeError(f"The spras package wasn't found: {clone_tip}") from err
84+
85+
def attach_spras_revision(label: str) -> str:
86+
return f"{label}_{spras_revision()}"
87+
3088
# This will get called in the Snakefile, instantiating the singleton with the raw config
3189
def init_global(config_dict):
3290
global config
3391
config = Config(config_dict)
3492

3593
def init_from_file(filepath):
3694
global config
37-
38-
# Handle opening the file and parsing the yaml
39-
filepath = os.path.abspath(filepath)
40-
try:
41-
with open(filepath, 'r') as yaml_file:
42-
config_dict = yaml.safe_load(yaml_file)
43-
except FileNotFoundError as e:
44-
raise RuntimeError(f"Error: The specified config '{filepath}' could not be found.") from e
45-
except yaml.YAMLError as e:
46-
raise RuntimeError(f"Error: Failed to parse config '{filepath}'") from e
47-
48-
# And finally, initialize
49-
config = Config(config_dict)
95+
config = Config.from_file(filepath)
5096

5197

5298
class Config:
@@ -64,7 +110,7 @@ def __init__(self, raw_config: dict[str, Any]):
64110
# Directory used for storing output
65111
self.out_dir = parsed_raw_config.reconstruction_settings.locations.reconstruction_dir
66112
# A dictionary to store configured datasets against which SPRAS will be run
67-
self.datasets = None
113+
self.datasets: dict[str, DatasetSchema] = {}
68114
# A dictionary to store configured gold standard data against output of SPRAS runs
69115
self.gold_standards = None
70116
# The hash length SPRAS will use to identify parameter combinations.
@@ -103,6 +149,20 @@ def __init__(self, raw_config: dict[str, Any]):
103149

104150
self.process_config(parsed_raw_config)
105151

152+
@classmethod
153+
def from_file(cls, filepath: LoosePathLike):
154+
# Handle opening the file and parsing the yaml
155+
filepath = Path(filepath).absolute()
156+
try:
157+
with open(filepath, 'r') as yaml_file:
158+
config_dict = yaml.safe_load(yaml_file)
159+
except FileNotFoundError as e:
160+
raise RuntimeError(f"Error: The specified config '{filepath}' could not be found.") from e
161+
except yaml.YAMLError as e:
162+
raise RuntimeError(f"Error: Failed to parse config '{filepath}'") from e
163+
164+
return cls(config_dict)
165+
106166
def process_datasets(self, raw_config: RawConfig):
107167
"""
108168
Parse dataset information
@@ -115,12 +175,17 @@ def process_datasets(self, raw_config: RawConfig):
115175
# Currently assumes all datasets have a label and the labels are unique
116176
# When Snakemake parses the config file it loads the datasets as OrderedDicts not dicts
117177
# Convert to dicts to simplify the yaml logging
118-
self.datasets = {}
178+
179+
for dataset in raw_config.datasets:
180+
dataset.label = attach_spras_revision(dataset.label)
181+
for gold_standard in raw_config.gold_standards:
182+
gold_standard.label = attach_spras_revision(gold_standard.label)
183+
119184
for dataset in raw_config.datasets:
120185
label = dataset.label
121186
if label.lower() in [key.lower() for key in self.datasets.keys()]:
122187
raise ValueError(f"Datasets must have unique case-insensitive labels, but the label {label} appears at least twice.")
123-
self.datasets[label] = dict(dataset)
188+
self.datasets[label] = dataset
124189

125190
# parse gold standard information
126191
self.gold_standards = {gold_standard.label: dict(gold_standard) for gold_standard in raw_config.gold_standards}
@@ -129,8 +194,11 @@ def process_datasets(self, raw_config: RawConfig):
129194
dataset_labels = set(self.datasets.keys())
130195
gold_standard_dataset_labels = {dataset_label for value in self.gold_standards.values() for dataset_label in value['dataset_labels']}
131196
for label in gold_standard_dataset_labels:
132-
if label not in dataset_labels:
197+
if attach_spras_revision(label) not in dataset_labels:
133198
raise ValueError(f"Dataset label '{label}' provided in gold standards does not exist in the existing dataset labels.")
199+
# We attach the SPRAS revision to the individual dataset labels afterwards for a cleaner error message above.
200+
for key, gold_standard in self.gold_standards.items():
201+
self.gold_standards[key]["dataset_labels"] = map(attach_spras_revision, gold_standard["dataset_labels"])
134202

135203
# Code snipped from Snakefile that may be useful for assigning default labels
136204
# dataset_labels = [dataset.get('label', f'dataset{index}') for index, dataset in enumerate(datasets)]
@@ -186,7 +254,10 @@ def process_algorithms(self, raw_config: RawConfig):
186254
run_dict[param] = float(value)
187255
if isinstance(value, np.ndarray):
188256
run_dict[param] = value.tolist()
189-
params_hash = hash_params_sha1_base32(run_dict, self.hash_length, cls=NpHashEncoder)
257+
# Incorporates the `spras_revision` into the hash
258+
hash_run_dict = copy.deepcopy(run_dict)
259+
hash_run_dict["_spras_rev"] = spras_revision()
260+
params_hash = hash_params_sha1_base32(hash_run_dict, self.hash_length, cls=NpHashEncoder)
190261
if params_hash in prior_params_hashes:
191262
raise ValueError(f'Parameter hash collision detected. Increase the hash_length in the config file '
192263
f'(current length {self.hash_length}).')

spras/config/dataset.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Annotated
2+
3+
from pydantic import AfterValidator, BaseModel, ConfigDict
4+
5+
from spras.config.util import label_validator
6+
from spras.util import LoosePathLike
7+
8+
9+
class DatasetSchema(BaseModel):
10+
"""
11+
Collection of information related to `Dataset` objects in the configuration.
12+
"""
13+
14+
# We prefer AfterValidator here to allow pydantic to run its own
15+
# validation & coercion logic before we check it against our own
16+
# requirements
17+
label: Annotated[str, AfterValidator(label_validator("Dataset"))]
18+
node_files: list[LoosePathLike]
19+
edge_files: list[LoosePathLike]
20+
other_files: list[LoosePathLike]
21+
data_dir: LoosePathLike
22+
23+
model_config = ConfigDict(extra='forbid')

spras/config/schema.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
- `CaseInsensitiveEnum` (see ./util.py)
1111
"""
1212

13-
import re
1413
from typing import Annotated
1514

1615
from pydantic import AfterValidator, BaseModel, ConfigDict
1716

1817
from spras.config.algorithms import AlgorithmUnion
1918
from spras.config.container_schema import ContainerSettings
20-
from spras.config.util import CaseInsensitiveEnum
19+
from spras.config.dataset import DatasetSchema
20+
from spras.config.util import CaseInsensitiveEnum, label_validator
2121

2222
# Most options here have an `include` property,
2323
# which is meant to make disabling parts of the configuration easier.
@@ -79,30 +79,6 @@ class Analysis(BaseModel):
7979
# The default length of the truncated hash used to identify parameter combinations
8080
DEFAULT_HASH_LENGTH = 7
8181

82-
def label_validator(name: str):
83-
"""
84-
A validator takes in a label
85-
and ensures that it contains only letters, numbers, or underscores.
86-
"""
87-
label_pattern = r'^\w+$'
88-
def validate(label: str):
89-
if not bool(re.match(label_pattern, label)):
90-
raise ValueError(f"{name} label '{label}' contains invalid values. {name} labels can only contain letters, numbers, or underscores.")
91-
return label
92-
return validate
93-
94-
class Dataset(BaseModel):
95-
# We prefer AfterValidator here to allow pydantic to run its own
96-
# validation & coercion logic before we check it against our own
97-
# requirements
98-
label: Annotated[str, AfterValidator(label_validator("Dataset"))]
99-
node_files: list[str]
100-
edge_files: list[str]
101-
other_files: list[str]
102-
data_dir: str
103-
104-
model_config = ConfigDict(extra='forbid')
105-
10682
class GoldStandard(BaseModel):
10783
label: Annotated[str, AfterValidator(label_validator("Gold Standard"))]
10884
node_files: list[str] = []
@@ -131,7 +107,7 @@ class RawConfig(BaseModel):
131107

132108
# See algorithms.py for more information about AlgorithmUnion
133109
algorithms: list[AlgorithmUnion] # type: ignore - pydantic allows this.
134-
datasets: list[Dataset]
110+
datasets: list[DatasetSchema]
135111
gold_standards: list[GoldStandard] = []
136112
analysis: Analysis = Analysis()
137113

spras/config/util.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,26 @@
44
only import this config file.
55
"""
66

7+
import re
78
from enum import Enum
89
from typing import Any
910

1011
import yaml
1112
from pydantic import BaseModel, ConfigDict
1213

1314

15+
def label_validator(name: str):
16+
"""
17+
A validator takes in a label
18+
and ensures that it contains only letters, numbers, or underscores.
19+
"""
20+
label_pattern = r'^\w+$'
21+
def validate(label: str):
22+
if not bool(re.match(label_pattern, label)):
23+
raise ValueError(f"{name} label '{label}' contains invalid values. {name} labels can only contain letters, numbers, or underscores.")
24+
return label
25+
return validate
26+
1427
# https://stackoverflow.com/a/76883868/7589775
1528
class CaseInsensitiveEnum(str, Enum):
1629
"""

0 commit comments

Comments
 (0)