Skip to content

Commit c21bc98

Browse files
committed
ENH: lazy loading is added.
1 parent 290e39f commit c21bc98

File tree

2 files changed

+118
-16
lines changed

2 files changed

+118
-16
lines changed

barcodebert/datasets.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch.utils.data import Dataset
1212
from torchtext.vocab import vocab as build_vocab_from_dict
1313
from transformers import AutoTokenizer
14+
from torch.utils.data import IterableDataset
1415

1516

1617
class KmerTokenizer(object):
@@ -85,6 +86,83 @@ def __call__(self, dna_sequence, offset=0) -> tuple[list, list]:
8586
return tokens, att_mask
8687

8788

89+
class LazyDNADataset(IterableDataset):
90+
def __init__(
91+
self,
92+
file_path,
93+
k_mer=4,
94+
stride=None,
95+
max_len=256,
96+
randomize_offset=False,
97+
tokenizer="kmer",
98+
bpe_path=None,
99+
tokenize_n_nucleotide=False,
100+
dataset_format="CANADA-1.5M",
101+
):
102+
self.file_path = file_path
103+
self.k_mer = k_mer
104+
self.stride = k_mer if stride is None else stride
105+
self.max_len = max_len
106+
self.randomize_offset = randomize_offset
107+
self.dataset_format = dataset_format
108+
109+
if dataset_format not in ["CANADA-1.5M", "BIOSCAN-5M", "DNABERT-2"]:
110+
raise NotImplementedError(f"Dataset {dataset_format} not supported.")
111+
112+
if tokenizer == "kmer":
113+
base_pairs = "ACGT"
114+
self.special_tokens = ["[MASK]", "[UNK]"]
115+
UNK_TOKEN = "[UNK]"
116+
117+
if tokenize_n_nucleotide:
118+
base_pairs += "N"
119+
kmers = ["".join(kmer) for kmer in product(base_pairs, repeat=self.k_mer)]
120+
121+
if tokenize_n_nucleotide:
122+
prediction_kmers = [k for k in kmers if "N" not in k]
123+
other_kmers = [k for k in kmers if "N" in k]
124+
kmers = prediction_kmers + other_kmers
125+
126+
kmer_dict = dict.fromkeys(kmers, 1)
127+
self.vocab = build_vocab_from_dict(kmer_dict, specials=self.special_tokens)
128+
self.vocab.set_default_index(self.vocab[UNK_TOKEN])
129+
self.vocab_size = len(self.vocab)
130+
131+
self.tokenizer = KmerTokenizer(
132+
self.k_mer, self.vocab, stride=self.stride, padding=True, max_len=self.max_len
133+
)
134+
elif tokenizer == "bpe":
135+
self.tokenizer = BPETokenizer(padding=True, max_tokenized_len=self.max_len, bpe_path=bpe_path)
136+
self.vocab_size = self.tokenizer.bpe.vocab_size
137+
else:
138+
raise ValueError(f'Tokenizer "{tokenizer}" not recognized.')
139+
140+
def parse_row(self, row):
141+
dna_seq = row["nucleotides"]
142+
if self.dataset_format == "CANADA-1.5M":
143+
label = row["species_name"]
144+
elif self.dataset_format == "BIOSCAN-5M":
145+
label = row["species_index"]
146+
elif self.dataset_format == "DNABERT-2":
147+
label = 0 # dummy label
148+
else:
149+
raise NotImplementedError
150+
151+
offset = torch.randint(self.k_mer, (1,)).item() if self.randomize_offset else 0
152+
tokens, att_mask = self.tokenizer(dna_seq, offset=offset)
153+
return tokens, torch.tensor(int(label), dtype=torch.int64), att_mask
154+
155+
def __iter__(self):
156+
df_iter = pd.read_csv(
157+
self.file_path,
158+
sep="\t" if self.file_path.endswith(".tsv") else ",",
159+
chunksize=1,
160+
keep_default_na=False,
161+
)
162+
for chunk in df_iter:
163+
yield self.parse_row(chunk.iloc[0])
164+
165+
88166
class DNADataset(Dataset):
89167
def __init__(
90168
self,
@@ -104,7 +182,7 @@ def __init__(
104182
self.randomize_offset = randomize_offset
105183

106184
# Check that the dataframe contains a valid format
107-
if dataset_format not in ["CANADA-1.5M", "BIOSCAN-5M"]:
185+
if dataset_format not in ["CANADA-1.5M", "BIOSCAN-5M", "DNABERT-2"]:
108186
raise NotImplementedError(f"Dataset {dataset_format} not supported.")
109187

110188
if tokenizer == "kmer":
@@ -148,10 +226,14 @@ def __init__(
148226
if dataset_format == "CANADA-1.5M":
149227
self.labels, self.label_set = pd.factorize(df["species_name"], sort=True)
150228
self.num_labels = len(self.label_set)
151-
else:
229+
elif dataset_format == "BIOSCAN-5M":
152230
self.label_names = df["species_name"].to_list()
153231
self.labels = df["species_index"].to_list()
154232
self.num_labels = 22_622
233+
elif dataset_format == "DNABERT-2":
234+
# this is just dummy labels for the DNABERT-S_2M dataset
235+
self.labels = np.zeros(len(self.barcodes), dtype=np.int64)
236+
self.num_labels = 1
155237

156238
def __len__(self):
157239
return len(self.barcodes)

barcodebert/pretraining.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from transformers import BertConfig, BertForTokenClassification
1818

1919
from barcodebert import levenshtein, utils
20-
from barcodebert.datasets import DNADataset
20+
from barcodebert.datasets import DNADataset, LazyDNADataset
2121
from barcodebert.io import safe_save_model
2222

2323
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
@@ -138,7 +138,7 @@ def print_pass(*args, **kwargs):
138138

139139
# DATASET =================================================================
140140

141-
if config.dataset_name not in ["CANADA-1.5M", "BIOSCAN-5M"]:
141+
if config.dataset_name not in ["CANADA-1.5M", "BIOSCAN-5M", "DNABERT-2"]:
142142
raise NotImplementedError(f"Dataset {config.dataset_name} not supported.")
143143

144144
# Handle default stride dynamically set to equal k-mer size
@@ -155,16 +155,30 @@ def print_pass(*args, **kwargs):
155155
"dataset_format": config.dataset_name,
156156
}
157157

158-
dataset_train = DNADataset(
159-
file_path=os.path.join(config.data_dir, "pre_training.csv"),
160-
randomize_offset=True,
161-
**dataset_args,
162-
)
163-
dataset_val = DNADataset(
164-
file_path=os.path.join(config.data_dir, "supervised_train.csv"),
165-
randomize_offset=False,
166-
**dataset_args,
167-
)
158+
if config.lazy_load:
159+
dataset_train = LazyDNADataset(
160+
file_path=os.path.join(config.data_dir, "pre_training.csv"),
161+
randomize_offset=True,
162+
**dataset_args,
163+
)
164+
dataset_val = LazyDNADataset(
165+
file_path=os.path.join(config.data_dir, "supervised_train.csv"),
166+
randomize_offset=False,
167+
**dataset_args,
168+
)
169+
170+
else:
171+
dataset_train = DNADataset(
172+
file_path=os.path.join(config.data_dir, "pre_training.csv"),
173+
randomize_offset=True,
174+
**dataset_args,
175+
)
176+
dataset_val = DNADataset(
177+
file_path=os.path.join(config.data_dir, "supervised_train.csv"),
178+
randomize_offset=False,
179+
**dataset_args,
180+
)
181+
168182
eval_set = "Val"
169183

170184
# Dataloader --------------------------------------------------------------
@@ -214,7 +228,13 @@ def print_pass(*args, **kwargs):
214228
base_pairs += "N"
215229

216230
if config.tokenizer == "kmer":
217-
max_position_embeddings = max(512, math.ceil(1536 / config.stride))
231+
if config.dataset_name == "DNABERT-2":
232+
# DNABERT-2 uses a fixed 512-length sequence
233+
max_position_embeddings = max(512, math.ceil(1000 / config.stride))
234+
else:
235+
# for barcodes
236+
max_position_embeddings = max(512, math.ceil(1536 / config.stride))
237+
218238
n_output_tokens = len(base_pairs) ** config.k_mer
219239
n_special_tokens = len(dataset_train.special_tokens)
220240
n_all_tokens = n_output_tokens + n_special_tokens
@@ -781,7 +801,7 @@ def train_one_epoch(
781801
# Perform the forward pass through the model
782802
print(config.arch)
783803
if config.arch == "maelm":
784-
print("MAELM is implemented")
804+
# print("MAELM is implemented")
785805
out = model(masked_input, att_mask, masked_unseen_tokens, config.maelm_version)
786806
elif config.arch == "transformer":
787807
out = model(masked_input, attention_mask=att_mask)

0 commit comments

Comments
 (0)