Skip to content

Commit 263e080

Browse files
authored
Merge pull request #19 from bioscan-ml/test_ft_baselines
Camera Ready Version for Bioinformatics
2 parents 0f0ddf1 + bf07bb8 commit 263e080

16 files changed

+1061
-712
lines changed

README.md

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@ features = output.mean(1)
4040
pip install -e .
4141
```
4242

43-
1. Download the [data](https://vault.cs.uwaterloo.ca/s/x7gXQKnmRX3GAZm)
43+
1. Download the data from our Hugging Face Dataset [repository](https://huggingface.co/datasets/bioscan-ml/CanadianInvertebrates-ML)
44+
```shell
45+
cd data/
46+
python download_HF_CanInv.py
47+
```
48+
49+
**Optional**: You can also download the first version of the [data](https://vault.cs.uwaterloo.ca/s/x7gXQKnmRX3GAZm)
4450
```shell
4551
wget https://vault.cs.uwaterloo.ca/s/x7gXQKnmRX3GAZm/download -O data.zip
4652
unzip data.zip
@@ -49,22 +55,28 @@ rm -r new_data
4955
rm data.zip
5056
```
5157

52-
3. Pretrain BarcodeBERT
53-
54-
```bash
55-
python barcodebert/pretraining.py --dataset=CANADA-1.5M --k_mer=4 --n_layers=4 --n_heads=4 --data_dir=data/ --checkpoint=model_checkpoints/CANADA-1.5M/4_4_4/checkpoint_pretraining.pt
56-
```
57-
58-
4. Baseline model pipelines: The desired backbone can be selected using one of the following keywords:
58+
2. DNA foundation model baselines: The desired backbone can be selected using one of the following keywords:
5959
`BarcodeBERT, NT, Hyena_DNA, DNABERT, DNABERT-2, DNABERT-S`
6060
```bash
6161
python baselines/knn_probing.py --backbone=<DESIRED-BACKBONE> --data-dir=data/
6262
python baselines/linear_probing.py --backbone=<DESIRED-BACKBONE> --data-dir=data/
6363
python baselines/finetuning.py --backbone=<DESIRED-BACKBONE> --data-dir=data/ --batch_size=32
64+
python baselines/zsc.py --backbone=<DESIRED-BACKBONE> --data-dir=data/
6465
```
65-
**Note**: HyenaDNA has to be downloaded using `git-lfs`. If that is not available to you, you may download the `/hyenadna-tiny-1k-seqlen/` checkpoint directly from [Hugging face](https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen/tree/main). The keyword `BarcodeBERT` is also available as a baseline but this will download the publicly available model as presented in our workshop paper.
66+
**Note**: The DNABERT model has to be downloaded manually following the instructions in the paper's [repo](https://github.com/jerryji1993/DNABERT) and placed in the `pretrained-models` folder.
6667

67-
5. BLAST
68+
3. Supervised CNN
69+
70+
```bash
71+
python baselines/cnn/1D_CNN_supervised.py
72+
python baselines/cnn/1D_CNN_KNN.py
73+
python baselines/cnn/1D_CNN_Linear_probing.py
74+
python baselines/cnn/1D_CNN_ZSC.py
75+
76+
```
77+
**Note**: Train the CNN backbone with `1D_CNN_supervised.py` before evaluating it on any downtream task.
78+
79+
4. BLAST
6880
```shell
6981
cd data/
7082
python to_fasta.py --input_file=supervised_train.csv &&
@@ -75,7 +87,17 @@ makeblastdb -in supervised_train.fas -title train -dbtype nucl -out train.fas
7587
blastn -query supervised_test.fas -db train.fas -out results_supervised_test.tsv -outfmt 6 -num_threads 16
7688
blastn -query unseen.fas -db train.fas -out results_unseen.tsv -outfmt 6 -num_threads 16
7789
```
78-
90+
### Pretrain BarcodeBERT
91+
To pretrain the model you can run the following command:
92+
```bash
93+
python barcodebert/pretraining.py
94+
--dataset=CANADA-1.5M \
95+
--k_mer=4 \
96+
--n_layers=4 \
97+
--n_heads=4 \
98+
--data_dir=data/ \
99+
--checkpoint=model_checkpoints/CANADA-1.5M/4_4_4/checkpoint_pretraining.pt
100+
```
79101

80102
## Citation
81103

baselines/cnn/1D_CNN_KNN.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import argparse
2+
import sys
3+
4+
import numpy as np
5+
import pandas as pd
6+
import sklearn
7+
import sklearn.metrics
8+
import torch
9+
import wandb
10+
from sklearn.neighbors import KNeighborsClassifier
11+
12+
sys.path.append(".")
13+
from baselines.cnn.cnn_utils import CNNModel, data_from_df
14+
15+
16+
def run(config):
17+
18+
data_folder = config.data_dir
19+
train = pd.read_csv(f"{data_folder}/supervised_train.csv")
20+
test = pd.read_csv(f"{data_folder}/unseen.csv")
21+
22+
target_level = config.target_level + "_name" # "species_name"
23+
24+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
25+
26+
# Get pipeline for reference labels:
27+
labels = train[target_level].to_list()
28+
label_set = sorted(set(labels))
29+
label_pipeline = lambda x: label_set.index(x)
30+
31+
X, y_train = data_from_df(train, target_level, label_pipeline)
32+
X_test, y_test = data_from_df(test, target_level, label_pipeline)
33+
34+
numClasses = max(y_train) + 1
35+
print(f"[INFO]: There are {numClasses} taxonomic groups")
36+
37+
model = CNNModel(1, 1653).to(device)
38+
39+
model_path = "model_checkpoints/CANADA1.5M_CNN.pth"
40+
print(f"Getting the model from: {model_path}")
41+
42+
try:
43+
model.load_state_dict(torch.load(model_path))
44+
model.to(device)
45+
model.eval()
46+
except Exception:
47+
print("There was a problem loading the model")
48+
return
49+
50+
# USE MODEL AS FEATURE EXTRACTOR =================================================================
51+
dna_embeddings = []
52+
53+
with torch.no_grad():
54+
for i in range(X_test.shape[0]):
55+
inputs = torch.tensor(X_test[i]).view(-1, 1, 660, 5).to(device)
56+
dna_embeddings.extend(model(inputs)[1].cpu().numpy())
57+
58+
train_embeddings = []
59+
60+
with torch.no_grad():
61+
for i in range(X.shape[0]):
62+
inputs = torch.tensor(X[i]).view(-1, 1, 660, 5).to(device)
63+
train_embeddings.extend(model(inputs)[1].cpu().numpy())
64+
65+
X_test = np.array(dna_embeddings).reshape(-1, 500)
66+
print(X_test.shape)
67+
68+
X = np.array(train_embeddings).reshape(-1, 500)
69+
70+
neigh = KNeighborsClassifier(n_neighbors=1, metric="cosine")
71+
neigh.fit(X, y_train)
72+
print("Accuracy:", neigh.score(X_test, y_test))
73+
y_pred = neigh.predict(X_test)
74+
75+
# Create results dictionary
76+
results = {}
77+
results["count"] = len(y_test)
78+
# Note that these evaluation metrics have all been converted to percentages
79+
results["accuracy"] = 100.0 * sklearn.metrics.accuracy_score(y_test, y_pred)
80+
results["accuracy-balanced"] = 100.0 * sklearn.metrics.balanced_accuracy_score(y_test, y_pred)
81+
results["f1-micro"] = 100.0 * sklearn.metrics.f1_score(y_test, y_pred, average="micro")
82+
results["f1-macro"] = 100.0 * sklearn.metrics.f1_score(y_test, y_pred, average="macro")
83+
results["f1-support"] = 100.0 * sklearn.metrics.f1_score(y_test, y_pred, average="weighted")
84+
85+
wandb.log({f"eval/{k}": v for k, v in results.items()})
86+
87+
print("Evaluation results:")
88+
for k, v in results.items():
89+
if k == "count":
90+
print(f" {k + ' ':.<21s}{v:7d}")
91+
elif k in ["max_ram_mb", "peak_vram_mb"]:
92+
print(f" {k + ' ':.<24s} {v:6.2f} MB")
93+
else:
94+
print(f" {k + ' ':.<24s} {v:6.2f} %")
95+
96+
97+
if __name__ == "__main__":
98+
parser = argparse.ArgumentParser()
99+
parser.add_argument(
100+
"--data_dir",
101+
default="./data",
102+
help="Path to the folder containing the data in the desired CSV format",
103+
)
104+
parser.add_argument(
105+
"--target_level",
106+
default="genus",
107+
help="Desired taxonomic rank, either 'genus' or 'species'",
108+
)
109+
110+
config = parser.parse_args()
111+
wandb.init(project="BarcodeBERT", name="knn_CNN_CANADA-1.5M", config=vars(config))
112+
wandb.config.update(vars(config)) # log your CLI args
113+
run(config)

0 commit comments

Comments
 (0)