1111from torch .utils .data import Dataset
1212from torchtext .vocab import vocab as build_vocab_from_dict
1313from transformers import AutoTokenizer
14+ from torch .utils .data import IterableDataset
1415
1516
1617class KmerTokenizer (object ):
@@ -85,6 +86,83 @@ def __call__(self, dna_sequence, offset=0) -> tuple[list, list]:
8586 return tokens , att_mask
8687
8788
89+ class LazyDNADataset (IterableDataset ):
90+ def __init__ (
91+ self ,
92+ file_path ,
93+ k_mer = 4 ,
94+ stride = None ,
95+ max_len = 256 ,
96+ randomize_offset = False ,
97+ tokenizer = "kmer" ,
98+ bpe_path = None ,
99+ tokenize_n_nucleotide = False ,
100+ dataset_format = "CANADA-1.5M" ,
101+ ):
102+ self .file_path = file_path
103+ self .k_mer = k_mer
104+ self .stride = k_mer if stride is None else stride
105+ self .max_len = max_len
106+ self .randomize_offset = randomize_offset
107+ self .dataset_format = dataset_format
108+
109+ if dataset_format not in ["CANADA-1.5M" , "BIOSCAN-5M" , "DNABERT-2" ]:
110+ raise NotImplementedError (f"Dataset { dataset_format } not supported." )
111+
112+ if tokenizer == "kmer" :
113+ base_pairs = "ACGT"
114+ self .special_tokens = ["[MASK]" , "[UNK]" ]
115+ UNK_TOKEN = "[UNK]"
116+
117+ if tokenize_n_nucleotide :
118+ base_pairs += "N"
119+ kmers = ["" .join (kmer ) for kmer in product (base_pairs , repeat = self .k_mer )]
120+
121+ if tokenize_n_nucleotide :
122+ prediction_kmers = [k for k in kmers if "N" not in k ]
123+ other_kmers = [k for k in kmers if "N" in k ]
124+ kmers = prediction_kmers + other_kmers
125+
126+ kmer_dict = dict .fromkeys (kmers , 1 )
127+ self .vocab = build_vocab_from_dict (kmer_dict , specials = self .special_tokens )
128+ self .vocab .set_default_index (self .vocab [UNK_TOKEN ])
129+ self .vocab_size = len (self .vocab )
130+
131+ self .tokenizer = KmerTokenizer (
132+ self .k_mer , self .vocab , stride = self .stride , padding = True , max_len = self .max_len
133+ )
134+ elif tokenizer == "bpe" :
135+ self .tokenizer = BPETokenizer (padding = True , max_tokenized_len = self .max_len , bpe_path = bpe_path )
136+ self .vocab_size = self .tokenizer .bpe .vocab_size
137+ else :
138+ raise ValueError (f'Tokenizer "{ tokenizer } " not recognized.' )
139+
140+ def parse_row (self , row ):
141+ dna_seq = row ["nucleotides" ]
142+ if self .dataset_format == "CANADA-1.5M" :
143+ label = row ["species_name" ]
144+ elif self .dataset_format == "BIOSCAN-5M" :
145+ label = row ["species_index" ]
146+ elif self .dataset_format == "DNABERT-2" :
147+ label = 0 # dummy label
148+ else :
149+ raise NotImplementedError
150+
151+ offset = torch .randint (self .k_mer , (1 ,)).item () if self .randomize_offset else 0
152+ tokens , att_mask = self .tokenizer (dna_seq , offset = offset )
153+ return tokens , torch .tensor (int (label ), dtype = torch .int64 ), att_mask
154+
155+ def __iter__ (self ):
156+ df_iter = pd .read_csv (
157+ self .file_path ,
158+ sep = "\t " if self .file_path .endswith (".tsv" ) else "," ,
159+ chunksize = 1 ,
160+ keep_default_na = False ,
161+ )
162+ for chunk in df_iter :
163+ yield self .parse_row (chunk .iloc [0 ])
164+
165+
88166class DNADataset (Dataset ):
89167 def __init__ (
90168 self ,
@@ -104,7 +182,7 @@ def __init__(
104182 self .randomize_offset = randomize_offset
105183
106184 # Check that the dataframe contains a valid format
107- if dataset_format not in ["CANADA-1.5M" , "BIOSCAN-5M" ]:
185+ if dataset_format not in ["CANADA-1.5M" , "BIOSCAN-5M" , "DNABERT-2" ]:
108186 raise NotImplementedError (f"Dataset { dataset_format } not supported." )
109187
110188 if tokenizer == "kmer" :
@@ -148,10 +226,14 @@ def __init__(
148226 if dataset_format == "CANADA-1.5M" :
149227 self .labels , self .label_set = pd .factorize (df ["species_name" ], sort = True )
150228 self .num_labels = len (self .label_set )
151- else :
229+ elif dataset_format == "BIOSCAN-5M" :
152230 self .label_names = df ["species_name" ].to_list ()
153231 self .labels = df ["species_index" ].to_list ()
154232 self .num_labels = 22_622
233+ elif dataset_format == "DNABERT-2" :
234+ # this is just dummy labels for the DNABERT-S_2M dataset
235+ self .labels = np .zeros (len (self .barcodes ), dtype = np .int64 )
236+ self .num_labels = 1
155237
156238 def __len__ (self ):
157239 return len (self .barcodes )
0 commit comments