@@ -98,19 +98,24 @@ def ptm_func(x,d0):
9898ptm_func_vec = np .vectorize (ptm_func ) # vector version
9999
100100# Define the d0 functions for numbers and arrays; minimum value = 1.0; from Yang and Skolnick, PROTEINS: Structure, Function, and Bioinformatics 57:702–710 (2004)
101- def calc_d0 (L ):
101+ def calc_d0 (L , pair_type ):
102102 L = float (L )
103- if L < 27 : return 1.0
104- return 1.24 * (L - 15 )** (1.0 / 3.0 ) - 1.8
103+ if L < 27 : L = 27
104+ min_value = 1.0
105+ if pair_type == 'nucleic_acid' : min_value = 2.0
106+ d0 = 1.24 * (L - 15 )** (1.0 / 3.0 ) - 1.8
107+ return max (min_value , d0 )
105108
106- def calc_d0_array (L ):
109+ def calc_d0_array (L , pair_type ):
107110 # Convert L to a NumPy array if it isn't already one (enables flexibility in input types)
108111 L = np .array (L , dtype = float )
109- # Ensure all values of L are at least 19.0
110- L = np . maximum ( L , 26.523 )
111- # Calculate d0 using the vectorized operation
112- return 1.24 * ( L - 15 ) ** ( 1.0 / 3.0 ) - 1.8
112+ L = np . maximum ( 27 , L )
113+ min_value = 1.0
114+
115+ if pair_type == 'nucleic_acid' : min_value = 2.0
113116
117+ # Calculate d0 using the vectorized operation
118+ return np .maximum (min_value , 1.24 * (L - 15 ) ** (1.0 / 3.0 ) - 1.8 )
114119
115120
116121# Define the parse_atom_line function for PDB lines (by column) and mmCIF lines (split by white_space)
@@ -267,6 +272,37 @@ def format_range(start, end):
267272 string = '+' .join (ranges )
268273 return (string )
269274
275+ # Initializes a nested dictionary with all values set to 0
276+ def init_chainpairdict_zeros (chainlist ):
277+ return {chain1 : {chain2 : 0 for chain2 in chainlist if chain1 != chain2 } for chain1 in chainlist }
278+
279+ # Initializes a nested dictionary with NumPy arrays of zeros of a specified size
280+ def init_chainpairdict_npzeros (chainlist , arraysize ):
281+ return {chain1 : {chain2 : np .zeros (arraysize ) for chain2 in chainlist if chain1 != chain2 } for chain1 in chainlist }
282+
283+ # Initializes a nested dictionary with empty sets.
284+ def init_chainpairdict_set (chainlist ):
285+ return {chain1 : {chain2 : set () for chain2 in chainlist if chain1 != chain2 } for chain1 in chainlist }
286+
287+
288+ def classify_chains (chains , residue_types ):
289+ nuc_residue_set = {"DA" , "DC" , "DT" , "DG" , "A" , "C" , "U" , "G" }
290+ chain_types = {}
291+
292+ # Get unique chains and iterate over them
293+ unique_chains = np .unique (chains )
294+ for chain in unique_chains :
295+ # Find indices where the current chain is located
296+ indices = np .where (chains == chain )[0 ]
297+ # Get the residues for these indices
298+ chain_residues = residue_types [indices ]
299+ # Count nucleic acid residues
300+ nuc_count = sum (residue in nuc_residue_set for residue in chain_residues )
301+
302+ # Determine if the chain is a nucleic acid or protein
303+ chain_types [chain ] = 'nucleic_acid' if nuc_count > 0 else 'protein'
304+
305+ return chain_types
270306
271307
272308# Load residues from AlphaFold PDB or mmCIF file into lists; each residue is a dictionary
@@ -284,7 +320,10 @@ def format_range(start, end):
284320residue_set = {"ALA" , "ARG" , "ASN" , "ASP" , "CYS" ,
285321 "GLN" , "GLU" , "GLY" , "HIS" , "ILE" ,
286322 "LEU" , "LYS" , "MET" , "PHE" , "PRO" ,
287- "SER" , "THR" , "TRP" , "TYR" , "VAL" }
323+ "SER" , "THR" , "TRP" , "TYR" , "VAL" ,
324+ "DA" , "DC" , "DT" , "DG" , "A" , "C" , "U" , "G" }
325+
326+ nuc_residue_set = {"DA" , "DC" , "DT" , "DG" , "A" , "C" , "U" , "G" }
288327
289328with open (pdb_path , 'r' ) as PDB :
290329 for line in PDB :
@@ -300,12 +339,11 @@ def format_range(start, end):
300339 atom = parse_cif_atom_line (line , atomsitefield_dict )
301340 else :
302341 atom = parse_pdb_atom_line (line )
303-
304342 if atom is None : # ligand atom
305343 token_mask .append (0 )
306344 continue
307345
308- if atom ['atom_name' ] == "CA" :
346+ if atom ['atom_name' ] == "CA" or "C1" in atom [ 'atom_name' ] :
309347 token_mask .append (1 )
310348 residues .append ({
311349 'atom_num' : atom ['atom_num' ],
@@ -317,7 +355,7 @@ def format_range(start, end):
317355 })
318356 chains .append (atom ['chain_id' ])
319357
320- if atom ['atom_name' ] == "CB" or (atom ['residue_name' ]== "GLY" and atom ['atom_name' ]== "CA" ):
358+ if atom ['atom_name' ] == "CB" or "C3" in atom [ 'atom_name' ] or (atom ['residue_name' ]== "GLY" and atom ['atom_name' ]== "CA" ):
321359 cb_residues .append ({
322360 'atom_num' : atom ['atom_num' ],
323361 'coor' : np .array ([atom ['x' ], atom ['y' ], atom ['z' ]]),
@@ -328,9 +366,9 @@ def format_range(start, end):
328366 })
329367
330368 # add nucleic acids and non-CA atoms in PTM residues to tokens (as 0), whether labeled as "HETATM" (af3) or as "ATOM" (boltz1)
331- if atom ['atom_name' ] != "CA" and atom ['residue_name' ] not in residue_set :
369+ if atom ['atom_name' ] != "CA" and "C1" not in atom [ 'atom_name' ] and atom ['residue_name' ] not in residue_set :
332370 token_mask .append (0 )
333-
371+
334372# Convert structure information to numpy arrays
335373numres = len (residues )
336374CA_atom_num = np .array ([res ['atom_num' ]- 1 for res in residues ]) # for AF3 atom indexing from 0
@@ -340,7 +378,22 @@ def format_range(start, end):
340378unique_chains = np .unique (chains )
341379token_array = np .array (token_mask )
342380ntokens = np .sum (token_array )
381+ residue_types = np .array ([res ['res' ] for res in residues ])
343382
383+ # chain types (nucleic acid (NA) or protein) and chain_pair_types ('nucleic_acid' if either chain is NA) for d0 calculation
384+ # arbitrarily setting d0 to 2.0 for NA/protein or NA/NA chain pairs (approximately 21 base pairs)
385+ d0_nucleic_acid = 2.0
386+ chain_dict = classify_chains (chains , residue_types )
387+ chain_pair_type = init_chainpairdict_zeros (unique_chains )
388+ for chain1 in unique_chains :
389+ for chain2 in unique_chains :
390+ if chain1 == chain2 : continue
391+ if chain_dict [chain1 ] == 'nucleic_acid' or chain_dict [chain2 ] == 'nucleic_acid' :
392+ chain_pair_type [chain1 ][chain2 ]= 'nucleic_acid'
393+ else :
394+ chain_pair_type [chain1 ][chain2 ]= 'protein'
395+ print (chain1 , chain2 , chain_dict [chain1 ], chain_dict [chain2 ], chain_pair_type [chain1 ][chain2 ])
396+
344397# Calculate distance matrix using NumPy broadcasting
345398distances = np .sqrt (((coordinates [:, np .newaxis , :] - coordinates [np .newaxis , :, :])** 2 ).sum (axis = 2 ))
346399
@@ -349,10 +402,19 @@ def format_range(start, end):
349402 if os .path .exists (pae_file_path ):
350403 with open (pae_file_path , 'r' ) as file :
351404 data = json .load (file )
352- iptm_af2 = float (data ['iptm' ])
353- ptm_af2 = float (data ['ptm' ])
354- plddt = np .array (data ['plddt' ])
355- cb_plddt = np .array (data ['plddt' ]) # for pDockQ
405+
406+ if 'iptm' in data : iptm_af2 = float (data ['iptm' ])
407+ else : iptm_af2 = - 1.0
408+ if 'ptm' in data : ptm_af2 = float (data ['ptm' ])
409+ else : ptm_af2 = - 1.0
410+
411+ if 'plddt' in data :
412+ plddt = np .array (data ['plddt' ])
413+ cb_plddt = np .array (data ['plddt' ]) # for pDockQ
414+ else :
415+ plddt = np .zeros (numres )
416+ cb_plddt = np .zeros (numres )
417+
356418 pae_matrix = np .array (data ['pae' ])
357419 else :
358420 print ("AF2 PAE file does not exist: " , pae_file_path )
@@ -434,7 +496,6 @@ def format_range(start, end):
434496 # Set pae_matrix for AF3 from subset of full PAE matrix from json file
435497 token_array = np .array (token_mask )
436498 pae_matrix = pae_matrix_af3 [np .ix_ (token_array .astype (bool ), token_array .astype (bool ))]
437-
438499 # Get iptm matrix from AF3 summary_confidences file
439500 iptm_af3 = {chain1 : {chain2 : 0 for chain2 in unique_chains if chain1 != chain2 } for chain1 in unique_chains }
440501
@@ -478,18 +539,6 @@ def format_range(start, end):
478539# n0dom = number of residues in chain pair that have good PAE values (<cutoff)
479540# n0res = number of residues in chain2 that have good PAE residues for each residue of chain1
480541
481- # Initializes a nested dictionary with all values set to 0
482- def init_chainpairdict_zeros (chainlist ):
483- return {chain1 : {chain2 : 0 for chain2 in chainlist if chain1 != chain2 } for chain1 in chainlist }
484-
485- # Initializes a nested dictionary with NumPy arrays of zeros of a specified size
486- def init_chainpairdict_npzeros (chainlist , arraysize ):
487- return {chain1 : {chain2 : np .zeros (arraysize ) for chain2 in chainlist if chain1 != chain2 } for chain1 in chainlist }
488-
489- # Initializes a nested dictionary with empty sets.
490- def init_chainpairdict_set (chainlist ):
491- return {chain1 : {chain2 : set () for chain2 in chainlist if chain1 != chain2 } for chain1 in chainlist }
492-
493542iptm_d0chn_byres = init_chainpairdict_npzeros (unique_chains , numres )
494543ipsae_d0chn_byres = init_chainpairdict_npzeros (unique_chains , numres )
495544ipsae_d0dom_byres = init_chainpairdict_npzeros (unique_chains , numres )
@@ -541,18 +590,15 @@ def init_chainpairdict_set(chainlist):
541590pDockQ2 = init_chainpairdict_zeros (unique_chains )
542591LIS = init_chainpairdict_zeros (unique_chains )
543592
544-
545593# pDockQ
546594pDockQ_cutoff = 8.0
547595
548596for chain1 in unique_chains :
549597 for chain2 in unique_chains :
550- if chain1 == chain2 :
551- continue
598+ if chain1 == chain2 : continue
552599 npairs = 0
553600 for i in range (numres ):
554- if chains [i ] != chain1 :
555- continue
601+ if chains [i ] != chain1 : continue
556602 valid_pairs = (chains == chain2 ) & (distances [i ] <= pDockQ_cutoff )
557603 npairs += np .sum (valid_pairs )
558604 if valid_pairs .any ():
@@ -600,8 +646,8 @@ def init_chainpairdict_set(chainlist):
600646 else :
601647 mean_plddt = 0.0
602648 x = 0.0
603- pDockQ [chain1 ][chain2 ]= 0.0
604649 nres = 0
650+ pDockQ2 [chain1 ][chain2 ]= 0.0
605651
606652# LIS
607653
@@ -622,7 +668,8 @@ def init_chainpairdict_set(chainlist):
622668 LIS [chain1 ][chain2 ] = 0.0 # No valid values
623669 else :
624670 LIS [chain1 ][chain2 ]= 0.0
625-
671+
672+
626673
627674# calculate ipTM/ipSAE with and without PAE cutoff
628675
@@ -632,14 +679,16 @@ def init_chainpairdict_set(chainlist):
632679 continue
633680
634681 n0chn [chain1 ][chain2 ]= np .sum ( chains == chain1 ) + np .sum (chains == chain2 ) # total number of residues in chain1 and chain2
635- d0chn [chain1 ][chain2 ]= calc_d0 (n0chn [chain1 ][chain2 ])
682+ d0chn [chain1 ][chain2 ]= calc_d0 (n0chn [chain1 ][chain2 ], chain_pair_type [ chain1 ][ chain2 ] )
636683 ptm_matrix_d0chn = np .zeros ((numres ,numres ))
637684 ptm_matrix_d0chn = ptm_func_vec (pae_matrix ,d0chn [chain1 ][chain2 ])
638685
639686 valid_pairs_iptm = (chains == chain2 )
640687 valid_pairs_matrix = (chains == chain2 ) & (pae_matrix < pae_cutoff )
641688
642689 for i in range (numres ):
690+
691+
643692 if chains [i ] != chain1 :
644693 continue
645694
@@ -676,7 +725,7 @@ def init_chainpairdict_set(chainlist):
676725 residues_1 = len (unique_residues_chain1 [chain1 ][chain2 ])
677726 residues_2 = len (unique_residues_chain2 [chain1 ][chain2 ])
678727 n0dom [chain1 ][chain2 ] = residues_1 + residues_2
679- d0dom [chain1 ][chain2 ] = calc_d0 (n0dom [chain1 ][chain2 ])
728+ d0dom [chain1 ][chain2 ] = calc_d0 (n0dom [chain1 ][chain2 ], chain_pair_type [ chain1 ][ chain2 ] )
680729
681730 ptm_matrix_d0dom = np .zeros ((numres ,numres ))
682731 ptm_matrix_d0dom = ptm_func_vec (pae_matrix ,d0dom [chain1 ][chain2 ])
@@ -685,7 +734,7 @@ def init_chainpairdict_set(chainlist):
685734
686735 # Assuming valid_pairs_matrix is already defined
687736 n0res_byres_all = np .sum (valid_pairs_matrix , axis = 1 )
688- d0res_byres_all = calc_d0_array (n0res_byres_all )
737+ d0res_byres_all = calc_d0_array (n0res_byres_all , chain_pair_type [ chain1 ][ chain2 ] )
689738
690739 n0res_byres [chain1 ][chain2 ] = n0res_byres_all
691740 d0res_byres [chain1 ][chain2 ] = d0res_byres_all
@@ -802,7 +851,13 @@ def init_chainpairdict_set(chainlist):
802851 d0res_max [chain2 ][chain1 ]= maxd0
803852
804853
805- chaincolor = {'A' :'magenta' , 'B' :'marine' , 'C' :'lime' , 'D' :'orange' , 'E' :'yellow' , 'F' :'cyan' , 'G' :'lightorange' , 'H' :'pink' }
854+ chaincolor = {'A' :'magenta' , 'B' :'marine' , 'C' :'lime' , 'D' :'orange' ,
855+ 'E' :'yellow' , 'F' :'cyan' , 'G' :'lightorange' , 'H' :'pink' ,
856+ 'I' :'deepteal' , 'J' :'forest' , 'K' :'lightblue' , 'L' :'slate' ,
857+ 'M' :'violet' , 'N' :'arsenic' , 'O' :'iodine' , 'P' :'silver' ,
858+ 'Q' :'red' , 'R' :'sulfur' , 'S' :'purple' , 'T' :'olive' ,
859+ 'U' :'palegreen' , 'V' :'gray90' , 'W' :'blue' , 'X' :'palecyan' ,
860+ 'Y' :'yellow' , 'Z' :'white' }
806861
807862chainpairs = set ()
808863for chain1 in unique_chains :
@@ -819,8 +874,17 @@ def init_chainpairdict_set(chainlist):
819874 for pair in (pair1 , pair2 ):
820875 chain1 = pair [0 ]
821876 chain2 = pair [1 ]
822- color1 = chaincolor [chain1 ]
823- color2 = chaincolor [chain2 ]
877+
878+ if chain1 in chaincolor :
879+ color1 = chaincolor [chain1 ]
880+ else :
881+ color1 = 'magenta'
882+
883+ if chain2 in chaincolor :
884+ color2 = chaincolor [chain2 ]
885+ else :
886+ color2 = 'marine'
887+
824888 residues_1 = len (unique_residues_chain1 [chain1 ][chain2 ])
825889 residues_2 = len (unique_residues_chain2 [chain1 ][chain2 ])
826890 dist_residues_1 = len (dist_unique_residues_chain1 [chain1 ][chain2 ])
0 commit comments