|
2 | 2 | "cells": [ |
3 | 3 | { |
4 | 4 | "cell_type": "code", |
5 | | - "execution_count": 2, |
| 5 | + "execution_count": 1, |
6 | 6 | "id": "59ee2444", |
7 | 7 | "metadata": {}, |
8 | 8 | "outputs": [ |
|
20 | 20 | }, |
21 | 21 | { |
22 | 22 | "cell_type": "code", |
23 | | - "execution_count": 3, |
| 23 | + "execution_count": 2, |
24 | 24 | "id": "10fe5139", |
25 | 25 | "metadata": {}, |
26 | 26 | "outputs": [], |
|
89 | 89 | }, |
90 | 90 | { |
91 | 91 | "cell_type": "code", |
92 | | - "execution_count": 39, |
| 92 | + "execution_count": 3, |
93 | 93 | "id": "4c3ff08e", |
94 | 94 | "metadata": {}, |
95 | 95 | "outputs": [], |
96 | 96 | "source": [ |
97 | 97 | "import torch\n", |
98 | 98 | "from foldtree2.src.pdbgraph import PDB2PyG\n", |
99 | | - "from foldtree2.src.rigid_utils import *\n", |
| 99 | + "from foldtree2.src.rigid_utils import Rigid\n", |
| 100 | + "from foldtree2.src.losses.fape import (\n", |
| 101 | + " compute_chain_positions,\n", |
| 102 | + " rotation_matrix_to_quaternion,\n", |
| 103 | + " reconstruct_positions,\n", |
| 104 | + " quaternion_multiply,\n", |
| 105 | + " quaternion_rotate,\n", |
| 106 | + " fape_loss,\n", |
| 107 | + " lddt_loss,\n", |
| 108 | + " compute_lddt_quaternions\n", |
| 109 | + ")\n", |
100 | 110 | "\n", |
101 | 111 | "import matplotlib.pyplot as plt\n", |
102 | 112 | "from mpl_toolkits.mplot3d import Axes3D\n", |
103 | | - "from foldtree2.src.losses.losses import compute_chain_positions\n", |
104 | | - "\n", |
105 | | - "# Example PDB file path\n", |
106 | | - "pdb_path = './foldtree2/config/1eei.pdb'\n", |
107 | | - "\n", |
108 | | - "# Initialize parser\n", |
109 | | - "pdb2pyg = PDB2PyG(aapropcsv = 'foldtree2/config/aaindex1.csv')\n" |
| 113 | + "import numpy as np\n", |
| 114 | + "from scipy.spatial import procrustes" |
110 | 115 | ] |
111 | 116 | }, |
112 | 117 | { |
113 | 118 | "cell_type": "code", |
114 | | - "execution_count": 42, |
| 119 | + "execution_count": 4, |
115 | 120 | "id": "63373b61", |
116 | 121 | "metadata": {}, |
117 | 122 | "outputs": [], |
|
234 | 239 | "\t# Stack into tensors\n", |
235 | 240 | "\tif rotations:\n", |
236 | 241 | "\t\trotation_tensor = torch.stack(rotations, dim=0)\n", |
237 | | - "\t\t#normalize rotation matrices to ensure orthogonality\n", |
238 | | - "\t\trotation_tensor = rotation_tensor / rotation_tensor.norm(dim=[1, 2], keepdim=True)\n", |
| 242 | + "\t\t# Rotation matrices from Rigid.get_rots() are already orthogonal - no normalization needed!\n", |
239 | 243 | "\telse:\n", |
240 | 244 | "\t\trotation_tensor = torch.empty(0, 3, 3)\n", |
241 | 245 | "\t\n", |
|
245 | 249 | "\t\ttranslation_tensor = torch.empty(0, 3)\n", |
246 | 250 | "\tif quats:\n", |
247 | 251 | "\t\tquats = torch.stack(quats, dim=0)\n", |
248 | | - "\t\t#normalize quaternion vectors\n", |
| 252 | + "\t\t# Normalize quaternion vectors to ensure unit norm\n", |
249 | 253 | "\t\tquats = quats / quats.norm(dim=1, keepdim=True)\n", |
250 | 254 | "\telse:\n", |
251 | 255 | "\t\tquats = torch.empty(0, 4)\n", |
|
372 | 376 | }, |
373 | 377 | { |
374 | 378 | "cell_type": "code", |
375 | | - "execution_count": 43, |
376 | | - "id": "48e0d8f9", |
| 379 | + "execution_count": 5, |
| 380 | + "id": "d5f32f57", |
| 381 | + "metadata": {}, |
| 382 | + "outputs": [], |
| 383 | + "source": [ |
| 384 | + "#transform a strucutre's residues into rigid body transforms\n", |
| 385 | + "from Bio.PDB import PDBParser\n", |
| 386 | + "pdb_parser = PDBParser(QUIET=True)\n", |
| 387 | + "\n", |
| 388 | + "pdb_path = './foldtree2/config/1eei.pdb'" |
| 389 | + ] |
| 390 | + }, |
| 391 | + { |
| 392 | + "cell_type": "code", |
| 393 | + "execution_count": 6, |
| 394 | + "id": "635676ab", |
377 | 395 | "metadata": {}, |
378 | 396 | "outputs": [ |
379 | 397 | { |
|
485 | 503 | " [ 0.5021, 0.2582, -0.7162, 0.4102]])\n", |
486 | 504 | "Transformation Analysis:\n", |
487 | 505 | " Number of residue-to-residue transformations: 102\n", |
488 | | - " Mean rotation angle: 120.00° ± 15.96°\n", |
| 506 | + " Mean rotation angle: 123.56° ± 31.83°\n", |
489 | 507 | " Mean translation distance: 57.890 ± 17.816 Å\n", |
490 | 508 | " Translation distance range: 10.431 - 98.682 Å\n", |
491 | | - " Rotation angle range: 88.05° - 142.06°\n" |
| 509 | + " Rotation angle range: 64.84° - 179.54°\n" |
492 | 510 | ] |
493 | 511 | } |
494 | 512 | ], |
495 | 513 | "source": [ |
496 | | - "#transform a strucutre's residues into rigid body transforms\n", |
497 | | - "from Bio.PDB import PDBParser\n", |
498 | | - "pdb_parser = PDBParser(QUIET=True)\n", |
| 514 | + "\n", |
499 | 515 | "pdb_structure = pdb_parser.get_structure('1eei', pdb_path)\n", |
500 | 516 | "pdb_transforms = get_residue_transformations_from_pdb(\n", |
501 | 517 | "\tpdb_structure, chain_id=None, device=None\n", |
|
507 | 523 | "pdb_analysis = analyze_residue_transformations(pdb_transforms, verbose=True)\n" |
508 | 524 | ] |
509 | 525 | }, |
| 526 | + { |
| 527 | + "cell_type": "code", |
| 528 | + "execution_count": 7, |
| 529 | + "id": "f5540be4", |
| 530 | + "metadata": {}, |
| 531 | + "outputs": [ |
| 532 | + { |
| 533 | + "name": "stdout", |
| 534 | + "output_type": "stream", |
| 535 | + "text": [ |
| 536 | + "============================================================\n", |
| 537 | + "VALIDATION TESTS FOR get_residue_transformations\n", |
| 538 | + "============================================================\n", |
| 539 | + "\n", |
| 540 | + "1. Quaternion Unit Norm Test:\n", |
| 541 | + " Min quaternion norm: 1.000000\n", |
| 542 | + " Max quaternion norm: 1.000000\n", |
| 543 | + " Mean quaternion norm: 1.000000\n", |
| 544 | + " ✓ PASS\n", |
| 545 | + "\n", |
| 546 | + "2. Rotation Matrix Orthogonality Test:\n", |
| 547 | + " Max orthogonality error: 6.597833e-07\n", |
| 548 | + " Mean orthogonality error: 2.584600e-07\n", |
| 549 | + " ✓ PASS\n", |
| 550 | + "\n", |
| 551 | + "3. Rotation Matrix Determinant Test:\n", |
| 552 | + " Min determinant: 1.000000\n", |
| 553 | + " Max determinant: 1.000000\n", |
| 554 | + " Mean determinant: 1.000000\n", |
| 555 | + " ✓ PASS\n", |
| 556 | + "\n", |
| 557 | + "4. Quaternion-to-Rotation Consistency Test:\n", |
| 558 | + " Max consistency error: 2.828426e+00\n", |
| 559 | + " Mean consistency error: 2.283508e+00\n", |
| 560 | + " ✗ FAIL\n", |
| 561 | + "\n", |
| 562 | + "5. Transformation Chain Test:\n", |
| 563 | + " Rotation composition error: 2.522347e+00\n", |
| 564 | + " Translation composition error: 6.872144e+01\n", |
| 565 | + " ✗ FAIL\n", |
| 566 | + "\n", |
| 567 | + "============================================================\n", |
| 568 | + "VALIDATION COMPLETE\n", |
| 569 | + "============================================================\n" |
| 570 | + ] |
| 571 | + } |
| 572 | + ], |
| 573 | + "source": [ |
| 574 | + "# Validation tests for get_residue_transformations\n", |
| 575 | + "\n", |
| 576 | + "print(\"=\"*60)\n", |
| 577 | + "print(\"VALIDATION TESTS FOR get_residue_transformations\")\n", |
| 578 | + "print(\"=\"*60)\n", |
| 579 | + "\n", |
| 580 | + "# Test 1: Check quaternion properties\n", |
| 581 | + "print(\"\\n1. Quaternion Unit Norm Test:\")\n", |
| 582 | + "quat_norms = torch.norm(pdb_transforms['quats'], dim=1)\n", |
| 583 | + "print(f\" Min quaternion norm: {quat_norms.min():.6f}\")\n", |
| 584 | + "print(f\" Max quaternion norm: {quat_norms.max():.6f}\")\n", |
| 585 | + "print(f\" Mean quaternion norm: {quat_norms.mean():.6f}\")\n", |
| 586 | + "print(f\" ✓ PASS\" if torch.allclose(quat_norms, torch.ones_like(quat_norms), atol=1e-5) else \" ✗ FAIL\")\n", |
| 587 | + "\n", |
| 588 | + "# Test 2: Check rotation matrix orthogonality (R^T @ R should be identity)\n", |
| 589 | + "print(\"\\n2. Rotation Matrix Orthogonality Test:\")\n", |
| 590 | + "R = pdb_transforms['rotations']\n", |
| 591 | + "identity_errors = []\n", |
| 592 | + "for i in range(len(R)):\n", |
| 593 | + " should_be_identity = R[i].T @ R[i]\n", |
| 594 | + " error = torch.norm(should_be_identity - torch.eye(3), p='fro')\n", |
| 595 | + " identity_errors.append(error.item())\n", |
| 596 | + "\n", |
| 597 | + "identity_errors = torch.tensor(identity_errors)\n", |
| 598 | + "print(f\" Max orthogonality error: {identity_errors.max():.6e}\")\n", |
| 599 | + "print(f\" Mean orthogonality error: {identity_errors.mean():.6e}\")\n", |
| 600 | + "print(f\" ✓ PASS\" if identity_errors.max() < 1e-5 else \" ✗ FAIL\")\n", |
| 601 | + "\n", |
| 602 | + "# Test 3: Check rotation matrix determinant (should be +1)\n", |
| 603 | + "print(\"\\n3. Rotation Matrix Determinant Test:\")\n", |
| 604 | + "determinants = torch.stack([torch.det(R[i]) for i in range(len(R))])\n", |
| 605 | + "print(f\" Min determinant: {determinants.min():.6f}\")\n", |
| 606 | + "print(f\" Max determinant: {determinants.max():.6f}\")\n", |
| 607 | + "print(f\" Mean determinant: {determinants.mean():.6f}\")\n", |
| 608 | + "print(f\" ✓ PASS\" if torch.allclose(determinants, torch.ones_like(determinants), atol=1e-5) else \" ✗ FAIL\")\n", |
| 609 | + "\n", |
| 610 | + "# Test 4: Check quaternion-to-rotation consistency\n", |
| 611 | + "print(\"\\n4. Quaternion-to-Rotation Consistency Test:\")\n", |
| 612 | + "from foldtree2.src.losses.fape import quaternion_to_rotation_matrix\n", |
| 613 | + "R_from_quats = quaternion_to_rotation_matrix(pdb_transforms['quats'])\n", |
| 614 | + "consistency_error = torch.norm(R_from_quats - R, p='fro', dim=[1, 2])\n", |
| 615 | + "print(f\" Max consistency error: {consistency_error.max():.6e}\")\n", |
| 616 | + "print(f\" Mean consistency error: {consistency_error.mean():.6e}\")\n", |
| 617 | + "print(f\" ✓ PASS\" if consistency_error.max() < 1e-4 else \" ✗ FAIL\")\n", |
| 618 | + "\n", |
| 619 | + "# Test 5: Verify transformation composition\n", |
| 620 | + "print(\"\\n5. Transformation Chain Test:\")\n", |
| 621 | + "# Apply transforms sequentially and check against direct computation\n", |
| 622 | + "transforms = pdb_transforms['transforms']\n", |
| 623 | + "composed_transform = transforms[0]\n", |
| 624 | + "for i in range(1, min(5, len(transforms))): # Test first 5 transforms\n", |
| 625 | + " composed_transform = composed_transform.compose(transforms[i])\n", |
| 626 | + "\n", |
| 627 | + "# The composed transform should match what we get from composing frames directly\n", |
| 628 | + "residue_list = list(pdb_structure[0][list(pdb_structure[0].child_dict.keys())[0]])\n", |
| 629 | + "coords_0 = {\n", |
| 630 | + " 'N': torch.tensor(residue_list[0]['N'].get_coord(), dtype=torch.float32),\n", |
| 631 | + " 'CA': torch.tensor(residue_list[0]['CA'].get_coord(), dtype=torch.float32),\n", |
| 632 | + " 'C': torch.tensor(residue_list[0]['C'].get_coord(), dtype=torch.float32)\n", |
| 633 | + "}\n", |
| 634 | + "coords_5 = {\n", |
| 635 | + " 'N': torch.tensor(residue_list[5]['N'].get_coord(), dtype=torch.float32),\n", |
| 636 | + " 'CA': torch.tensor(residue_list[5]['CA'].get_coord(), dtype=torch.float32),\n", |
| 637 | + " 'C': torch.tensor(residue_list[5]['C'].get_coord(), dtype=torch.float32)\n", |
| 638 | + "}\n", |
| 639 | + "\n", |
| 640 | + "frame_0 = Rigid.from_3_points(\n", |
| 641 | + " p_neg_x_axis=coords_0['N'].unsqueeze(0),\n", |
| 642 | + " origin=coords_0['CA'].unsqueeze(0),\n", |
| 643 | + " p_xy_plane=coords_0['C'].unsqueeze(0)\n", |
| 644 | + ")\n", |
| 645 | + "frame_5 = Rigid.from_3_points(\n", |
| 646 | + " p_neg_x_axis=coords_5['N'].unsqueeze(0),\n", |
| 647 | + " origin=coords_5['CA'].unsqueeze(0),\n", |
| 648 | + " p_xy_plane=coords_5['C'].unsqueeze(0)\n", |
| 649 | + ")\n", |
| 650 | + "\n", |
| 651 | + "direct_transform = frame_5.compose(frame_0.invert())\n", |
| 652 | + "rotation_error = torch.norm(\n", |
| 653 | + " composed_transform.get_rots().get_rot_mats() - direct_transform.get_rots().get_rot_mats(), \n", |
| 654 | + " p='fro'\n", |
| 655 | + ")\n", |
| 656 | + "translation_error = torch.norm(\n", |
| 657 | + " composed_transform.get_trans() - direct_transform.get_trans()\n", |
| 658 | + ")\n", |
| 659 | + "\n", |
| 660 | + "print(f\" Rotation composition error: {rotation_error.item():.6e}\")\n", |
| 661 | + "print(f\" Translation composition error: {translation_error.item():.6e}\")\n", |
| 662 | + "print(f\" ✓ PASS\" if rotation_error < 1e-4 and translation_error < 1e-4 else \" ✗ FAIL\")\n", |
| 663 | + "\n", |
| 664 | + "print(\"\\n\" + \"=\"*60)\n", |
| 665 | + "print(\"VALIDATION COMPLETE\")\n", |
| 666 | + "print(\"=\"*60)" |
| 667 | + ] |
| 668 | + }, |
510 | 669 | { |
511 | 670 | "cell_type": "code", |
512 | 671 | "execution_count": 44, |
|
553 | 712 | "\tplt.show()\n", |
554 | 713 | "\n", |
555 | 714 | "\n", |
556 | | - "def reconstruct_positions(R, T):\n", |
557 | | - "\t\"\"\"\n", |
558 | | - "\tReconstruct 3D positions from a sequence of rotation matrices and translation vectors.\n", |
559 | | - "\t\n", |
560 | | - "\tArgs:\n", |
561 | | - "\t\tR (torch.Tensor): Rotation matrices of shape (N, 3, 3)\n", |
562 | | - "\t\tT (torch.Tensor): Translation vectors of shape (N, 3)\n", |
563 | | - "\t\t\n", |
564 | | - "\tReturns:\n", |
565 | | - "\t\ttorch.Tensor: Reconstructed positions of shape (N+1, 3), starting from origin\n", |
566 | | - "\t\"\"\"\n", |
567 | | - "\tpositions = torch.zeros(len(T) + 1, 3, dtype=T.dtype, device=T.device)\n", |
568 | | - "\tcurrent_pos = torch.zeros(3, dtype=T.dtype, device=T.device)\n", |
569 | | - "\t\n", |
570 | | - "\tfor i in range(len(T)):\n", |
571 | | - "\t\tcurrent_pos = R[i] @ current_pos + T[i]\n", |
572 | | - "\t\tpositions[i + 1] = current_pos\n", |
573 | | - "\t\n", |
574 | | - "\treturn positions\n", |
575 | | - "\n", |
576 | 715 | "def plot_reconstructed_chain(R, T, title=\"Reconstructed Chain\"):\n", |
577 | | - "\tR = R.detach().cpu()\n", |
578 | | - "\tT = T.detach().cpu()\n", |
579 | 716 | "\tpositions = reconstruct_positions(R, T)\n", |
580 | | - "\t\n", |
581 | | - "\tplot_points(positions, title)\n", |
582 | | - "\n" |
| 717 | + "\tplot_points(positions, title=title)" |
583 | 718 | ] |
584 | 719 | }, |
585 | 720 | { |
|
930 | 1065 | } |
931 | 1066 | ], |
932 | 1067 | "source": [ |
933 | | - "from src.losses.losses import compute_chain_positions\n", |
934 | | - "\n", |
935 | 1068 | "# Extract original CA coordinates\n", |
936 | 1069 | "ca_coords = pdb2pyg.extract_pdb_coordinates(pdb_path, atom_type=\"CA\")\n", |
937 | 1070 | "\n", |
|
1018 | 1151 | } |
1019 | 1152 | ], |
1020 | 1153 | "source": [ |
1021 | | - "from src.losses.losses import compute_chain_positions\n", |
1022 | | - "\n", |
1023 | 1154 | "def plot_quaternion_and_rt_alignment(quats, translations, R, T, title=\"Alignment Plot\"):\n", |
1024 | 1155 | "\t\"\"\"\n", |
1025 | 1156 | "\tPlot points created from quaternion chain and from RT values to check alignment.\n", |
|
1141 | 1272 | "plot_quaternion_chain(quaternions_noisy, t_noisy, \"Noisy Quaternion Chain Positions\")\n", |
1142 | 1273 | "\n", |
1143 | 1274 | "# --- Compute FAPE loss and lDDT loss ---\n", |
1144 | | - "from losses.losses import fape_loss, lddt_loss\n", |
1145 | 1275 | "\n", |
1146 | 1276 | "# FAPE loss (using original as true, noisy as pred)\n", |
1147 | 1277 | "batch = torch.zeros(coords.shape[0], dtype=torch.long) # single batch\n", |
|
1169 | 1299 | "ax.legend()\n", |
1170 | 1300 | "plt.show()" |
1171 | 1301 | ] |
1172 | | - }, |
1173 | | - { |
1174 | | - "cell_type": "code", |
1175 | | - "execution_count": null, |
1176 | | - "id": "0121113b", |
1177 | | - "metadata": {}, |
1178 | | - "outputs": [], |
1179 | | - "source": [] |
1180 | 1302 | } |
1181 | 1303 | ], |
1182 | 1304 | "metadata": { |
1183 | 1305 | "kernelspec": { |
1184 | | - "display_name": "Python 3 (ipykernel)", |
| 1306 | + "display_name": "foldtree2", |
1185 | 1307 | "language": "python", |
1186 | 1308 | "name": "python3" |
| 1309 | + }, |
| 1310 | + "language_info": { |
| 1311 | + "codemirror_mode": { |
| 1312 | + "name": "ipython", |
| 1313 | + "version": 3 |
| 1314 | + }, |
| 1315 | + "file_extension": ".py", |
| 1316 | + "mimetype": "text/x-python", |
| 1317 | + "name": "python", |
| 1318 | + "nbconvert_exporter": "python", |
| 1319 | + "pygments_lexer": "ipython3", |
| 1320 | + "version": "3.9.23" |
1187 | 1321 | } |
1188 | 1322 | }, |
1189 | 1323 | "nbformat": 4, |
|
0 commit comments