Skip to content

Commit fb4c0fe

Browse files
committed
Merge branch 'main' into ml-separate
2 parents 3254623 + 18f2cf8 commit fb4c0fe

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+373
-258
lines changed

Snakefile

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ def write_parameter_log(algorithm, param_label, logfile):
6464
def write_dataset_log(dataset, logfile):
6565
dataset_contents = get_dataset(_config.config.datasets,dataset)
6666

67-
# safe_dump gives RepresenterError for an OrderedDict
68-
# config file has to convert the dataset from OrderedDict to dict to avoid this
69-
with open(logfile,'w') as f:
70-
yaml.safe_dump(dataset_contents,f)
67+
# safe_dump gives RepresenterError for a DatasetSchema
68+
# config file has to convert the dataset to a dict to avoid this
69+
with open(logfile, 'w') as f:
70+
yaml.safe_dump(dict(dataset_contents), f)
7171

7272
# Choose the final files expected according to the config file options.
7373
def make_final_input(wildcards):
@@ -179,9 +179,9 @@ rule log_datasets:
179179
# Input preparation needs to be rerun if these files are modified
180180
def get_dataset_dependencies(wildcards):
181181
dataset = _config.config.datasets[wildcards.dataset]
182-
all_files = dataset["node_files"] + dataset["edge_files"] + dataset["other_files"]
182+
all_files = dataset.node_files + dataset.edge_files + dataset.other_files
183183
# Add the relative file path
184-
all_files = [dataset["data_dir"] + SEP + data_file for data_file in all_files]
184+
all_files = [dataset.data_dir + SEP + data_file for data_file in all_files]
185185

186186
return all_files
187187

docs/htcondor.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@ might look like:
5454

5555
.. code:: bash
5656
57-
apptainer build spras-v0.6.0.sif docker://reedcompbio/spras:v0.6.0
57+
apptainer build spras-v0.6.0.sif docker://reedcompbio/spras:0.6.0
5858
5959
After running this command, a new file called ``spras-v0.6.0.sif`` will
60-
exist in the directory where the command was run.
60+
exist in the directory where the command was run. Note that the Docker
61+
image does not use a "v" in the tag.
6162

6263
Submitting All Jobs to a Single EP
6364
----------------------------------

spras/allpairs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ def generate_inputs(data: Dataset, filename_map):
3535
# Get sources and targets for node input file
3636
# Borrowed code from pathlinker.py
3737
sources_targets = data.get_node_columns(["sources", "targets"])
38-
if sources_targets is None:
39-
raise ValueError("All Pairs Shortest Paths requires sources and targets")
4038

4139
both_series = sources_targets.sources & sources_targets.targets
4240
for _index, row in sources_targets[both_series].iterrows():

spras/analysis/ml.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,11 @@ def jaccard_similarity_eval(summary_df: pd.DataFrame, output_file: str | PathLik
459459
ax.set_yticklabels(algorithms)
460460
plt.colorbar(cax, ax=ax)
461461
# annotate each cell with the corresponding similarity value
462+
# where we set the precision to be lower as the number of algorithms increases
463+
n = 2
464+
if len(algorithms) > 10: n = 1
462465
for i in range(len(algorithms)):
463466
for j in range(len(algorithms)):
464-
ax.text(j, i, f'{jaccard_matrix.values[i, j]:.2f}', ha='center', va='center', color='white')
467+
ax.text(j, i, f'{jaccard_matrix.values[i, j]:.{n}f}', ha='center', va='center', color='white')
465468
plt.savefig(output_png, bbox_inches="tight", dpi=DPI)
466469
plt.close()

spras/btb.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,8 @@ def generate_inputs(data, filename_map):
4444

4545
# Get sources and write to file, repeat for targets
4646
# Does not check whether a node is a source and a target
47-
for node_type in ['sources', 'targets']:
48-
nodes = data.get_node_columns([node_type])
49-
if nodes is None:
50-
raise ValueError(f'No {node_type} found in the node files')
51-
52-
# TODO test whether this selection is needed, what values could the column contain that we would want to
53-
# include or exclude?
54-
nodes = nodes.loc[nodes[node_type]]
55-
if node_type == "sources":
56-
nodes.to_csv(filename_map["sources"], sep= '\t', index=False, columns=['NODEID'], header=False)
57-
elif node_type == "targets":
58-
nodes.to_csv(filename_map["targets"], sep= '\t', index=False, columns=['NODEID'], header=False)
59-
47+
for node_type, nodes in data.get_node_columns_separate(['sources', 'targets']).items():
48+
nodes.to_csv(filename_map[node_type], sep='\t', index=False, columns=['NODEID'], header=False)
6049

6150
# Create network file
6251
edges = data.get_interactome()

spras/config/config.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@
1414

1515
import copy as copy
1616
import itertools as it
17-
import os
1817
import warnings
18+
from pathlib import Path
1919
from typing import Any
2020

2121
import numpy as np
2222
import yaml
2323

2424
from spras.config.container_schema import ProcessedContainerSettings
25-
from spras.config.schema import RawConfig
26-
from spras.util import NpHashEncoder, hash_params_sha1_base32
25+
from spras.config.schema import DatasetSchema, RawConfig
26+
from spras.util import LoosePathLike, NpHashEncoder, hash_params_sha1_base32
2727

2828
config = None
2929

@@ -34,19 +34,7 @@ def init_global(config_dict):
3434

3535
def init_from_file(filepath):
3636
global config
37-
38-
# Handle opening the file and parsing the yaml
39-
filepath = os.path.abspath(filepath)
40-
try:
41-
with open(filepath, 'r') as yaml_file:
42-
config_dict = yaml.safe_load(yaml_file)
43-
except FileNotFoundError as e:
44-
raise RuntimeError(f"Error: The specified config '{filepath}' could not be found.") from e
45-
except yaml.YAMLError as e:
46-
raise RuntimeError(f"Error: Failed to parse config '{filepath}'") from e
47-
48-
# And finally, initialize
49-
config = Config(config_dict)
37+
config = Config.from_file(filepath)
5038

5139

5240
class Config:
@@ -64,7 +52,7 @@ def __init__(self, raw_config: dict[str, Any]):
6452
# Directory used for storing output
6553
self.out_dir = parsed_raw_config.reconstruction_settings.locations.reconstruction_dir
6654
# A dictionary to store configured datasets against which SPRAS will be run
67-
self.datasets = None
55+
self.datasets: dict[str, DatasetSchema] = {}
6856
# A dictionary to store configured gold standard data against output of SPRAS runs
6957
self.gold_standards = None
7058
# The hash length SPRAS will use to identify parameter combinations.
@@ -81,6 +69,20 @@ def __init__(self, raw_config: dict[str, Any]):
8169

8270
self.process_config(parsed_raw_config)
8371

72+
@classmethod
73+
def from_file(cls, filepath: LoosePathLike):
74+
# Handle opening the file and parsing the yaml
75+
filepath = Path(filepath).absolute()
76+
try:
77+
with open(filepath, 'r') as yaml_file:
78+
config_dict = yaml.safe_load(yaml_file)
79+
except FileNotFoundError as e:
80+
raise RuntimeError(f"Error: The specified config '{filepath}' could not be found.") from e
81+
except yaml.YAMLError as e:
82+
raise RuntimeError(f"Error: Failed to parse config '{filepath}'") from e
83+
84+
return cls(config_dict)
85+
8486
def process_datasets(self, raw_config: RawConfig):
8587
"""
8688
Parse dataset information
@@ -93,12 +95,11 @@ def process_datasets(self, raw_config: RawConfig):
9395
# Currently assumes all datasets have a label and the labels are unique
9496
# When Snakemake parses the config file it loads the datasets as OrderedDicts not dicts
9597
# Convert to dicts to simplify the yaml logging
96-
self.datasets = {}
9798
for dataset in raw_config.datasets:
9899
label = dataset.label
99100
if label.lower() in [key.lower() for key in self.datasets.keys()]:
100101
raise ValueError(f"Datasets must have unique case-insensitive labels, but the label {label} appears at least twice.")
101-
self.datasets[label] = dict(dataset)
102+
self.datasets[label] = dataset
102103

103104
# parse gold standard information
104105
self.gold_standards = {gold_standard.label: dict(gold_standard) for gold_standard in raw_config.gold_standards}

spras/config/dataset.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Annotated
2+
3+
from pydantic import AfterValidator, BaseModel, ConfigDict
4+
5+
from spras.config.util import label_validator
6+
from spras.util import LoosePathLike
7+
8+
9+
class DatasetSchema(BaseModel):
10+
"""
11+
Collection of information related to `Dataset` objects in the configuration.
12+
"""
13+
14+
# We prefer AfterValidator here to allow pydantic to run its own
15+
# validation & coercion logic before we check it against our own
16+
# requirements
17+
label: Annotated[str, AfterValidator(label_validator("Dataset"))]
18+
node_files: list[LoosePathLike]
19+
edge_files: list[LoosePathLike]
20+
other_files: list[LoosePathLike]
21+
data_dir: LoosePathLike
22+
23+
model_config = ConfigDict(extra='forbid')

spras/config/schema.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
- `CaseInsensitiveEnum` (see ./util.py)
1111
"""
1212

13-
import re
1413
import warnings
1514
from typing import Annotated
1615

1716
from pydantic import AfterValidator, BaseModel, ConfigDict, model_validator
1817

1918
from spras.config.algorithms import AlgorithmUnion
2019
from spras.config.container_schema import ContainerSettings
21-
from spras.config.util import CaseInsensitiveEnum
20+
from spras.config.dataset import DatasetSchema
21+
from spras.config.util import CaseInsensitiveEnum, label_validator
2222

2323
# Most options here have an `include` property,
2424
# which is meant to make disabling parts of the configuration easier.
@@ -112,30 +112,6 @@ class Analysis(BaseModel):
112112
# The default length of the truncated hash used to identify parameter combinations
113113
DEFAULT_HASH_LENGTH = 7
114114

115-
def label_validator(name: str):
116-
"""
117-
A validator takes in a label
118-
and ensures that it contains only letters, numbers, or underscores.
119-
"""
120-
label_pattern = r'^\w+$'
121-
def validate(label: str):
122-
if not bool(re.match(label_pattern, label)):
123-
raise ValueError(f"{name} label '{label}' contains invalid values. {name} labels can only contain letters, numbers, or underscores.")
124-
return label
125-
return validate
126-
127-
class Dataset(BaseModel):
128-
# We prefer AfterValidator here to allow pydantic to run its own
129-
# validation & coercion logic before we check it against our own
130-
# requirements
131-
label: Annotated[str, AfterValidator(label_validator("Dataset"))]
132-
node_files: list[str]
133-
edge_files: list[str]
134-
other_files: list[str]
135-
data_dir: str
136-
137-
model_config = ConfigDict(extra='forbid')
138-
139115
class GoldStandard(BaseModel):
140116
label: Annotated[str, AfterValidator(label_validator("Gold Standard"))]
141117
node_files: list[str] = []
@@ -164,7 +140,7 @@ class RawConfig(BaseModel):
164140

165141
# See algorithms.py for more information about AlgorithmUnion
166142
algorithms: list[AlgorithmUnion] # type: ignore - pydantic allows this.
167-
datasets: list[Dataset]
143+
datasets: list[DatasetSchema]
168144
gold_standards: list[GoldStandard] = []
169145
analysis: Analysis = Analysis()
170146

spras/config/util.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,26 @@
44
only import this config file.
55
"""
66

7+
import re
78
from enum import Enum
89
from typing import Any
910

1011
import yaml
1112
from pydantic import BaseModel, ConfigDict
1213

1314

15+
def label_validator(name: str):
16+
"""
17+
A validator takes in a label
18+
and ensures that it contains only letters, numbers, or underscores.
19+
"""
20+
label_pattern = r'^\w+$'
21+
def validate(label: str):
22+
if not bool(re.match(label_pattern, label)):
23+
raise ValueError(f"{name} label '{label}' contains invalid values. {name} labels can only contain letters, numbers, or underscores.")
24+
return label
25+
return validate
26+
1427
# https://stackoverflow.com/a/76883868/7589775
1528
class CaseInsensitiveEnum(str, Enum):
1629
"""

0 commit comments

Comments
 (0)