77i - residue sequence length (source)
88j - residue sequence length (target)
99m - atom sequence length
10+ c - coordinates (3 for spatial)
1011d - feature dimension
1112ds - feature dimension (single)
1213dp - feature dimension (pairwise)
@@ -861,7 +862,7 @@ def __init__(
861862 # final projection of mean pooled repr -> out
862863
863864 self .to_out = nn .Sequential (
864- LinearNoBias (dim , dim ),
865+ LinearNoBias (dim , dim_pairwise ),
865866 nn .ReLU ()
866867 )
867868
@@ -873,7 +874,7 @@ def forward(
873874 template_mask : Bool ['b t' ],
874875 pairwise_repr : Float ['b n n dp' ],
875876 mask : Bool ['b n' ] | None = None ,
876- ) -> Float ['b n n d ' ]:
877+ ) -> Float ['b n n dp ' ]:
877878
878879 num_templates = templates .shape [1 ]
879880
@@ -884,7 +885,8 @@ def forward(
884885
885886 v , merged_batch_ps = pack_one (v , '* i j d' )
886887
887- mask = repeat (mask , 'b n -> (b t) n' , t = num_templates )
888+ if exists (mask ):
889+ mask = repeat (mask , 'b n -> (b t) n' , t = num_templates )
888890
889891 for block in self .pairformer_stack :
890892 v = block (
@@ -1815,7 +1817,7 @@ def forward(
18151817 pairwise_repr : Float ['b n n dp' ],
18161818 pred_atom_pos : Float ['b n c' ],
18171819 mask : Bool ['b n' ] | None = None ,
1818- calc_pae_logits_and_loss = True
1820+ return_pae_logits = True
18191821
18201822 ) -> ConfidenceHeadLogits [
18211823 Float ['b pae n n' ] | None ,
@@ -1854,7 +1856,7 @@ def forward(
18541856
18551857 pae_logits = None
18561858
1857- if calc_pae_logits_and_loss :
1859+ if return_pae_logits :
18581860 pae_logits = self .to_pae_logits (pairwise_repr )
18591861
18601862 # return all logits
@@ -1863,21 +1865,248 @@ def forward(
18631865
18641866# main class
18651867
1868+ LossBreakdown = namedtuple ('LossBreakdown' , [
1869+ 'distogram' ,
1870+ 'pae' ,
1871+ 'pdt' ,
1872+ 'plddt' ,
1873+ 'resolved'
1874+ ])
1875+
18661876class Alphafold3 (Module ):
1877+ """ Algorithm 1 """
1878+
1879+ @typecheck
18671880 def __init__ (
18681881 self ,
18691882 * ,
1883+ dim_atom_inputs ,
1884+ dim_additional_residue_feats ,
1885+ dim_template_feats ,
1886+ dim_template_model = 64 ,
1887+ atoms_per_window = 27 ,
1888+ dim_atom = 128 ,
1889+ dim_atompair = 16 ,
1890+ dim_input_embedder_token = 384 ,
1891+ dim_single = 384 ,
1892+ dim_pairwise = 128 ,
1893+ atompair_dist_bins : Float [' dist_bins' ] = torch .linspace (3 , 20 , 37 ),
1894+ ignore_index = - 1 ,
1895+ num_dist_bins = 38 ,
1896+ num_plddt_bins = 50 ,
1897+ num_pde_bins = 64 ,
1898+ num_pae_bins = 64 ,
18701899 loss_confidence_weight = 1e-4 ,
18711900 loss_distogram_weight = 1e-2 ,
1872- loss_diffusion = 4.
1901+ loss_diffusion_weight = 4. ,
1902+ input_embedder_kwargs : dict = dict (
1903+ atom_transformer_blocks = 3 ,
1904+ atom_transformer_heads = 4 ,
1905+ atom_transformer_kwargs = dict ()
1906+ ),
1907+ confidence_head_kwargs : dict = dict (
1908+ pairformer_depth = 4
1909+ ),
1910+ template_embedder_kwargs : dict = dict (
1911+ pairformer_stack_depth = 2 ,
1912+ pairwise_block_kwargs = dict (),
1913+ ),
1914+ msa_module_kwargs : dict = dict (
1915+ depth = 4 ,
1916+ dim_msa = 64 ,
1917+ dim_msa_input = None ,
1918+ outer_product_mean_dim_hidden = 32 ,
1919+ msa_pwa_dropout_row_prob = 0.15 ,
1920+ msa_pwa_heads = 8 ,
1921+ msa_pwa_dim_head = 32 ,
1922+ pairwise_block_kwargs = dict ()
1923+ ),
1924+ pairformer_stack : dict = dict (
1925+ depth = 48 ,
1926+ pair_bias_attn_dim_head = 64 ,
1927+ pair_bias_attn_heads = 16 ,
1928+ dropout_row_prob = 0.25 ,
1929+ pairwise_block_kwargs = dict ()
1930+ )
18731931 ):
18741932 super ().__init__ ()
18751933
1934+ self .atoms_per_window = atoms_per_window
1935+
1936+ # input feature embedder
1937+
1938+ self .input_embedder = InputFeatureEmbedder (
1939+ dim_atom_inputs = dim_atom_inputs ,
1940+ dim_additional_residue_feats = dim_additional_residue_feats ,
1941+ atoms_per_window = atoms_per_window ,
1942+ dim_atom = dim_atom ,
1943+ dim_atompair = dim_atompair ,
1944+ dim_token = dim_input_embedder_token ,
1945+ dim_single = dim_single ,
1946+ dim_pairwise = dim_pairwise ,
1947+ ** input_embedder_kwargs
1948+ )
1949+
1950+ dim_single_inputs = dim_input_embedder_token + dim_additional_residue_feats
1951+
1952+ # templates
1953+
1954+ self .template_embedder = TemplateEmbedder (
1955+ dim_template_feats = dim_template_feats ,
1956+ dim = dim_template_model ,
1957+ dim_pairwise = dim_pairwise ,
1958+ ** template_embedder_kwargs
1959+ )
1960+
1961+ # msa
1962+
1963+ self .msa_module = MSAModule (
1964+ dim_single = dim_single ,
1965+ dim_pairwise = dim_pairwise ,
1966+ ** msa_module_kwargs
1967+ )
1968+
1969+ # main pairformer trunk, 48 layers
1970+
1971+ self .pairformer = PairformerStack (
1972+ dim_single = dim_single ,
1973+ dim_pairwise = dim_pairwise ,
1974+ ** pairformer_stack
1975+ )
1976+
1977+ # recycling related
1978+
1979+ self .recycle_single = nn .Sequential (
1980+ nn .LayerNorm (dim_single ),
1981+ LinearNoBias (dim_single , dim_single )
1982+ )
1983+
1984+ self .recycle_pairwise = nn .Sequential (
1985+ nn .LayerNorm (dim_pairwise ),
1986+ LinearNoBias (dim_pairwise , dim_pairwise )
1987+ )
1988+
1989+ # logit heads
1990+
1991+ self .distogram_head = DistogramHead (
1992+ dim_pairwise = dim_pairwise ,
1993+ num_dist_bins = num_dist_bins
1994+ )
1995+
1996+ self .confidence_head = ConfidenceHead (
1997+ dim_single_inputs = dim_single_inputs ,
1998+ atompair_dist_bins = atompair_dist_bins ,
1999+ dim_single = dim_single ,
2000+ dim_pairwise = dim_pairwise ,
2001+ num_plddt_bins = num_plddt_bins ,
2002+ num_pde_bins = num_pde_bins ,
2003+ num_pae_bins = num_pae_bins ,
2004+ ** confidence_head_kwargs
2005+ )
2006+
2007+ # loss related
2008+
2009+ self .ignore_index = ignore_index
2010+ self .loss_distogram_weight = loss_distogram_weight
2011+ self .loss_confidence_weight = loss_confidence_weight
2012+ self .loss_diffusion_weight = loss_diffusion_weight
18762013
18772014 @typecheck
18782015 def forward (
18792016 self ,
18802017 * ,
1881- include_pae_loss = False # turned on in latter part of training
1882- ):
1883- return
2018+ atom_inputs : Float ['b m dai' ],
2019+ atom_mask : Bool ['b m' ],
2020+ atompair_feats : Float ['b m m dap' ],
2021+ additional_residue_feats : Float ['b n rf' ],
2022+ msa : Float ['b s n d' ],
2023+ templates : Float ['b t n n dt' ],
2024+ template_mask : Bool ['b t' ],
2025+ num_recycling_steps : int = 1 ,
2026+ distance_labels : Int ['b n n' ] | None = None ,
2027+ pae_labels : Int ['b n n' ] | None = None ,
2028+ pde_labels : Int ['b n n' ] | None = None ,
2029+ plddt_labels : Int ['b n' ] | None = None ,
2030+ resolved_labels : Int ['b n' ] | None = None ,
2031+ ) -> Float ['b m c' ] | Float ['' ]:
2032+
2033+ # embed inputs
2034+
2035+ (
2036+ single_inputs ,
2037+ single_init ,
2038+ pairwise_init ,
2039+ atom_feats ,
2040+ atompair_feats
2041+ ) = self .input_embedder (
2042+ atom_inputs = atom_inputs ,
2043+ atom_mask = atom_mask ,
2044+ atompair_feats = atompair_feats ,
2045+ additional_residue_feats = additional_residue_feats
2046+ )
2047+
2048+ w = self .atoms_per_window
2049+
2050+ mask = reduce (atom_mask , 'b (n w) -> b n' , w = w , reduction = 'any' )
2051+
2052+ # init recycled single and pairwise
2053+
2054+ recycled_pairwise = recycled_single = None
2055+ single = pairwise = None
2056+
2057+ # for each recycling step
2058+
2059+ for _ in range (num_recycling_steps ):
2060+
2061+ # handle recycled single and pairwise if not first step
2062+
2063+ recycled_single = recycled_pairwise = 0.
2064+
2065+ if exists (single ):
2066+ recycled_single = self .recycle_single (single )
2067+
2068+ if exists (pairwise ):
2069+ recycled_pairwise = self .recycle_pairwise (pairwise )
2070+
2071+ single = single_init + recycled_single
2072+ pairwise = pairwise_init + recycled_pairwise
2073+
2074+ # else go through main transformer trunk from alphafold2
2075+
2076+ # templates
2077+
2078+ embedded_template = self .template_embedder (
2079+ templates = templates ,
2080+ template_mask = template_mask ,
2081+ pairwise_repr = pairwise ,
2082+ mask = mask
2083+ )
2084+
2085+ pairwise = embedded_template + pairwise
2086+
2087+ # msa
2088+
2089+ embedded_msa = self .msa_module (
2090+ msa = msa ,
2091+ single_repr = single ,
2092+ pairwise_repr = pairwise ,
2093+ mask = mask
2094+ )
2095+
2096+ pairwise = embedded_msa + pairwise
2097+
2098+ # main attention trunk (pairformer)
2099+
2100+ single , pairwise = self .pairformer (
2101+ single_repr = single ,
2102+ pairwise_repr = pairwise ,
2103+ mask = mask
2104+ )
2105+
2106+ # determine whether to return loss if any labels were to be passed in
2107+ # otherwise will sample the atomic coordinates
2108+
2109+ labels = (distance_labels , pae_labels , pde_labels , plddt_labels , resolved_labels )
2110+ return_loss = any ([* filter (exists , labels )])
2111+
2112+ return torch .tensor (0. )
0 commit comments