@@ -4518,6 +4518,7 @@ def __init__(
45184518 checkpoint_input_embedding = False ,
45194519 checkpoint_trunk_pairformer = False ,
45204520 checkpoint_diffusion_token_transformer = False ,
4521+ detach_when_recycling = True
45214522 ):
45224523 super ().__init__ ()
45234524
@@ -4642,6 +4643,8 @@ def __init__(
46424643
46434644 # recycling related
46444645
4646+ self .detach_when_recycling = detach_when_recycling
4647+
46454648 self .recycle_single = nn .Sequential (
46464649 nn .LayerNorm (dim_single ),
46474650 LinearNoBias (dim_single , dim_single )
@@ -4835,7 +4838,8 @@ def forward(
48354838 return_present_sampled_atoms : bool = False ,
48364839 return_confidence_head_logits : bool = False ,
48374840 num_rollout_steps : int | None = None ,
4838- rollout_show_tqdm_pbar : bool = False
4841+ rollout_show_tqdm_pbar : bool = False ,
4842+ detach_when_recycling : bool = None
48394843 ) -> (
48404844 Float ['b m 3' ] |
48414845 Float ['l 3' ] |
@@ -4995,6 +4999,9 @@ def forward(
49954999
49965000 # init recycled single and pairwise
49975001
5002+ detach_when_recycling = default (detach_when_recycling , self .detach_when_recycling )
5003+ maybe_recycling_detach = torch .detach if detach_when_recycling else identity
5004+
49985005 recycled_pairwise = recycled_single = None
49995006 single = pairwise = None
50005007
@@ -5007,9 +5014,11 @@ def forward(
50075014 recycled_single = recycled_pairwise = 0.
50085015
50095016 if exists (single ):
5017+ single = maybe_recycling_detach (single )
50105018 recycled_single = self .recycle_single (single )
50115019
50125020 if exists (pairwise ):
5021+ pairwise = maybe_recycling_detach (pairwise )
50135022 recycled_pairwise = self .recycle_pairwise (pairwise )
50145023
50155024 single = single_init + recycled_single
0 commit comments