|
| 1 | +""" |
| 2 | +Written by Alex Tseng |
| 3 | +
|
| 4 | +https://gist.github.com/amtseng/010dd522daaabc92b014f075a34a0a0b |
| 5 | +""" |
| 6 | + |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +def dna_to_one_hot(seqs): |
| 10 | + """ |
| 11 | + Converts a list of DNA ("ACGT") sequences to one-hot encodings, where the |
| 12 | + position of 1s is ordered alphabetically by "ACGT". `seqs` must be a list |
| 13 | + of N strings, where every string is the same length L. Returns an N x L x 4 |
| 14 | + NumPy array of one-hot encodings, in the same order as the input sequences. |
| 15 | + All bases will be converted to upper-case prior to performing the encoding. |
| 16 | + Any bases that are not "ACGT" will be given an encoding of all 0s. |
| 17 | + """ |
| 18 | + seq_len = len(seqs[0]) |
| 19 | + assert np.all(np.array([len(s) for s in seqs]) == seq_len) |
| 20 | + |
| 21 | + # Join all sequences together into one long string, all uppercase |
| 22 | + seq_concat = "".join(seqs).upper() + "ACGT" |
| 23 | + # Add one example of each base, so np.unique doesn't miss indices later |
| 24 | + |
| 25 | + one_hot_map = np.identity(5)[:, :-1].astype(np.int8) |
| 26 | + |
| 27 | + # Convert string into array of ASCII character codes; |
| 28 | + base_vals = np.frombuffer(bytearray(seq_concat, "utf8"), dtype=np.int8) |
| 29 | + |
| 30 | + # Anything that's not an A, C, G, or T gets assigned a higher code |
| 31 | + base_vals[~np.isin(base_vals, np.array([65, 67, 71, 84]))] = 85 |
| 32 | + |
| 33 | + # Convert the codes into indices in [0, 4], in ascending order by code |
| 34 | + _, base_inds = np.unique(base_vals, return_inverse=True) |
| 35 | + |
| 36 | + # Get the one-hot encoding for those indices, and reshape back to separate |
| 37 | + return one_hot_map[base_inds[:-4]].reshape((len(seqs), seq_len, 4)) |
| 38 | + |
| 39 | + |
| 40 | +def one_hot_to_dna(one_hot): |
| 41 | + """ |
| 42 | + Converts a one-hot encoding into a list of DNA ("ACGT") sequences, where the |
| 43 | + position of 1s is ordered alphabetically by "ACGT". `one_hot` must be an |
| 44 | + N x L x 4 array of one-hot encodings. Returns a lits of N "ACGT" strings, |
| 45 | + each of length L, in the same order as the input array. The returned |
| 46 | + sequences will only consist of letters "A", "C", "G", "T", or "N" (all |
| 47 | + upper-case). Any encodings that are all 0s will be translated to "N". |
| 48 | + """ |
| 49 | + bases = np.array(["A", "C", "G", "T", "N"]) |
| 50 | + # Create N x L array of all 5s |
| 51 | + one_hot_inds = np.tile(one_hot.shape[2], one_hot.shape[:2]) |
| 52 | + |
| 53 | + # Get indices of where the 1s are |
| 54 | + batch_inds, seq_inds, base_inds = np.where(one_hot) |
| 55 | + |
| 56 | + # In each of the locations in the N x L array, fill in the location of the 1 |
| 57 | + one_hot_inds[batch_inds, seq_inds] = base_inds |
| 58 | + |
| 59 | + # Fetch the corresponding base for each position using indexing |
| 60 | + seq_array = bases[one_hot_inds] |
| 61 | + return ["".join(seq) for seq in seq_array] |
0 commit comments