Skip to content

Commit f7cea21

Browse files
authored
Added support for nucleic acids
Added support for nucleic acids out of AF3 or Boltz1. Set minimum value of d0 to 2.0 for nucleic acid (NA) chain pairs or protein/NA chain pairs. This is arbitrary, not trained on any data, and may be changed in calc_d0 and calc_d0_array functions
1 parent cc42919 commit f7cea21

File tree

1 file changed

+108
-44
lines changed

1 file changed

+108
-44
lines changed

ipsae.py

Lines changed: 108 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,24 @@ def ptm_func(x,d0):
9898
ptm_func_vec=np.vectorize(ptm_func) # vector version
9999

100100
# Define the d0 functions for numbers and arrays; minimum value = 1.0; from Yang and Skolnick, PROTEINS: Structure, Function, and Bioinformatics 57:702–710 (2004)
101-
def calc_d0(L):
101+
def calc_d0(L,pair_type):
102102
L=float(L)
103-
if L<27: return 1.0
104-
return 1.24*(L-15)**(1.0/3.0) - 1.8
103+
if L<27: L=27
104+
min_value=1.0
105+
if pair_type=='nucleic_acid': min_value=2.0
106+
d0=1.24*(L-15)**(1.0/3.0) - 1.8
107+
return max(min_value, d0)
105108

106-
def calc_d0_array(L):
109+
def calc_d0_array(L,pair_type):
107110
# Convert L to a NumPy array if it isn't already one (enables flexibility in input types)
108111
L = np.array(L, dtype=float)
109-
# Ensure all values of L are at least 19.0
110-
L = np.maximum(L, 26.523)
111-
# Calculate d0 using the vectorized operation
112-
return 1.24 * (L - 15) ** (1.0 / 3.0) - 1.8
112+
L = np.maximum(27,L)
113+
min_value=1.0
114+
115+
if pair_type=='nucleic_acid': min_value=2.0
113116

117+
# Calculate d0 using the vectorized operation
118+
return np.maximum(min_value, 1.24 * (L - 15) ** (1.0/3.0) - 1.8)
114119

115120

116121
# Define the parse_atom_line function for PDB lines (by column) and mmCIF lines (split by white_space)
@@ -267,6 +272,37 @@ def format_range(start, end):
267272
string='+'.join(ranges)
268273
return(string)
269274

275+
# Initializes a nested dictionary with all values set to 0
276+
def init_chainpairdict_zeros(chainlist):
277+
return {chain1: {chain2: 0 for chain2 in chainlist if chain1 != chain2} for chain1 in chainlist}
278+
279+
# Initializes a nested dictionary with NumPy arrays of zeros of a specified size
280+
def init_chainpairdict_npzeros(chainlist, arraysize):
281+
return {chain1: {chain2: np.zeros(arraysize) for chain2 in chainlist if chain1 != chain2} for chain1 in chainlist}
282+
283+
# Initializes a nested dictionary with empty sets.
284+
def init_chainpairdict_set(chainlist):
285+
return {chain1: {chain2: set() for chain2 in chainlist if chain1 != chain2} for chain1 in chainlist}
286+
287+
288+
def classify_chains(chains, residue_types):
289+
nuc_residue_set = {"DA", "DC", "DT", "DG", "A", "C", "U", "G"}
290+
chain_types = {}
291+
292+
# Get unique chains and iterate over them
293+
unique_chains = np.unique(chains)
294+
for chain in unique_chains:
295+
# Find indices where the current chain is located
296+
indices = np.where(chains == chain)[0]
297+
# Get the residues for these indices
298+
chain_residues = residue_types[indices]
299+
# Count nucleic acid residues
300+
nuc_count = sum(residue in nuc_residue_set for residue in chain_residues)
301+
302+
# Determine if the chain is a nucleic acid or protein
303+
chain_types[chain] = 'nucleic_acid' if nuc_count > 0 else 'protein'
304+
305+
return chain_types
270306

271307

272308
# Load residues from AlphaFold PDB or mmCIF file into lists; each residue is a dictionary
@@ -284,7 +320,10 @@ def format_range(start, end):
284320
residue_set= {"ALA", "ARG", "ASN", "ASP", "CYS",
285321
"GLN", "GLU", "GLY", "HIS", "ILE",
286322
"LEU", "LYS", "MET", "PHE", "PRO",
287-
"SER", "THR", "TRP", "TYR", "VAL"}
323+
"SER", "THR", "TRP", "TYR", "VAL",
324+
"DA", "DC", "DT", "DG", "A", "C", "U", "G"}
325+
326+
nuc_residue_set = {"DA", "DC", "DT", "DG", "A", "C", "U", "G"}
288327

289328
with open(pdb_path, 'r') as PDB:
290329
for line in PDB:
@@ -300,12 +339,11 @@ def format_range(start, end):
300339
atom=parse_cif_atom_line(line, atomsitefield_dict)
301340
else:
302341
atom=parse_pdb_atom_line(line)
303-
304342
if atom is None: # ligand atom
305343
token_mask.append(0)
306344
continue
307345

308-
if atom['atom_name'] == "CA":
346+
if atom['atom_name'] == "CA" or "C1" in atom['atom_name']:
309347
token_mask.append(1)
310348
residues.append({
311349
'atom_num': atom['atom_num'],
@@ -317,7 +355,7 @@ def format_range(start, end):
317355
})
318356
chains.append(atom['chain_id'])
319357

320-
if atom['atom_name'] == "CB" or (atom['residue_name']=="GLY" and atom['atom_name']=="CA"):
358+
if atom['atom_name'] == "CB" or "C3" in atom['atom_name'] or (atom['residue_name']=="GLY" and atom['atom_name']=="CA"):
321359
cb_residues.append({
322360
'atom_num': atom['atom_num'],
323361
'coor': np.array([atom['x'], atom['y'], atom['z']]),
@@ -328,9 +366,9 @@ def format_range(start, end):
328366
})
329367

330368
# add nucleic acids and non-CA atoms in PTM residues to tokens (as 0), whether labeled as "HETATM" (af3) or as "ATOM" (boltz1)
331-
if atom['atom_name'] != "CA" and atom['residue_name'] not in residue_set:
369+
if atom['atom_name'] != "CA" and "C1" not in atom['atom_name'] and atom['residue_name'] not in residue_set:
332370
token_mask.append(0)
333-
371+
334372
# Convert structure information to numpy arrays
335373
numres = len(residues)
336374
CA_atom_num= np.array([res['atom_num']-1 for res in residues]) # for AF3 atom indexing from 0
@@ -340,7 +378,22 @@ def format_range(start, end):
340378
unique_chains = np.unique(chains)
341379
token_array=np.array(token_mask)
342380
ntokens=np.sum(token_array)
381+
residue_types=np.array([res['res'] for res in residues])
343382

383+
# chain types (nucleic acid (NA) or protein) and chain_pair_types ('nucleic_acid' if either chain is NA) for d0 calculation
384+
# arbitrarily setting d0 to 2.0 for NA/protein or NA/NA chain pairs (approximately 21 base pairs)
385+
d0_nucleic_acid=2.0
386+
chain_dict = classify_chains(chains, residue_types)
387+
chain_pair_type = init_chainpairdict_zeros(unique_chains)
388+
for chain1 in unique_chains:
389+
for chain2 in unique_chains:
390+
if chain1==chain2: continue
391+
if chain_dict[chain1] == 'nucleic_acid' or chain_dict[chain2] == 'nucleic_acid':
392+
chain_pair_type[chain1][chain2]='nucleic_acid'
393+
else:
394+
chain_pair_type[chain1][chain2]='protein'
395+
print(chain1, chain2, chain_dict[chain1], chain_dict[chain2], chain_pair_type[chain1][chain2])
396+
344397
# Calculate distance matrix using NumPy broadcasting
345398
distances = np.sqrt(((coordinates[:, np.newaxis, :] - coordinates[np.newaxis, :, :])**2).sum(axis=2))
346399

@@ -349,10 +402,19 @@ def format_range(start, end):
349402
if os.path.exists(pae_file_path):
350403
with open(pae_file_path, 'r') as file:
351404
data = json.load(file)
352-
iptm_af2 = float(data['iptm'])
353-
ptm_af2 = float(data['ptm'])
354-
plddt = np.array(data['plddt'])
355-
cb_plddt = np.array(data['plddt']) # for pDockQ
405+
406+
if 'iptm' in data: iptm_af2 = float(data['iptm'])
407+
else: iptm_af2=-1.0
408+
if 'ptm' in data: ptm_af2 = float(data['ptm'])
409+
else: ptm_af2=-1.0
410+
411+
if 'plddt' in data:
412+
plddt = np.array(data['plddt'])
413+
cb_plddt = np.array(data['plddt']) # for pDockQ
414+
else:
415+
plddt = np.zeros(numres)
416+
cb_plddt = np.zeros(numres)
417+
356418
pae_matrix = np.array(data['pae'])
357419
else:
358420
print("AF2 PAE file does not exist: ", pae_file_path)
@@ -434,7 +496,6 @@ def format_range(start, end):
434496
# Set pae_matrix for AF3 from subset of full PAE matrix from json file
435497
token_array=np.array(token_mask)
436498
pae_matrix = pae_matrix_af3[np.ix_(token_array.astype(bool), token_array.astype(bool))]
437-
438499
# Get iptm matrix from AF3 summary_confidences file
439500
iptm_af3= {chain1: {chain2: 0 for chain2 in unique_chains if chain1 != chain2} for chain1 in unique_chains}
440501

@@ -478,18 +539,6 @@ def format_range(start, end):
478539
# n0dom = number of residues in chain pair that have good PAE values (<cutoff)
479540
# n0res = number of residues in chain2 that have good PAE residues for each residue of chain1
480541

481-
# Initializes a nested dictionary with all values set to 0
482-
def init_chainpairdict_zeros(chainlist):
483-
return {chain1: {chain2: 0 for chain2 in chainlist if chain1 != chain2} for chain1 in chainlist}
484-
485-
# Initializes a nested dictionary with NumPy arrays of zeros of a specified size
486-
def init_chainpairdict_npzeros(chainlist, arraysize):
487-
return {chain1: {chain2: np.zeros(arraysize) for chain2 in chainlist if chain1 != chain2} for chain1 in chainlist}
488-
489-
# Initializes a nested dictionary with empty sets.
490-
def init_chainpairdict_set(chainlist):
491-
return {chain1: {chain2: set() for chain2 in chainlist if chain1 != chain2} for chain1 in chainlist}
492-
493542
iptm_d0chn_byres = init_chainpairdict_npzeros(unique_chains, numres)
494543
ipsae_d0chn_byres = init_chainpairdict_npzeros(unique_chains, numres)
495544
ipsae_d0dom_byres = init_chainpairdict_npzeros(unique_chains, numres)
@@ -541,18 +590,15 @@ def init_chainpairdict_set(chainlist):
541590
pDockQ2 = init_chainpairdict_zeros(unique_chains)
542591
LIS = init_chainpairdict_zeros(unique_chains)
543592

544-
545593
# pDockQ
546594
pDockQ_cutoff=8.0
547595

548596
for chain1 in unique_chains:
549597
for chain2 in unique_chains:
550-
if chain1 == chain2:
551-
continue
598+
if chain1 == chain2: continue
552599
npairs=0
553600
for i in range(numres):
554-
if chains[i] != chain1:
555-
continue
601+
if chains[i] != chain1: continue
556602
valid_pairs = (chains==chain2) & (distances[i] <= pDockQ_cutoff)
557603
npairs += np.sum(valid_pairs)
558604
if valid_pairs.any():
@@ -600,8 +646,8 @@ def init_chainpairdict_set(chainlist):
600646
else:
601647
mean_plddt=0.0
602648
x=0.0
603-
pDockQ[chain1][chain2]=0.0
604649
nres=0
650+
pDockQ2[chain1][chain2]=0.0
605651

606652
# LIS
607653

@@ -622,7 +668,8 @@ def init_chainpairdict_set(chainlist):
622668
LIS[chain1][chain2] = 0.0 # No valid values
623669
else:
624670
LIS[chain1][chain2]=0.0
625-
671+
672+
626673

627674
# calculate ipTM/ipSAE with and without PAE cutoff
628675

@@ -632,14 +679,16 @@ def init_chainpairdict_set(chainlist):
632679
continue
633680

634681
n0chn[chain1][chain2]=np.sum( chains==chain1) + np.sum(chains==chain2) # total number of residues in chain1 and chain2
635-
d0chn[chain1][chain2]=calc_d0(n0chn[chain1][chain2])
682+
d0chn[chain1][chain2]=calc_d0(n0chn[chain1][chain2], chain_pair_type[chain1][chain2])
636683
ptm_matrix_d0chn=np.zeros((numres,numres))
637684
ptm_matrix_d0chn=ptm_func_vec(pae_matrix,d0chn[chain1][chain2])
638685

639686
valid_pairs_iptm = (chains == chain2)
640687
valid_pairs_matrix = (chains == chain2) & (pae_matrix < pae_cutoff)
641688

642689
for i in range(numres):
690+
691+
643692
if chains[i] != chain1:
644693
continue
645694

@@ -676,7 +725,7 @@ def init_chainpairdict_set(chainlist):
676725
residues_1 = len(unique_residues_chain1[chain1][chain2])
677726
residues_2 = len(unique_residues_chain2[chain1][chain2])
678727
n0dom[chain1][chain2] = residues_1+residues_2
679-
d0dom[chain1][chain2] = calc_d0(n0dom[chain1][chain2])
728+
d0dom[chain1][chain2] = calc_d0(n0dom[chain1][chain2], chain_pair_type[chain1][chain2])
680729

681730
ptm_matrix_d0dom = np.zeros((numres,numres))
682731
ptm_matrix_d0dom = ptm_func_vec(pae_matrix,d0dom[chain1][chain2])
@@ -685,7 +734,7 @@ def init_chainpairdict_set(chainlist):
685734

686735
# Assuming valid_pairs_matrix is already defined
687736
n0res_byres_all = np.sum(valid_pairs_matrix, axis=1)
688-
d0res_byres_all = calc_d0_array(n0res_byres_all)
737+
d0res_byres_all = calc_d0_array(n0res_byres_all, chain_pair_type[chain1][chain2])
689738

690739
n0res_byres[chain1][chain2] = n0res_byres_all
691740
d0res_byres[chain1][chain2] = d0res_byres_all
@@ -802,7 +851,13 @@ def init_chainpairdict_set(chainlist):
802851
d0res_max[chain2][chain1]=maxd0
803852

804853

805-
chaincolor={'A':'magenta', 'B':'marine', 'C':'lime', 'D':'orange', 'E':'yellow', 'F':'cyan', 'G':'lightorange', 'H':'pink'}
854+
chaincolor={'A':'magenta', 'B':'marine', 'C':'lime', 'D':'orange',
855+
'E':'yellow', 'F':'cyan', 'G':'lightorange', 'H':'pink',
856+
'I':'deepteal', 'J':'forest', 'K':'lightblue', 'L':'slate',
857+
'M':'violet', 'N':'arsenic', 'O':'iodine', 'P':'silver',
858+
'Q':'red', 'R':'sulfur', 'S':'purple', 'T':'olive',
859+
'U':'palegreen', 'V':'gray90', 'W':'blue', 'X':'palecyan',
860+
'Y':'yellow', 'Z':'white'}
806861

807862
chainpairs=set()
808863
for chain1 in unique_chains:
@@ -819,8 +874,17 @@ def init_chainpairdict_set(chainlist):
819874
for pair in (pair1, pair2):
820875
chain1=pair[0]
821876
chain2=pair[1]
822-
color1=chaincolor[chain1]
823-
color2=chaincolor[chain2]
877+
878+
if chain1 in chaincolor:
879+
color1=chaincolor[chain1]
880+
else:
881+
color1='magenta'
882+
883+
if chain2 in chaincolor:
884+
color2=chaincolor[chain2]
885+
else:
886+
color2='marine'
887+
824888
residues_1 = len(unique_residues_chain1[chain1][chain2])
825889
residues_2 = len(unique_residues_chain2[chain1][chain2])
826890
dist_residues_1 = len(dist_unique_residues_chain1[chain1][chain2])

0 commit comments

Comments
 (0)