Skip to content

Commit 45bc8de

Browse files
author
dmoi
committed
fxing geometry stuff
1 parent 829880f commit 45bc8de

File tree

5 files changed

+2195
-3341
lines changed

5 files changed

+2195
-3341
lines changed

foldtree2/notebooks/experiments/test_fapeloss.ipynb

Lines changed: 195 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 2,
5+
"execution_count": 1,
66
"id": "59ee2444",
77
"metadata": {},
88
"outputs": [
@@ -20,7 +20,7 @@
2020
},
2121
{
2222
"cell_type": "code",
23-
"execution_count": 3,
23+
"execution_count": 2,
2424
"id": "10fe5139",
2525
"metadata": {},
2626
"outputs": [],
@@ -89,29 +89,34 @@
8989
},
9090
{
9191
"cell_type": "code",
92-
"execution_count": 39,
92+
"execution_count": 3,
9393
"id": "4c3ff08e",
9494
"metadata": {},
9595
"outputs": [],
9696
"source": [
9797
"import torch\n",
9898
"from foldtree2.src.pdbgraph import PDB2PyG\n",
99-
"from foldtree2.src.rigid_utils import *\n",
99+
"from foldtree2.src.rigid_utils import Rigid\n",
100+
"from foldtree2.src.losses.fape import (\n",
101+
" compute_chain_positions,\n",
102+
" rotation_matrix_to_quaternion,\n",
103+
" reconstruct_positions,\n",
104+
" quaternion_multiply,\n",
105+
" quaternion_rotate,\n",
106+
" fape_loss,\n",
107+
" lddt_loss,\n",
108+
" compute_lddt_quaternions\n",
109+
")\n",
100110
"\n",
101111
"import matplotlib.pyplot as plt\n",
102112
"from mpl_toolkits.mplot3d import Axes3D\n",
103-
"from foldtree2.src.losses.losses import compute_chain_positions\n",
104-
"\n",
105-
"# Example PDB file path\n",
106-
"pdb_path = './foldtree2/config/1eei.pdb'\n",
107-
"\n",
108-
"# Initialize parser\n",
109-
"pdb2pyg = PDB2PyG(aapropcsv = 'foldtree2/config/aaindex1.csv')\n"
113+
"import numpy as np\n",
114+
"from scipy.spatial import procrustes"
110115
]
111116
},
112117
{
113118
"cell_type": "code",
114-
"execution_count": 42,
119+
"execution_count": 4,
115120
"id": "63373b61",
116121
"metadata": {},
117122
"outputs": [],
@@ -234,8 +239,7 @@
234239
"\t# Stack into tensors\n",
235240
"\tif rotations:\n",
236241
"\t\trotation_tensor = torch.stack(rotations, dim=0)\n",
237-
"\t\t#normalize rotation matrices to ensure orthogonality\n",
238-
"\t\trotation_tensor = rotation_tensor / rotation_tensor.norm(dim=[1, 2], keepdim=True)\n",
242+
"\t\t# Rotation matrices from Rigid.get_rots() are already orthogonal - no normalization needed!\n",
239243
"\telse:\n",
240244
"\t\trotation_tensor = torch.empty(0, 3, 3)\n",
241245
"\t\n",
@@ -245,7 +249,7 @@
245249
"\t\ttranslation_tensor = torch.empty(0, 3)\n",
246250
"\tif quats:\n",
247251
"\t\tquats = torch.stack(quats, dim=0)\n",
248-
"\t\t#normalize quaternion vectors\n",
252+
"\t\t# Normalize quaternion vectors to ensure unit norm\n",
249253
"\t\tquats = quats / quats.norm(dim=1, keepdim=True)\n",
250254
"\telse:\n",
251255
"\t\tquats = torch.empty(0, 4)\n",
@@ -372,8 +376,22 @@
372376
},
373377
{
374378
"cell_type": "code",
375-
"execution_count": 43,
376-
"id": "48e0d8f9",
379+
"execution_count": 5,
380+
"id": "d5f32f57",
381+
"metadata": {},
382+
"outputs": [],
383+
"source": [
384+
"#transform a strucutre's residues into rigid body transforms\n",
385+
"from Bio.PDB import PDBParser\n",
386+
"pdb_parser = PDBParser(QUIET=True)\n",
387+
"\n",
388+
"pdb_path = './foldtree2/config/1eei.pdb'"
389+
]
390+
},
391+
{
392+
"cell_type": "code",
393+
"execution_count": 6,
394+
"id": "635676ab",
377395
"metadata": {},
378396
"outputs": [
379397
{
@@ -485,17 +503,15 @@
485503
" [ 0.5021, 0.2582, -0.7162, 0.4102]])\n",
486504
"Transformation Analysis:\n",
487505
" Number of residue-to-residue transformations: 102\n",
488-
" Mean rotation angle: 120.00° ± 15.96°\n",
506+
" Mean rotation angle: 123.56° ± 31.83°\n",
489507
" Mean translation distance: 57.890 ± 17.816 Å\n",
490508
" Translation distance range: 10.431 - 98.682 Å\n",
491-
" Rotation angle range: 88.05° - 142.06°\n"
509+
" Rotation angle range: 64.84° - 179.54°\n"
492510
]
493511
}
494512
],
495513
"source": [
496-
"#transform a strucutre's residues into rigid body transforms\n",
497-
"from Bio.PDB import PDBParser\n",
498-
"pdb_parser = PDBParser(QUIET=True)\n",
514+
"\n",
499515
"pdb_structure = pdb_parser.get_structure('1eei', pdb_path)\n",
500516
"pdb_transforms = get_residue_transformations_from_pdb(\n",
501517
"\tpdb_structure, chain_id=None, device=None\n",
@@ -507,6 +523,149 @@
507523
"pdb_analysis = analyze_residue_transformations(pdb_transforms, verbose=True)\n"
508524
]
509525
},
526+
{
527+
"cell_type": "code",
528+
"execution_count": 7,
529+
"id": "f5540be4",
530+
"metadata": {},
531+
"outputs": [
532+
{
533+
"name": "stdout",
534+
"output_type": "stream",
535+
"text": [
536+
"============================================================\n",
537+
"VALIDATION TESTS FOR get_residue_transformations\n",
538+
"============================================================\n",
539+
"\n",
540+
"1. Quaternion Unit Norm Test:\n",
541+
" Min quaternion norm: 1.000000\n",
542+
" Max quaternion norm: 1.000000\n",
543+
" Mean quaternion norm: 1.000000\n",
544+
" ✓ PASS\n",
545+
"\n",
546+
"2. Rotation Matrix Orthogonality Test:\n",
547+
" Max orthogonality error: 6.597833e-07\n",
548+
" Mean orthogonality error: 2.584600e-07\n",
549+
" ✓ PASS\n",
550+
"\n",
551+
"3. Rotation Matrix Determinant Test:\n",
552+
" Min determinant: 1.000000\n",
553+
" Max determinant: 1.000000\n",
554+
" Mean determinant: 1.000000\n",
555+
" ✓ PASS\n",
556+
"\n",
557+
"4. Quaternion-to-Rotation Consistency Test:\n",
558+
" Max consistency error: 2.828426e+00\n",
559+
" Mean consistency error: 2.283508e+00\n",
560+
" ✗ FAIL\n",
561+
"\n",
562+
"5. Transformation Chain Test:\n",
563+
" Rotation composition error: 2.522347e+00\n",
564+
" Translation composition error: 6.872144e+01\n",
565+
" ✗ FAIL\n",
566+
"\n",
567+
"============================================================\n",
568+
"VALIDATION COMPLETE\n",
569+
"============================================================\n"
570+
]
571+
}
572+
],
573+
"source": [
574+
"# Validation tests for get_residue_transformations\n",
575+
"\n",
576+
"print(\"=\"*60)\n",
577+
"print(\"VALIDATION TESTS FOR get_residue_transformations\")\n",
578+
"print(\"=\"*60)\n",
579+
"\n",
580+
"# Test 1: Check quaternion properties\n",
581+
"print(\"\\n1. Quaternion Unit Norm Test:\")\n",
582+
"quat_norms = torch.norm(pdb_transforms['quats'], dim=1)\n",
583+
"print(f\" Min quaternion norm: {quat_norms.min():.6f}\")\n",
584+
"print(f\" Max quaternion norm: {quat_norms.max():.6f}\")\n",
585+
"print(f\" Mean quaternion norm: {quat_norms.mean():.6f}\")\n",
586+
"print(f\" ✓ PASS\" if torch.allclose(quat_norms, torch.ones_like(quat_norms), atol=1e-5) else \" ✗ FAIL\")\n",
587+
"\n",
588+
"# Test 2: Check rotation matrix orthogonality (R^T @ R should be identity)\n",
589+
"print(\"\\n2. Rotation Matrix Orthogonality Test:\")\n",
590+
"R = pdb_transforms['rotations']\n",
591+
"identity_errors = []\n",
592+
"for i in range(len(R)):\n",
593+
" should_be_identity = R[i].T @ R[i]\n",
594+
" error = torch.norm(should_be_identity - torch.eye(3), p='fro')\n",
595+
" identity_errors.append(error.item())\n",
596+
"\n",
597+
"identity_errors = torch.tensor(identity_errors)\n",
598+
"print(f\" Max orthogonality error: {identity_errors.max():.6e}\")\n",
599+
"print(f\" Mean orthogonality error: {identity_errors.mean():.6e}\")\n",
600+
"print(f\" ✓ PASS\" if identity_errors.max() < 1e-5 else \" ✗ FAIL\")\n",
601+
"\n",
602+
"# Test 3: Check rotation matrix determinant (should be +1)\n",
603+
"print(\"\\n3. Rotation Matrix Determinant Test:\")\n",
604+
"determinants = torch.stack([torch.det(R[i]) for i in range(len(R))])\n",
605+
"print(f\" Min determinant: {determinants.min():.6f}\")\n",
606+
"print(f\" Max determinant: {determinants.max():.6f}\")\n",
607+
"print(f\" Mean determinant: {determinants.mean():.6f}\")\n",
608+
"print(f\" ✓ PASS\" if torch.allclose(determinants, torch.ones_like(determinants), atol=1e-5) else \" ✗ FAIL\")\n",
609+
"\n",
610+
"# Test 4: Check quaternion-to-rotation consistency\n",
611+
"print(\"\\n4. Quaternion-to-Rotation Consistency Test:\")\n",
612+
"from foldtree2.src.losses.fape import quaternion_to_rotation_matrix\n",
613+
"R_from_quats = quaternion_to_rotation_matrix(pdb_transforms['quats'])\n",
614+
"consistency_error = torch.norm(R_from_quats - R, p='fro', dim=[1, 2])\n",
615+
"print(f\" Max consistency error: {consistency_error.max():.6e}\")\n",
616+
"print(f\" Mean consistency error: {consistency_error.mean():.6e}\")\n",
617+
"print(f\" ✓ PASS\" if consistency_error.max() < 1e-4 else \" ✗ FAIL\")\n",
618+
"\n",
619+
"# Test 5: Verify transformation composition\n",
620+
"print(\"\\n5. Transformation Chain Test:\")\n",
621+
"# Apply transforms sequentially and check against direct computation\n",
622+
"transforms = pdb_transforms['transforms']\n",
623+
"composed_transform = transforms[0]\n",
624+
"for i in range(1, min(5, len(transforms))): # Test first 5 transforms\n",
625+
" composed_transform = composed_transform.compose(transforms[i])\n",
626+
"\n",
627+
"# The composed transform should match what we get from composing frames directly\n",
628+
"residue_list = list(pdb_structure[0][list(pdb_structure[0].child_dict.keys())[0]])\n",
629+
"coords_0 = {\n",
630+
" 'N': torch.tensor(residue_list[0]['N'].get_coord(), dtype=torch.float32),\n",
631+
" 'CA': torch.tensor(residue_list[0]['CA'].get_coord(), dtype=torch.float32),\n",
632+
" 'C': torch.tensor(residue_list[0]['C'].get_coord(), dtype=torch.float32)\n",
633+
"}\n",
634+
"coords_5 = {\n",
635+
" 'N': torch.tensor(residue_list[5]['N'].get_coord(), dtype=torch.float32),\n",
636+
" 'CA': torch.tensor(residue_list[5]['CA'].get_coord(), dtype=torch.float32),\n",
637+
" 'C': torch.tensor(residue_list[5]['C'].get_coord(), dtype=torch.float32)\n",
638+
"}\n",
639+
"\n",
640+
"frame_0 = Rigid.from_3_points(\n",
641+
" p_neg_x_axis=coords_0['N'].unsqueeze(0),\n",
642+
" origin=coords_0['CA'].unsqueeze(0),\n",
643+
" p_xy_plane=coords_0['C'].unsqueeze(0)\n",
644+
")\n",
645+
"frame_5 = Rigid.from_3_points(\n",
646+
" p_neg_x_axis=coords_5['N'].unsqueeze(0),\n",
647+
" origin=coords_5['CA'].unsqueeze(0),\n",
648+
" p_xy_plane=coords_5['C'].unsqueeze(0)\n",
649+
")\n",
650+
"\n",
651+
"direct_transform = frame_5.compose(frame_0.invert())\n",
652+
"rotation_error = torch.norm(\n",
653+
" composed_transform.get_rots().get_rot_mats() - direct_transform.get_rots().get_rot_mats(), \n",
654+
" p='fro'\n",
655+
")\n",
656+
"translation_error = torch.norm(\n",
657+
" composed_transform.get_trans() - direct_transform.get_trans()\n",
658+
")\n",
659+
"\n",
660+
"print(f\" Rotation composition error: {rotation_error.item():.6e}\")\n",
661+
"print(f\" Translation composition error: {translation_error.item():.6e}\")\n",
662+
"print(f\" ✓ PASS\" if rotation_error < 1e-4 and translation_error < 1e-4 else \" ✗ FAIL\")\n",
663+
"\n",
664+
"print(\"\\n\" + \"=\"*60)\n",
665+
"print(\"VALIDATION COMPLETE\")\n",
666+
"print(\"=\"*60)"
667+
]
668+
},
510669
{
511670
"cell_type": "code",
512671
"execution_count": 44,
@@ -553,33 +712,9 @@
553712
"\tplt.show()\n",
554713
"\n",
555714
"\n",
556-
"def reconstruct_positions(R, T):\n",
557-
"\t\"\"\"\n",
558-
"\tReconstruct 3D positions from a sequence of rotation matrices and translation vectors.\n",
559-
"\t\n",
560-
"\tArgs:\n",
561-
"\t\tR (torch.Tensor): Rotation matrices of shape (N, 3, 3)\n",
562-
"\t\tT (torch.Tensor): Translation vectors of shape (N, 3)\n",
563-
"\t\t\n",
564-
"\tReturns:\n",
565-
"\t\ttorch.Tensor: Reconstructed positions of shape (N+1, 3), starting from origin\n",
566-
"\t\"\"\"\n",
567-
"\tpositions = torch.zeros(len(T) + 1, 3, dtype=T.dtype, device=T.device)\n",
568-
"\tcurrent_pos = torch.zeros(3, dtype=T.dtype, device=T.device)\n",
569-
"\t\n",
570-
"\tfor i in range(len(T)):\n",
571-
"\t\tcurrent_pos = R[i] @ current_pos + T[i]\n",
572-
"\t\tpositions[i + 1] = current_pos\n",
573-
"\t\n",
574-
"\treturn positions\n",
575-
"\n",
576715
"def plot_reconstructed_chain(R, T, title=\"Reconstructed Chain\"):\n",
577-
"\tR = R.detach().cpu()\n",
578-
"\tT = T.detach().cpu()\n",
579716
"\tpositions = reconstruct_positions(R, T)\n",
580-
"\t\n",
581-
"\tplot_points(positions, title)\n",
582-
"\n"
717+
"\tplot_points(positions, title=title)"
583718
]
584719
},
585720
{
@@ -930,8 +1065,6 @@
9301065
}
9311066
],
9321067
"source": [
933-
"from src.losses.losses import compute_chain_positions\n",
934-
"\n",
9351068
"# Extract original CA coordinates\n",
9361069
"ca_coords = pdb2pyg.extract_pdb_coordinates(pdb_path, atom_type=\"CA\")\n",
9371070
"\n",
@@ -1018,8 +1151,6 @@
10181151
}
10191152
],
10201153
"source": [
1021-
"from src.losses.losses import compute_chain_positions\n",
1022-
"\n",
10231154
"def plot_quaternion_and_rt_alignment(quats, translations, R, T, title=\"Alignment Plot\"):\n",
10241155
"\t\"\"\"\n",
10251156
"\tPlot points created from quaternion chain and from RT values to check alignment.\n",
@@ -1141,7 +1272,6 @@
11411272
"plot_quaternion_chain(quaternions_noisy, t_noisy, \"Noisy Quaternion Chain Positions\")\n",
11421273
"\n",
11431274
"# --- Compute FAPE loss and lDDT loss ---\n",
1144-
"from losses.losses import fape_loss, lddt_loss\n",
11451275
"\n",
11461276
"# FAPE loss (using original as true, noisy as pred)\n",
11471277
"batch = torch.zeros(coords.shape[0], dtype=torch.long) # single batch\n",
@@ -1169,21 +1299,25 @@
11691299
"ax.legend()\n",
11701300
"plt.show()"
11711301
]
1172-
},
1173-
{
1174-
"cell_type": "code",
1175-
"execution_count": null,
1176-
"id": "0121113b",
1177-
"metadata": {},
1178-
"outputs": [],
1179-
"source": []
11801302
}
11811303
],
11821304
"metadata": {
11831305
"kernelspec": {
1184-
"display_name": "Python 3 (ipykernel)",
1306+
"display_name": "foldtree2",
11851307
"language": "python",
11861308
"name": "python3"
1309+
},
1310+
"language_info": {
1311+
"codemirror_mode": {
1312+
"name": "ipython",
1313+
"version": 3
1314+
},
1315+
"file_extension": ".py",
1316+
"mimetype": "text/x-python",
1317+
"name": "python",
1318+
"nbconvert_exporter": "python",
1319+
"pygments_lexer": "ipython3",
1320+
"version": "3.9.23"
11871321
}
11881322
},
11891323
"nbformat": 4,

0 commit comments

Comments
 (0)