@@ -981,6 +981,8 @@ def __init__(
981981 msa_pwa_dropout_row_prob = 0.15 ,
982982 msa_pwa_heads = 8 ,
983983 msa_pwa_dim_head = 32 ,
984+ checkpoint = False ,
985+ checkpoint_segments = 1 ,
984986 pairwise_block_kwargs : dict = dict (),
985987 max_num_msa : int | None = None ,
986988 layerscale_output : bool = True
@@ -1028,10 +1030,105 @@ def __init__(
10281030 pairwise_block
10291031 ]))
10301032
1033+ self .checkpoint = checkpoint
1034+ self .checkpoint_segments = checkpoint_segments
1035+
10311036 self .layers = layers
10321037
10331038 self .layerscale_output = nn .Parameter (torch .zeros (dim_pairwise )) if layerscale_output else 1.
10341039
1040+ @typecheck
1041+ def to_layers (
1042+ self ,
1043+ * ,
1044+ pairwise_repr : Float ['b n n dp' ],
1045+ msa : Float ['b s n dm' ],
1046+ mask : Bool ['b n' ] | None = None ,
1047+ msa_mask : Bool ['b s' ] | None = None ,
1048+ ) -> Float ['b n n dp' ]:
1049+
1050+ for (
1051+ outer_product_mean ,
1052+ msa_pair_weighted_avg ,
1053+ msa_transition ,
1054+ pairwise_block
1055+ ) in self .layers :
1056+
1057+ # communication between msa and pairwise rep
1058+
1059+ pairwise_repr = outer_product_mean (msa , mask = mask , msa_mask = msa_mask ) + pairwise_repr
1060+
1061+ msa = msa_pair_weighted_avg (msa = msa , pairwise_repr = pairwise_repr , mask = mask ) + msa
1062+ msa = msa_transition (msa ) + msa
1063+
1064+ # pairwise block
1065+
1066+ pairwise_repr = pairwise_block (pairwise_repr = pairwise_repr , mask = mask )
1067+
1068+ return pairwise_repr
1069+
1070+ @typecheck
1071+ def to_checkpointed_layers (
1072+ self ,
1073+ * ,
1074+ pairwise_repr : Float ['b n n dp' ],
1075+ msa : Float ['b s n dm' ],
1076+ mask : Bool ['b n' ] | None = None ,
1077+ msa_mask : Bool ['b s' ] | None = None ,
1078+ ) -> Float ['b n n dp' ]:
1079+
1080+ inputs = (pairwise_repr , mask , msa , msa_mask )
1081+
1082+ wrapped_layers = []
1083+
1084+ def outer_product_mean_wrapper (fn ):
1085+ @wraps (fn )
1086+ def inner (inputs ):
1087+ pairwise_repr , mask , msa , msa_mask = inputs
1088+ pairwise_repr = fn (msa = msa , mask = mask , msa_mask = msa_mask ) + pairwise_repr
1089+ return pairwise_repr , mask , msa , msa_mask
1090+ return inner
1091+
1092+ def msa_pair_weighted_avg_wrapper (fn ):
1093+ @wraps (fn )
1094+ def inner (inputs ):
1095+ pairwise_repr , mask , msa , msa_mask = inputs
1096+ msa = fn (msa = msa , pairwise_repr = pairwise_repr , mask = mask ) + msa
1097+ return pairwise_repr , mask , msa , msa_mask
1098+ return inner
1099+
1100+ def pairwise_block_wrapper (fn ):
1101+ @wraps (fn )
1102+ def inner (inputs ):
1103+ pairwise_repr , mask , msa , msa_mask = inputs
1104+ pairwise_repr = fn (pairwise_repr = pairwise_repr , mask = mask )
1105+ return pairwise_repr , mask , msa , msa_mask
1106+ return inner
1107+
1108+ def msa_transition_wrapper (fn ):
1109+ @wraps (fn )
1110+ def inner (inputs ):
1111+ pairwise_repr , mask , msa , msa_mask = inputs
1112+ msa = fn (msa ) + msa
1113+ return pairwise_repr , mask , msa , msa_mask
1114+ return inner
1115+
1116+ for (
1117+ outer_product_mean ,
1118+ msa_pair_weighted_avg ,
1119+ msa_transition ,
1120+ pairwise_block
1121+ ) in self .layers :
1122+
1123+ wrapped_layers .append (outer_product_mean_wrapper (outer_product_mean ))
1124+ wrapped_layers .append (msa_pair_weighted_avg_wrapper (msa_pair_weighted_avg ))
1125+ wrapped_layers .append (msa_transition_wrapper (msa_transition ))
1126+ wrapped_layers .append (pairwise_block_wrapper (pairwise_block ))
1127+
1128+ pairwise_repr , * _ = checkpoint_sequential (wrapped_layers , self .checkpoint_segments , inputs , use_reentrant = False )
1129+
1130+ return pairwise_repr
1131+
10351132 @typecheck
10361133 def forward (
10371134 self ,
@@ -1073,23 +1170,21 @@ def forward(
10731170
10741171 msa = rearrange (single_msa_feats , 'b n d -> b 1 n d' ) + msa
10751172
1076- for (
1077- outer_product_mean ,
1078- msa_pair_weighted_avg ,
1079- msa_transition ,
1080- pairwise_block
1081- ) in self .layers :
1082-
1083- # communication between msa and pairwise rep
1084-
1085- pairwise_repr = outer_product_mean (msa , mask = mask , msa_mask = msa_mask ) + pairwise_repr
1173+ # going through the layers
10861174
1087- msa = msa_pair_weighted_avg (msa = msa , pairwise_repr = pairwise_repr , mask = mask ) + msa
1088- msa = msa_transition (msa ) + msa
1175+ if should_checkpoint (self , (pairwise_repr , msa )):
1176+ to_layers_fn = self .to_checkpointed_layers
1177+ else :
1178+ to_layers_fn = self .to_layers
10891179
1090- # pairwise block
1180+ pairwise_repr = to_layers_fn (
1181+ msa = msa ,
1182+ mask = mask ,
1183+ pairwise_repr = pairwise_repr ,
1184+ msa_mask = msa_mask
1185+ )
10911186
1092- pairwise_repr = pairwise_block ( pairwise_repr = pairwise_repr , mask = mask )
1187+ # final masking and then layer scale
10931188
10941189 if exists (msa_mask ):
10951190 pairwise_repr = einx .where (
@@ -1208,20 +1303,23 @@ def to_checkpointed_layers(
12081303 inputs = (single_repr , pairwise_repr , mask )
12091304
12101305 def pairwise_block_wrapper (layer ):
1306+ @wraps (layer )
12111307 def inner (inputs , * args , ** kwargs ):
12121308 single_repr , pairwise_repr , mask = inputs
12131309 pairwise_repr = layer (pairwise_repr = pairwise_repr , mask = mask )
12141310 return single_repr , pairwise_repr , mask
12151311 return inner
12161312
12171313 def pair_bias_attn_wrapper (layer ):
1314+ @wraps (layer )
12181315 def inner (inputs , * args , ** kwargs ):
12191316 single_repr , pairwise_repr , mask = inputs
12201317 single_repr = layer (single_repr , pairwise_repr = pairwise_repr , mask = mask ) + single_repr
12211318 return single_repr , pairwise_repr , mask
12221319 return inner
12231320
12241321 def single_transition_wrapper (layer ):
1322+ @wraps (layer )
12251323 def inner (inputs , * args , ** kwargs ):
12261324 single_repr , pairwise_repr , mask = inputs
12271325 single_repr = layer (single_repr ) + single_repr
@@ -1725,20 +1823,23 @@ def to_checkpointed_serial_layers(
17251823 wrapped_layers = []
17261824
17271825 def efficient_attn_wrapper (fn ):
1826+ @wraps (fn )
17281827 def inner (inputs ):
17291828 noised_repr , single_repr , pairwise_repr , mask , windowed_mask = inputs
17301829 noised_repr = fn (noised_repr , mask = mask ) + noised_repr
17311830 return noised_repr , single_repr , pairwise_repr , mask , windowed_mask
17321831 return inner
17331832
17341833 def attn_wrapper (fn ):
1834+ @wraps (fn )
17351835 def inner (inputs ):
17361836 noised_repr , single_repr , pairwise_repr , mask , windowed_mask = inputs
17371837 noised_repr = fn (noised_repr , cond = single_repr , pairwise_repr = pairwise_repr , mask = mask , windowed_mask = windowed_mask ) + noised_repr
17381838 return noised_repr , single_repr , pairwise_repr , mask , windowed_mask
17391839 return inner
17401840
17411841 def transition_wrapper (fn ):
1842+ @wraps (fn )
17421843 def inner (inputs ):
17431844 noised_repr , single_repr , pairwise_repr , mask , windowed_mask = inputs
17441845 noised_repr = fn (noised_repr , cond = single_repr ) + noised_repr
0 commit comments