11"""
2- Functions to augment data. All functions assume that the input is a numpy array containing an integer
3- encoded DNA sequence of shape (L,) or a numpy array containing a label of shape (T, L).
4- The augmented output will be in the same format.
2+ `grelu.data.augment` contains functions to augment genomic sequences or functional genomic data.
3+
4+ All functions assume that the input is either:
5+
6+ (1) a 1-D numpy array containing an integer encoded DNA sequence of shape (length,) or;
7+ (2) a 2-D numpy array containing a label of shape (tasks, length).
8+
9+ The augmented output must be returned in the same format. All augmentation functions also
10+ require an index (idx) which is an integer or boolean value.
11+
12+ This module also contains the `Augmenter` class which is responsible for applying multiple
13+ augmentations to a given DNA sequence or (sequence, label) pair.
514"""
615
716import warnings
1221from grelu .sequence .mutate import random_mutate
1322from grelu .sequence .utils import reverse_complement
1423
15- # This is the number of output sequences expected from each type of augmentation
24+ # This is the number of output sequences expected from a single input sequence using each type of augmentation
1625AUGMENTATION_MULTIPLIER_FUNCS = {
1726 "rc" : lambda x : 2 ** x ,
1827 "max_seq_shift" : lambda x : (2 * x ) + 1 ,
@@ -197,7 +206,7 @@ def __call__(
197206 else :
198207 raise NotImplementedError
199208
200- # Augment the sequence
209+ # Apply all sequence augmentation functions here
201210
202211 # Shift sequence
203212 if self .shift_seq :
@@ -224,7 +233,7 @@ def __call__(
224233 return seq
225234
226235 else :
227- # Augment the label too
236+ # Apply all label augmentation functions here
228237 if self .shift_label :
229238 # Shift label
230239 label = shift (label , seq_len = self .label_len , idx = pair_shift_idx )
0 commit comments