Skip to content

Commit 1388aba

Browse files
committed
Add feature importance analysis and unit tests for filter script
1 parent 5934444 commit 1388aba

File tree

4 files changed

+387
-0
lines changed

4 files changed

+387
-0
lines changed

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,24 @@ sbatch -A [account] -p [partition] -c 1 --mem=4g \
227227

228228
**Output:** `results/FILTER/filter.gff3`
229229

230+
### Feature Importance Test
231+
232+
Reviewers often ask for an ablation study of the semi-supervised filter. After a
233+
filter run completes (which produces `FILTER/data.tsv`), launch the automated
234+
leave-one-feature-out test:
235+
236+
```bash
237+
python bin/filter_feature_importance.py FILTER/data.tsv results/busco/full_table.tsv \
238+
--output-table FILTER/feature_importance.tsv
239+
```
240+
241+
The script reuses `Filter.semiSupRandomForest`, trains a baseline model with all
242+
features, and then retrains while removing each feature individually. The final
243+
out-of-bag error deltas are written to `FILTER/feature_importance.tsv` (and
244+
`FILTER/feature_importance.json`). Use `--features` to restrict the analysis to a
245+
subset of columns or `--ignore` to drop metadata columns that should never be
246+
used as predictors.
247+
230248
## Configuration
231249

232250
Sylvan uses two separate configuration files:

Wiki.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,3 +560,42 @@ python bin/generate_cluster_from_config.py --config toydata/config/config_anno
560560
```
561561
chmod 775 bin/generate_cluster_from_config.py
562562
```
563+
564+
## Feature Importance Analysis
565+
566+
After finishing the filter phase you will have `FILTER/data.tsv` (the feature
567+
matrix used by `Filter.py`) and a BUSCO run directory such as
568+
`results/busco/eudicots_odb10`. Reviewers often ask for a feature ablation
569+
study, so we provide an automated helper:
570+
571+
```bash
572+
python bin/filter_feature_importance.py FILTER/data.tsv results/busco/<lineage>/full_table.tsv \
573+
--output-table FILTER/feature_importance.tsv
574+
```
575+
576+
- **What is the BUSCO full table?** Every BUSCO run writes a
577+
`full_table.tsv` inside its lineage-specific run folder. Each non-Missing
578+
BUSCO row lists the BUSCO ID, status (Complete/Duplicated/Fragmented), and the
579+
transcript/gene ID it matched. The feature-importance script reuses this file
580+
to count how many BUSCOs remain in the “keep” set during each iteration—no new
581+
BUSCO analysis is required.
582+
- **Outputs**: `FILTER/feature_importance.tsv` (table) plus
583+
`FILTER/feature_importance.json` (machine-readable). Both include the baseline
584+
run (all features) and each leave-one-feature-out run, along with final
585+
out-of-bag (OOB) error, BUSCO counts, and iteration counts.
586+
- **Optional flags**:
587+
- `--features TPM COVERAGE PFAM ...` restricts the analysis to specific
588+
columns from `FILTER/data.tsv`.
589+
- `--ignore TPM_missing singleExon` removes metadata columns so the script
590+
automatically uses every other feature column.
591+
592+
Workflow summary:
593+
594+
1. Run `Filter.py` as usual to create `FILTER/data.tsv`.
595+
2. Identify the BUSCO `full_table.tsv` path you already used for filter
596+
monitoring (e.g., `results/busco/eudicots_odb10/full_table.tsv`).
597+
3. Execute the command above. Inspect `FILTER/feature_importance.tsv` to see how
598+
dropping each feature affects OOB error (positive delta ⇒ feature is
599+
important).
600+
4. Incorporate the results (table/plot) into your manuscript or reviewer
601+
response.

bin/filter_feature_importance.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
#!/usr/bin/env python3
2+
"""Leave-one-feature-out analysis for the semi-supervised filter.
3+
4+
This script reuses the `Filter.semiSupRandomForest` training loop and reruns it
5+
multiple times while removing one feature at a time. The difference in the
6+
final out-of-bag (OOB) error provides an intuitive importance score: if dropping
7+
an evidence track increases the OOB error, that feature contributes useful
8+
signal to the classifier.
9+
10+
Example:
11+
python bin/filter_feature_importance.py FILTER/data.tsv results/busco/run.tsv \
12+
--output-table FILTER/feature_importance.tsv
13+
"""
14+
from __future__ import annotations
15+
16+
import argparse
17+
import json
18+
import math
19+
import os
20+
from typing import Dict, List
21+
22+
import pandas as pd
23+
24+
from Filter import semiSupRandomForest
25+
26+
27+
def parse_args() -> argparse.Namespace:
28+
parser = argparse.ArgumentParser(
29+
description="Leave-one-feature-out analysis for the Sylvan filter"
30+
)
31+
parser.add_argument(
32+
"data",
33+
help="Path to the TSV created by Filter.filter_genes (e.g. FILTER/data.tsv)",
34+
)
35+
parser.add_argument(
36+
"busco",
37+
help=(
38+
"Path to the BUSCO table used for monitoring (same input passed to "
39+
"Filter.py)."
40+
),
41+
)
42+
parser.add_argument(
43+
"--features",
44+
nargs="*",
45+
default=None,
46+
help=(
47+
"Explicit list of feature columns to evaluate. The default uses all "
48+
"columns except transcript_id/label and anything listed via --ignore."
49+
),
50+
)
51+
parser.add_argument(
52+
"--ignore",
53+
nargs="*",
54+
default=[],
55+
help="Columns in the data TSV that should never be used as model features.",
56+
)
57+
parser.add_argument(
58+
"--trees",
59+
type=int,
60+
default=100,
61+
help="Number of trees per random forest run (default: 100)",
62+
)
63+
parser.add_argument(
64+
"--predictors",
65+
type=int,
66+
default=6,
67+
help="max_features hyperparameter for RandomForestClassifier (default: 6)",
68+
)
69+
parser.add_argument(
70+
"--max-iter",
71+
type=int,
72+
default=5,
73+
help=(
74+
"Maximum number of recycling iterations used by the semi-supervised "
75+
"training loop (default: 5)"
76+
),
77+
)
78+
parser.add_argument(
79+
"--recycle",
80+
type=float,
81+
default=0.95,
82+
help=(
83+
"Prediction probability threshold required to recycle unlabeled "
84+
"examples (default: 0.95)"
85+
),
86+
)
87+
parser.add_argument(
88+
"--seed",
89+
type=int,
90+
default=123,
91+
help="Random seed passed to RandomForestClassifier (default: 123)",
92+
)
93+
parser.add_argument(
94+
"--output-table",
95+
default=None,
96+
help=(
97+
"Output TSV summarizing the baseline run and each leave-one-feature-out "
98+
"experiment (default: <data_dir>/feature_importance.tsv)"
99+
),
100+
)
101+
parser.add_argument(
102+
"--output-json",
103+
default=None,
104+
help=(
105+
"Optional JSON file capturing the same summary (default: "
106+
"<data_dir>/feature_importance.json)"
107+
),
108+
)
109+
return parser.parse_args()
110+
111+
112+
def resolve_feature_list(df: pd.DataFrame, include: List[str] | None, ignore: List[str]) -> List[str]:
113+
"""Return the ordered feature list used for training/ablation."""
114+
metadata_cols = {"transcript_id", "label"}
115+
metadata_cols.update(ignore or [])
116+
default_features = [c for c in df.columns if c not in metadata_cols]
117+
118+
if include:
119+
missing = sorted(set(include) - set(default_features))
120+
if missing:
121+
raise ValueError(
122+
f"Requested feature(s) not found in data columns: {', '.join(missing)}"
123+
)
124+
return include
125+
126+
return default_features
127+
128+
129+
def summarize_process(process: Dict[str, List[float]]) -> Dict[str, float]:
130+
"""Extract final iteration statistics from the training process log."""
131+
def last(seq: List[float]) -> float:
132+
if not seq:
133+
return float("nan")
134+
return seq[-1]
135+
136+
return {
137+
"iterations": len(process.get("kept", [])),
138+
"final_kept": last(process.get("kept", [])),
139+
"final_discarded": last(process.get("discarded", [])),
140+
"final_kept_buscos": last(process.get("kept_buscos", [])),
141+
"final_discarded_buscos": last(process.get("discarded_buscos", [])),
142+
"final_oob_error": last(process.get("OOB", [])),
143+
}
144+
145+
146+
def run_filter(
147+
data: pd.DataFrame,
148+
features: List[str],
149+
busco_path: str,
150+
args: argparse.Namespace,
151+
) -> Dict[str, float]:
152+
subset_cols = ["transcript_id", "label"] + features
153+
subset = data.loc[:, subset_cols].copy()
154+
_, process = semiSupRandomForest(
155+
subset,
156+
args.predictors,
157+
busco_path,
158+
args.trees,
159+
seed=args.seed,
160+
recycle_prob=args.recycle,
161+
maxiter=args.max_iter,
162+
)
163+
return summarize_process(process)
164+
165+
166+
def format_delta(value: float) -> str:
167+
if value is None or math.isnan(value):
168+
return "nan"
169+
return f"{value:+.4f}"
170+
171+
172+
def main() -> None:
173+
args = parse_args()
174+
data = pd.read_csv(args.data, sep="\t")
175+
176+
feature_list = resolve_feature_list(data, args.features, args.ignore)
177+
if not feature_list:
178+
raise ValueError("No usable features detected in data TSV.")
179+
180+
out_dir = os.path.dirname(os.path.abspath(args.data))
181+
table_path = args.output_table or os.path.join(out_dir, "feature_importance.tsv")
182+
json_path = args.output_json or os.path.join(out_dir, "feature_importance.json")
183+
184+
print(f"Running baseline model with {len(feature_list)} features...")
185+
baseline = run_filter(data, feature_list, args.busco, args)
186+
baseline_row = {
187+
"feature_removed": "(none)",
188+
"num_features": len(feature_list),
189+
"oob_delta": 0.0,
190+
**baseline,
191+
}
192+
193+
results = [baseline_row]
194+
for feature in feature_list:
195+
reduced = [f for f in feature_list if f != feature]
196+
if not reduced:
197+
continue
198+
print(f"Dropping '{feature}' ({len(reduced)} features remaining)...")
199+
summary = run_filter(data, reduced, args.busco, args)
200+
summary_row = {
201+
"feature_removed": feature,
202+
"num_features": len(reduced),
203+
"oob_delta": summary["final_oob_error"] - baseline["final_oob_error"],
204+
**summary,
205+
}
206+
results.append(summary_row)
207+
delta_str = format_delta(summary_row["oob_delta"])
208+
print(
209+
f" -> final OOB error: {summary_row['final_oob_error']:.4f} "
210+
f"(delta {delta_str})"
211+
)
212+
213+
df = pd.DataFrame(results)
214+
df.to_csv(table_path, sep="\t", index=False)
215+
print(f"\nSummary written to {table_path}")
216+
217+
if json_path:
218+
json_ready = []
219+
for row in results:
220+
json_ready.append(
221+
{
222+
k: (None if isinstance(v, float) and math.isnan(v) else v)
223+
for k, v in row.items()
224+
}
225+
)
226+
with open(json_path, "w", encoding="utf-8") as fh:
227+
json.dump({"runs": json_ready}, fh, indent=2)
228+
print(f"JSON summary written to {json_path}")
229+
230+
231+
if __name__ == "__main__":
232+
main()

bin/test_feature_importance.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#!/usr/bin/env python3
2+
"""Unit test for the filter_feature_importance.py script.
3+
4+
This test simulates the inputs to filter_feature_importance.py to verify that it
5+
runs and produces outputs in the expected format. It does not check the
6+
statistical validity of the results, but rather the script's execution and
7+
output structure.
8+
"""
9+
import unittest
10+
import tempfile
11+
import shutil
12+
import os
13+
import subprocess
14+
import pandas as pd
15+
import json
16+
17+
class TestFeatureImportance(unittest.TestCase):
18+
def setUp(self):
19+
"""Set up a temporary directory and dummy input files."""
20+
self.temp_dir = tempfile.mkdtemp()
21+
self.data_path = os.path.join(self.temp_dir, "data.tsv")
22+
self.busco_path = os.path.join(self.temp_dir, "full_table.tsv")
23+
self.output_table_path = os.path.join(self.temp_dir, "feature_importance.tsv")
24+
self.output_json_path = os.path.join(self.temp_dir, "feature_importance.json")
25+
26+
# Create dummy data.tsv
27+
data = {
28+
'transcript_id': [f'tx{i}' for i in range(10)],
29+
'label': ['TE', 'Prot', 'BG', 'TE', 'Prot', 'BG', 'TE', 'Prot', 'BG', 'TE'],
30+
'feature1': [0.1, 0.9, 0.2, 0.15, 0.85, 0.25, 0.11, 0.92, 0.22, 0.13],
31+
'feature2': [0.8, 0.2, 0.7, 0.85, 0.25, 0.75, 0.81, 0.22, 0.72, 0.83],
32+
'feature3': [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
33+
}
34+
pd.DataFrame(data).to_csv(self.data_path, sep='\t', index=False)
35+
36+
# Create dummy full_table.tsv for BUSCO
37+
busco_data = {
38+
'# Busco id': [f'busco{i}' for i in range(5)],
39+
'Status': ['Complete'] * 5,
40+
'Sequence': ['tx1', 'tx4', 'tx7', 'tx0', 'tx3'],
41+
'Score': [0.9] * 5,
42+
'Length': [100] * 5
43+
}
44+
with open(self.busco_path, "w") as f:
45+
f.write("# Some header lines\n")
46+
pd.DataFrame(busco_data).to_csv(f, sep='\t', index=False)
47+
48+
def tearDown(self):
49+
"""Clean up the temporary directory."""
50+
shutil.rmtree(self.temp_dir)
51+
52+
def test_script_runs_and_creates_output(self):
53+
"""Test if the script runs and creates the expected output files."""
54+
script_path = os.path.join(os.path.dirname(__file__), 'filter_feature_importance.py')
55+
56+
# The script we are testing imports 'Filter' which is in the same directory.
57+
# We need to make sure python can find it.
58+
env = os.environ.copy()
59+
env['PYTHONPATH'] = os.path.dirname(__file__) + os.pathsep + env.get('PYTHONPATH', '')
60+
61+
cmd = [
62+
'python', script_path,
63+
self.data_path,
64+
self.busco_path,
65+
'--output-table', self.output_table_path,
66+
'--output-json', self.output_json_path,
67+
'--max-iter', '2', # Keep it fast
68+
]
69+
70+
result = subprocess.run(cmd, capture_output=True, text=True, env=env)
71+
72+
self.assertEqual(result.returncode, 0, f"Script failed with exit code {result.returncode}\nstdout:\n{result.stdout}\nstderr:\n{result.stderr}")
73+
74+
# Check if output files were created
75+
self.assertTrue(os.path.exists(self.output_table_path), "Output table file was not created.")
76+
self.assertTrue(os.path.exists(self.output_json_path), "Output json file was not created.")
77+
78+
# Check the content of the TSV output
79+
output_df = pd.read_csv(self.output_table_path, sep='\t')
80+
self.assertEqual(len(output_df), 4) # baseline + 3 features
81+
expected_columns = [
82+
'feature_removed', 'num_features', 'oob_delta', 'iterations',
83+
'final_kept', 'final_discarded', 'final_kept_buscos',
84+
'final_discarded_buscos', 'final_oob_error'
85+
]
86+
self.assertListEqual(list(output_df.columns), expected_columns)
87+
self.assertEqual(output_df.iloc[0]['feature_removed'], '(none)')
88+
89+
# Check the content of the JSON output
90+
with open(self.output_json_path, 'r') as f:
91+
json_data = json.load(f)
92+
self.assertIn('runs', json_data)
93+
self.assertEqual(len(json_data['runs']), 4)
94+
self.assertEqual(json_data['runs'][0]['feature_removed'], '(none)')
95+
96+
97+
if __name__ == '__main__':
98+
unittest.main()

0 commit comments

Comments
 (0)