|
| 1 | +import argparse |
| 2 | +import numpy as np |
| 3 | +import torch |
| 4 | + |
| 5 | +from export_esmfold_full_from_safetensors import ( |
| 6 | + SafeTensorsReader, |
| 7 | + build_model, |
| 8 | + load_esm2_weights, |
| 9 | + load_rest_weights, |
| 10 | + compute_language_model_representations, |
| 11 | + infer_config, |
| 12 | +) |
| 13 | + |
| 14 | +RESTYPES = [ |
| 15 | + "A","R","N","D","C","Q","E","G","H","I","L","K","M","F","P","S","T","W","Y","V", |
| 16 | +] |
| 17 | +RESTYPES_WITH_X = RESTYPES + ["X"] |
| 18 | +RESTYPE_ORDER_WITH_X = {aa: i for i, aa in enumerate(RESTYPES_WITH_X)} |
| 19 | + |
| 20 | + |
| 21 | +def sequence_to_af2_indices(seq: str): |
| 22 | + return [RESTYPE_ORDER_WITH_X.get(ch, RESTYPE_ORDER_WITH_X["X"]) for ch in seq] |
| 23 | + |
| 24 | + |
| 25 | +def main(): |
| 26 | + parser = argparse.ArgumentParser() |
| 27 | + parser.add_argument("--safetensors", required=True) |
| 28 | + parser.add_argument("--sequence", default="ELLKKLLEELKG") |
| 29 | + parser.add_argument("--output", default="esmfold_block0_debug.npz") |
| 30 | + args = parser.parse_args() |
| 31 | + |
| 32 | + torch.set_grad_enabled(False) |
| 33 | + |
| 34 | + from pathlib import Path |
| 35 | + with SafeTensorsReader(Path(args.safetensors)) as reader: |
| 36 | + model = build_model(reader, use_esm_attn_map=False) |
| 37 | + load_esm2_weights(model.esm, reader) |
| 38 | + load_rest_weights(model, reader) |
| 39 | + |
| 40 | + model.esm.eval() |
| 41 | + model.esm_s_mlp.eval() |
| 42 | + model.embedding.eval() |
| 43 | + model.trunk.eval() |
| 44 | + model.distogram_head.eval() |
| 45 | + model.ptm_head.eval() |
| 46 | + model.lm_head.eval() |
| 47 | + model.lddt_head.eval() |
| 48 | + |
| 49 | + # build s_s_0, s_z_0 |
| 50 | + seq = args.sequence |
| 51 | + aa = torch.tensor([sequence_to_af2_indices(seq)], dtype=torch.long) |
| 52 | + mask = torch.ones_like(aa) |
| 53 | + |
| 54 | + esmaa = model.af2_to_esm[(aa + 1).masked_fill(mask != 1, 0)] |
| 55 | + esm_s, _ = compute_language_model_representations(model, esmaa, use_esm_attn_map=False) |
| 56 | + esm_s = esm_s.to(model.esm_s_combine.dtype).detach() |
| 57 | + esm_s = (torch.softmax(model.esm_s_combine, 0).unsqueeze(0) @ esm_s).squeeze(2) |
| 58 | + s_s_0 = model.esm_s_mlp(esm_s) + model.embedding(aa) |
| 59 | + B, L, _ = s_s_0.shape |
| 60 | + s_z_0 = s_s_0.new_zeros(B, L, L, model.cfg["c_z"]) |
| 61 | + |
| 62 | + block = model.trunk.blocks[0] |
| 63 | + tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) |
| 64 | + |
| 65 | + bias = block.pair_to_sequence(s_z_0) |
| 66 | + seq_ln = block.layernorm_1(s_s_0) |
| 67 | + proj_out = block.seq_attention.proj(seq_ln) |
| 68 | + B, L, C3 = proj_out.shape |
| 69 | + H = block.seq_attention.num_heads |
| 70 | + head_width = block.seq_attention.head_width |
| 71 | + t = proj_out.view(B, L, H, 3 * head_width) |
| 72 | + q = t[..., :head_width] |
| 73 | + k = t[..., head_width:2 * head_width] |
| 74 | + v = t[..., 2 * head_width:3 * head_width] |
| 75 | + seq_attn_out, attn = block.seq_attention(seq_ln, mask=mask, bias=bias) |
| 76 | + g_proj_out = block.seq_attention.g_proj(seq_ln) |
| 77 | + q_py = q.permute(0, 2, 1, 3) |
| 78 | + k_py = k.permute(0, 2, 1, 3) |
| 79 | + logits = torch.einsum("...qc,...kc->...qk", q_py, k_py) |
| 80 | + seq_state_attn = s_s_0 + seq_attn_out |
| 81 | + seq_state_mlp = block.mlp_seq(seq_state_attn) |
| 82 | + |
| 83 | + # manual attention to capture pre-o_proj outputs |
| 84 | + t_attn = block.seq_attention.proj(seq_ln) |
| 85 | + t_attn = t_attn.reshape(B, L, H, 3 * head_width).permute(0, 2, 1, 3) |
| 86 | + q2 = t_attn[..., :head_width] |
| 87 | + k2 = t_attn[..., head_width:2 * head_width] |
| 88 | + v2 = t_attn[..., 2 * head_width:3 * head_width] |
| 89 | + q2 = q2 * block.seq_attention.rescale_factor |
| 90 | + a2 = torch.einsum("...qc,...kc->...qk", q2, k2) |
| 91 | + a2 = a2 + bias.permute(0, 3, 1, 2) |
| 92 | + if mask is not None: |
| 93 | + a2 = a2.masked_fill(mask[:, None, None, :].expand_as(a2) == 0, float("-inf")) |
| 94 | + a2 = torch.softmax(a2, dim=-1) |
| 95 | + y2 = torch.einsum("...hqk,...hkc->...qhc", a2, v2) |
| 96 | + y2 = y2.reshape(B, L, H * head_width) |
| 97 | + if block.seq_attention.gated: |
| 98 | + y2_gated = block.seq_attention.g_proj(seq_ln).sigmoid() * y2 |
| 99 | + else: |
| 100 | + y2_gated = y2 |
| 101 | + |
| 102 | + pair_state = s_z_0 + block.sequence_to_pair(seq_state_mlp) |
| 103 | + |
| 104 | + tri_mul_out = block.tri_mul_out(pair_state, mask=tri_mask) |
| 105 | + pair_state = pair_state + tri_mul_out |
| 106 | + |
| 107 | + tri_mul_in = block.tri_mul_in(pair_state, mask=tri_mask) |
| 108 | + pair_state = pair_state + tri_mul_in |
| 109 | + |
| 110 | + pair_state_before_start = pair_state |
| 111 | + tri_att_start = block.tri_att_start(pair_state_before_start, mask=tri_mask, chunk_size=None) |
| 112 | + pair_state = pair_state_before_start + tri_att_start |
| 113 | + |
| 114 | + pair_state_before_end = pair_state |
| 115 | + tri_att_end = block.tri_att_end(pair_state_before_end, mask=tri_mask, chunk_size=None) |
| 116 | + pair_state = pair_state_before_end + tri_att_end |
| 117 | + |
| 118 | + # Triangle attention internals (start/end) to pinpoint mismatches |
| 119 | + from openfold.utils.tensor_utils import permute_final_dims, flatten_final_dims |
| 120 | + |
| 121 | + def tri_attn_internals(ta, x, mask): |
| 122 | + if not ta.starting: |
| 123 | + x = x.transpose(-2, -3) |
| 124 | + mask = mask.transpose(-1, -2) |
| 125 | + x = ta.layer_norm(x) |
| 126 | + mask_bias = (ta.inf * (mask - 1))[..., :, None, None, :] |
| 127 | + triangle_bias = permute_final_dims(ta.linear(x), (2, 0, 1)) |
| 128 | + triangle_bias = triangle_bias.unsqueeze(-4) |
| 129 | + q, k, v = ta.mha._prep_qkv(x, x, apply_scale=True) |
| 130 | + logits = torch.matmul(q, k.transpose(-1, -2)) |
| 131 | + logits = logits + mask_bias + triangle_bias |
| 132 | + attn = torch.softmax(logits, dim=-1) |
| 133 | + o = torch.matmul(attn, v) # [*, H, Q, C] |
| 134 | + o = o.transpose(-2, -3) # [*, Q, H, C] |
| 135 | + if ta.mha.linear_g is not None: |
| 136 | + g = torch.sigmoid(ta.mha.linear_g(x)) |
| 137 | + g = g.view(g.shape[:-1] + (ta.mha.no_heads, -1)) |
| 138 | + o = o * g |
| 139 | + o_flat = flatten_final_dims(o, 2) # [*, Q, H*C] |
| 140 | + return logits, attn, o_flat, mask_bias, triangle_bias, x, q, k, v |
| 141 | + |
| 142 | + tri_att_start_logits, tri_att_start_attn, tri_att_start_pre_o, tri_att_start_mask_bias, tri_att_start_triangle_bias, tri_att_start_ln, tri_att_start_q, tri_att_start_k, tri_att_start_v = tri_attn_internals( |
| 143 | + block.tri_att_start, pair_state_before_start, tri_mask |
| 144 | + ) |
| 145 | + tri_att_end_logits, tri_att_end_attn, tri_att_end_pre_o, tri_att_end_mask_bias, tri_att_end_triangle_bias, tri_att_end_ln, tri_att_end_q, tri_att_end_k, tri_att_end_v = tri_attn_internals( |
| 146 | + block.tri_att_end, pair_state_before_end, tri_mask |
| 147 | + ) |
| 148 | + |
| 149 | + pair_state_mlp = block.mlp_pair(pair_state) |
| 150 | + |
| 151 | + export = { |
| 152 | + "s_s_0": s_s_0.cpu().numpy(), |
| 153 | + "s_z_0": s_z_0.cpu().numpy(), |
| 154 | + "bias": bias.cpu().numpy(), |
| 155 | + "seq_ln": seq_ln.cpu().numpy(), |
| 156 | + "proj_out": proj_out.cpu().numpy(), |
| 157 | + "seq_attn_out": seq_attn_out.cpu().numpy(), |
| 158 | + "seq_attn_pre_o_proj": y2_gated.cpu().numpy(), |
| 159 | + "seq_attn_pre_gate": y2.cpu().numpy(), |
| 160 | + "g_proj_out": g_proj_out.cpu().numpy(), |
| 161 | + "logits": logits.cpu().numpy(), |
| 162 | + "q": q.cpu().numpy(), |
| 163 | + "k": k.cpu().numpy(), |
| 164 | + "v": v.cpu().numpy(), |
| 165 | + "attn": attn.cpu().numpy(), |
| 166 | + "seq_state_attn": seq_state_attn.cpu().numpy(), |
| 167 | + "seq_state_mlp": seq_state_mlp.cpu().numpy(), |
| 168 | + "pair_state_seq2pair": (s_z_0 + block.sequence_to_pair(seq_state_mlp)).cpu().numpy(), |
| 169 | + "tri_mul_out": tri_mul_out.cpu().numpy(), |
| 170 | + "tri_mul_in": tri_mul_in.cpu().numpy(), |
| 171 | + "pair_state_before_start": pair_state_before_start.cpu().numpy(), |
| 172 | + "pair_state_before_end": pair_state_before_end.cpu().numpy(), |
| 173 | + "tri_att_start": tri_att_start.cpu().numpy(), |
| 174 | + "tri_att_end": tri_att_end.cpu().numpy(), |
| 175 | + "tri_att_start_logits": tri_att_start_logits.cpu().numpy(), |
| 176 | + "tri_att_start_attn": tri_att_start_attn.cpu().numpy(), |
| 177 | + "tri_att_start_pre_o": tri_att_start_pre_o.cpu().numpy(), |
| 178 | + "tri_att_start_mask_bias": tri_att_start_mask_bias.cpu().numpy(), |
| 179 | + "tri_att_start_triangle_bias": tri_att_start_triangle_bias.cpu().numpy(), |
| 180 | + "tri_att_start_ln": tri_att_start_ln.cpu().numpy(), |
| 181 | + "tri_att_start_q": tri_att_start_q.cpu().numpy(), |
| 182 | + "tri_att_start_k": tri_att_start_k.cpu().numpy(), |
| 183 | + "tri_att_start_v": tri_att_start_v.cpu().numpy(), |
| 184 | + "tri_att_end_logits": tri_att_end_logits.cpu().numpy(), |
| 185 | + "tri_att_end_attn": tri_att_end_attn.cpu().numpy(), |
| 186 | + "tri_att_end_pre_o": tri_att_end_pre_o.cpu().numpy(), |
| 187 | + "tri_att_end_mask_bias": tri_att_end_mask_bias.cpu().numpy(), |
| 188 | + "tri_att_end_triangle_bias": tri_att_end_triangle_bias.cpu().numpy(), |
| 189 | + "tri_att_end_ln": tri_att_end_ln.cpu().numpy(), |
| 190 | + "tri_att_end_q": tri_att_end_q.cpu().numpy(), |
| 191 | + "tri_att_end_k": tri_att_end_k.cpu().numpy(), |
| 192 | + "tri_att_end_v": tri_att_end_v.cpu().numpy(), |
| 193 | + "pair_state_final": pair_state_mlp.cpu().numpy(), |
| 194 | + } |
| 195 | + np.savez(args.output, **export) |
| 196 | + |
| 197 | + |
| 198 | +if __name__ == "__main__": |
| 199 | + main() |
0 commit comments