|
| 1 | +from collections import defaultdict |
| 2 | + |
| 3 | + |
| 4 | +def filter_results(res_haps, extra_gl): |
| 5 | + """ |
| 6 | + Filter the result to the ones consistent with the extra_gl. |
| 7 | + res_muugs: dictionary: |
| 8 | + { |
| 9 | + 'MaxProb': 1.2370678464000013e-16, |
| 10 | + 'Haps': {'A*01:01+A*33:03^B*08:01+B*44:03^C*07:01+C*07:01^DQB1*02:01+DQB1*02:01^DRB1*03:01+DRB1*07:01': 1.5523456571675956e-16}, |
| 11 | + 'Pops': {'CAU,CAU': 1.5523456571675956e-16} |
| 12 | + } |
| 13 | + res_haps: dictionary: |
| 14 | + { |
| 15 | + 'MaxProb': 1.2370678464000013e-16, |
| 16 | + 'Haps': [['A*01:01~B*08:01~C*07:01~DQB1*02:01~DRB1*03:01', 'A*33:03~B*44:03~C*07:01~DQB1*02:01~DRB1*07:01'], |
| 17 | + ['A*33:03~B*08:01~C*07:01~DQB1*02:01~DRB1*03:01', 'A*01:01~B*44:03~C*07:01~DQB1*02:01~DRB1*07:01'], |
| 18 | + ['A*01:01~B*44:03~C*07:01~DQB1*02:01~DRB1*03:01', 'A*33:03~B*08:01~C*07:01~DQB1*02:01~DRB1*07:01'], |
| 19 | + ['A*33:03~B*44:03~C*07:01~DQB1*02:01~DRB1*03:01', 'A*01:01~B*08:01~C*07:01~DQB1*02:01~DRB1*07:01']], |
| 20 | + 'Probs': [1.2370678464000013e-16, 2.990312032960011e-17, 3.0052635931248134e-22, 1.6243602208000046e-18], |
| 21 | + 'Pops': [['CAU', 'CAU'], ['CAU', 'CAU'], ['CAU', 'CAU'], ['CAU', 'CAU']] |
| 22 | + } |
| 23 | + Extra GL: C*07:01+C*07:01^DQB1*02:01+DQB1*02:01 |
| 24 | + Short GL: A*01:01+A*33:03^B*08:01+B*44:03^DRB1*03:01+DRB1*07:01 |
| 25 | + """ |
| 26 | + |
| 27 | + split_extra_gl_into_locus = extra_gl.split("^") |
| 28 | + |
| 29 | + dct = {locus.split("*")[0]: [set(locus.split("+")[0].split("/")), set(locus.split("+")[1].split("/"))] |
| 30 | + for locus in split_extra_gl_into_locus} |
| 31 | + |
| 32 | + haps = res_haps["Haps"] |
| 33 | + filter_idx = [] |
| 34 | + for idx, pair in enumerate(haps): |
| 35 | + check = True |
| 36 | + hap1, hap2 = pair[0], pair[1] |
| 37 | + for allele1, allele2 in zip(hap1.split("~"), hap2.split("~")): |
| 38 | + loc = allele1.split("*")[0] |
| 39 | + if loc in dct: |
| 40 | + if not ((allele1 in dct[loc][0] and allele2 in dct[loc][1]) or ( |
| 41 | + allele1 in dct[loc][1] and allele2 in dct[loc][0])): |
| 42 | + check = False |
| 43 | + break |
| 44 | + if check: |
| 45 | + filter_idx.append(idx) |
| 46 | + res_haps["Haps"] = [res_haps["Haps"][idx] for idx in filter_idx] |
| 47 | + res_haps["Probs"] = [res_haps["Probs"][idx] for idx in filter_idx] |
| 48 | + res_haps["Pops"] = [res_haps["Pops"][idx] for idx in filter_idx] |
| 49 | + if not res_haps["Probs"]: |
| 50 | + return {"Haps": [], "Probs": [], "Pops": []} |
| 51 | + |
| 52 | + return res_haps |
| 53 | + |
| 54 | + |
| 55 | +def create_subject_dict(file_path): |
| 56 | + subject_dict = {} |
| 57 | + |
| 58 | + # Open and read the file |
| 59 | + with open(file_path, 'r') as file: |
| 60 | + for line in file: |
| 61 | + line = line.strip() |
| 62 | + if not line: |
| 63 | + continue |
| 64 | + |
| 65 | + subject_id = line.split(',', 1)[0] |
| 66 | + |
| 67 | + if subject_id not in subject_dict: |
| 68 | + subject_dict[subject_id] = [] |
| 69 | + |
| 70 | + subject_dict[subject_id].append(line) |
| 71 | + |
| 72 | + return subject_dict |
| 73 | + |
| 74 | +def create_haps(path_pmug): |
| 75 | + subject_dict = create_subject_dict(path_pmug) |
| 76 | + all_haps = {"subject_id": [], "res_haps": []} |
| 77 | + |
| 78 | + for idx, id in enumerate(subject_dict.keys()): |
| 79 | + res_haps = {"Haps": [], "Probs": [], "Pops": []} |
| 80 | + rows = subject_dict[id] |
| 81 | + for row in rows: |
| 82 | + row = row.split(',') |
| 83 | + pair1 = str(row[1]).split(';') |
| 84 | + haps1, pops1 = pair1[0], pair1[1] |
| 85 | + pair2 = str(row[2]).split(';') |
| 86 | + haps2, pops2 = pair2[0], pair2[1] |
| 87 | + prob = float(row[3]) |
| 88 | + |
| 89 | + res_haps["Haps"].append([haps1, haps2]) |
| 90 | + res_haps["Pops"].append([pops1, pops2]) |
| 91 | + res_haps["Probs"].append(prob) |
| 92 | + |
| 93 | + all_haps["subject_id"].append(id) |
| 94 | + all_haps["res_haps"].append(res_haps) |
| 95 | + |
| 96 | + return all_haps |
| 97 | + |
| 98 | +def is_subarray_unordered(large_array, small_array): |
| 99 | + # Convert arrays to sets |
| 100 | + set_large = set(large_array) |
| 101 | + set_small = set(small_array) |
| 102 | + |
| 103 | + # Check if all elements of small_array are in large_array |
| 104 | + return set_small.issubset(set_large) |
| 105 | + |
| 106 | +def write_best_hap_race_pairs(name_gl, haps, pops, probs, fout, numOfReasults): |
| 107 | + all_res = [] |
| 108 | + |
| 109 | + for i in range(len(probs)): |
| 110 | + pair = haps[i][0] + ";" + pops[i][0] + "," + haps[i][1] + ";" + pops[i][1] |
| 111 | + all_res.append([probs[i], pair]) |
| 112 | + all_res.sort(key=lambda x: x[0], reverse=True) |
| 113 | + # write the output to file |
| 114 | + minBestResult = min(numOfReasults,len(all_res)) |
| 115 | + for k in range(minBestResult): |
| 116 | + fout.write( |
| 117 | + name_gl |
| 118 | + + "," |
| 119 | + + str(all_res[k][1]) |
| 120 | + + "," |
| 121 | + + str(all_res[k][0]) |
| 122 | + + "," |
| 123 | + + str(k) |
| 124 | + + "\n" |
| 125 | + ) |
| 126 | +def write_best_prob(name_gl, res, probs, fout,number_of_pop_results ,sign=","): |
| 127 | + sumProbsDict = defaultdict(list) |
| 128 | + # loop over the result and sum the prob by populations/haplotype |
| 129 | + for k in range(len(res)): |
| 130 | + key = res[k][0] + sign + res[k][1] |
| 131 | + if key in sumProbsDict: |
| 132 | + sumProb = probs[k] + sumProbsDict[key] |
| 133 | + sumProbsDict[key] = sumProb |
| 134 | + else: |
| 135 | + key2 = res[k][1] + sign + res[k][0] |
| 136 | + if key2 in sumProbsDict: |
| 137 | + sumProb = probs[k] + sumProbsDict[key2] |
| 138 | + sumProbsDict[key2] = sumProb |
| 139 | + else: |
| 140 | + sumProbsDict[key] = probs[k] |
| 141 | + |
| 142 | + |
| 143 | + multProbs = [] |
| 144 | + for k in sumProbsDict: |
| 145 | + multProbs.append([sumProbsDict[k], [k, sumProbsDict[k]]]) |
| 146 | + |
| 147 | + multProbs.sort(key=lambda x: x[0], reverse=True) |
| 148 | + |
| 149 | + # write the output to file |
| 150 | + minBestResult =min(len(multProbs),number_of_pop_results) |
| 151 | + for k in range(minBestResult): |
| 152 | + fout.write( |
| 153 | + name_gl |
| 154 | + + "," |
| 155 | + + str(multProbs[k][1][0]) |
| 156 | + + "," |
| 157 | + + str(multProbs[k][0]) |
| 158 | + + "," |
| 159 | + + str(k) |
| 160 | + + "\n" |
| 161 | + ) |
| 162 | + |
| 163 | +def write_umug(id,res_haps,fout,numOfResults): |
| 164 | + |
| 165 | + res_muugs = {} |
| 166 | + for idx ,hap in enumerate(res_haps["Haps"]): |
| 167 | + hap1,hap2 = res_haps["Haps"][idx][0], res_haps["Haps"][idx][1] |
| 168 | + prob = res_haps["Probs"][idx] |
| 169 | + haps = [] |
| 170 | + haps.append(hap1.split('~')) |
| 171 | + haps.append(hap2.split('~')) |
| 172 | + muug = "" |
| 173 | + for i in range(len(haps[0])): |
| 174 | + sort_hap = sorted([haps[0][i], haps[1][i]]) |
| 175 | + muug += sort_hap[0] + "+" + sort_hap[1] + "^" |
| 176 | + muug = muug[:-1] |
| 177 | + if muug in res_muugs.keys(): |
| 178 | + res_muugs[muug] += prob |
| 179 | + else: |
| 180 | + res_muugs[muug] = prob |
| 181 | + pairs = [] |
| 182 | + for key in res_muugs.keys(): |
| 183 | + pairs.append((key, res_muugs[key])) |
| 184 | + pairs = sorted(pairs, key=lambda x: x[1], reverse=True) |
| 185 | + minResults = min(numOfResults,len(pairs)) |
| 186 | + for k in range(minResults): |
| 187 | + fout.write( |
| 188 | + id |
| 189 | + + "," |
| 190 | + + str(pairs[k][0]) |
| 191 | + + "," |
| 192 | + + str(pairs[k][1]) |
| 193 | + + "," |
| 194 | + + str(k) |
| 195 | + + "\n" |
| 196 | + ) |
| 197 | + |
| 198 | +def write_umug_pops(id,res_haps,fout,numOfResults): |
| 199 | + res_muugs = {} |
| 200 | + for idx,pop in enumerate(res_haps["Haps"]): |
| 201 | + pop1,pop2 = res_haps["Pops"][idx][0], res_haps["Pops"][idx][1] |
| 202 | + prob = res_haps["Probs"][idx] |
| 203 | + pops = [pop1,pop2] |
| 204 | + pops = sorted(pops) |
| 205 | + muug = pops[0]+','+pops[1] |
| 206 | + if muug in res_muugs.keys(): |
| 207 | + res_muugs[muug] += prob |
| 208 | + else: |
| 209 | + res_muugs[muug] = prob |
| 210 | + pairs = [] |
| 211 | + for key in res_muugs.keys(): |
| 212 | + pairs.append((key, res_muugs[key])) |
| 213 | + pairs = sorted(pairs, key=lambda x: x[1], reverse=True) |
| 214 | + minResults = min(numOfResults,len(pairs)) |
| 215 | + for k in range(minResults): |
| 216 | + fout.write( |
| 217 | + id |
| 218 | + + "," |
| 219 | + + str(pairs[k][0]) |
| 220 | + + "," |
| 221 | + + str(pairs[k][1]) |
| 222 | + + "," |
| 223 | + + str(k) |
| 224 | + + "\n" |
| 225 | + ) |
| 226 | + |
| 227 | +def write_filter(subject_id,res_haps,fout_hap_haplo,fout_pop_haplo,fout_hap_muug,fout_pop_muug,number_of_results,number_of_pop_results,MUUG_output,haps_output): |
| 228 | + haps = res_haps["Haps"] |
| 229 | + probs = res_haps["Probs"] |
| 230 | + pops = res_haps["Pops"] |
| 231 | + if haps_output: |
| 232 | + write_best_hap_race_pairs( |
| 233 | + subject_id, |
| 234 | + haps, |
| 235 | + pops, |
| 236 | + probs, |
| 237 | + fout_hap_haplo, |
| 238 | + number_of_results |
| 239 | + ) |
| 240 | + write_best_prob(subject_id, pops, probs, fout_pop_haplo,1) |
| 241 | + if MUUG_output: |
| 242 | + write_umug(subject_id,res_haps,fout_hap_muug,number_of_results) |
| 243 | + write_umug_pops(subject_id,res_haps,fout_pop_muug,number_of_pop_results) |
| 244 | + |
| 245 | + |
| 246 | +def change_output_by_extra_gl(config,gls,path_pmug,path_umug,path_umug_pops,path_pmug_pops,path_miss): |
| 247 | + res_haps = create_haps(path_pmug) |
| 248 | + all_data = {"subject_id": [], "res_haps": [], "extra_gl": [], "short_gl": []} |
| 249 | + |
| 250 | + if is_subarray_unordered(gls["subject_id"],res_haps["subject_id"]): |
| 251 | + ids= [] |
| 252 | + haps = [] |
| 253 | + extras = [] |
| 254 | + shorts = [] |
| 255 | + for idx,id in enumerate(res_haps["subject_id"]): |
| 256 | + ids.append(id) |
| 257 | + haps.append(res_haps["res_haps"][idx]) |
| 258 | + gl_idx = gls["subject_id"].index(id) |
| 259 | + extras.append(gls["extra_gl"][gl_idx]) |
| 260 | + shorts.append(gls["short_gl"][gl_idx]) |
| 261 | + all_data["subject_id"] = ids |
| 262 | + all_data["res_haps"] = haps |
| 263 | + all_data["extra_gl"] = extras |
| 264 | + all_data["short_gl"] = shorts |
| 265 | + else: |
| 266 | + print("error we got umug has ids that are not form the gls") |
| 267 | + |
| 268 | + MUUG_output = config["output_MUUG"] |
| 269 | + haps_output = config["output_haplotypes"] |
| 270 | + number_of_results = config["number_of_results"] |
| 271 | + number_of_pop_results = config["number_of_pop_results"] |
| 272 | + |
| 273 | + fout_hap_haplo,fout_pop_haplo,fout_hap_muug,fout_pop_muug ="","","","" |
| 274 | + |
| 275 | + if haps_output: |
| 276 | + fout_hap_haplo = open(path_pmug, "w") |
| 277 | + fout_pop_haplo = open(path_pmug_pops,"w") |
| 278 | + if MUUG_output: |
| 279 | + fout_hap_muug = open(path_umug,"w") |
| 280 | + fout_pop_muug = open(path_umug_pops,"w") |
| 281 | + miss = open(path_miss,"a") |
| 282 | + |
| 283 | + for idx,id in enumerate(all_data["subject_id"]): |
| 284 | + subject_id = id |
| 285 | + res_haps = all_data["res_haps"][idx] |
| 286 | + extra_gl = all_data["extra_gl"][idx] |
| 287 | + |
| 288 | + if len(extra_gl) > 0: |
| 289 | + res_haps = filter_results(res_haps, extra_gl) |
| 290 | + |
| 291 | + if len(res_haps["Haps"]) == 0 : |
| 292 | + gl_idx = gls["subject_id"].index(subject_id) |
| 293 | + miss.write(str(gl_idx) + "," + str(subject_id) + "\n") |
| 294 | + else: |
| 295 | + write_filter(subject_id, res_haps, fout_hap_haplo, fout_pop_haplo, fout_hap_muug, fout_pop_muug,number_of_results,number_of_pop_results,MUUG_output,haps_output) |
| 296 | + |
| 297 | + if MUUG_output: |
| 298 | + fout_hap_muug.close() |
| 299 | + fout_pop_muug.close() |
| 300 | + if haps_output: |
| 301 | + fout_hap_haplo.close() |
| 302 | + fout_pop_haplo.close() |
| 303 | + miss.close() |
0 commit comments