@@ -1888,7 +1888,6 @@ def __init__(
18881888 attn_num_memory_kv = False ,
18891889 trans_expansion_factor = 2 ,
18901890 num_register_tokens = 0 ,
1891- serial = True ,
18921891 add_residual = True ,
18931892 use_linear_attn = False ,
18941893 checkpoint = False ,
@@ -1967,13 +1966,10 @@ def __init__(
19671966 conditionable_transition
19681967 ]))
19691968
1970- assert not (not serial and checkpoint ), 'checkpointing can only be used for serial version of diffusion transformer'
1971-
19721969 self .checkpoint = checkpoint
19731970
19741971 self .layers = layers
19751972
1976- self .serial = serial
19771973 self .add_residual = add_residual
19781974
19791975 self .has_registers = num_register_tokens > 0
@@ -2074,48 +2070,6 @@ def to_serial_layers(
20742070
20752071 return noised_repr
20762072
2077- @typecheck
2078- def to_parallel_layers (
2079- self ,
2080- noised_repr : Float ['b n d' ],
2081- * ,
2082- single_repr : Float ['b n ds' ],
2083- pairwise_repr : Float ['b n n dp' ] | Float ['b nw w (w*2) dp' ],
2084- mask : Bool ['b n' ] | None = None ,
2085- windowed_mask : Bool ['b nw w (w*2)' ] | None = None
2086- ):
2087-
2088- for linear_attn , colt5_attn , attn , transition in self .layers :
2089-
2090- if exists (linear_attn ):
2091- noised_repr = linear_attn (noised_repr , mask = mask ) + noised_repr
2092-
2093- if exists (colt5_attn ):
2094- noised_repr = colt5_attn (noised_repr , mask = mask ) + noised_repr
2095-
2096- attn_out = attn (
2097- noised_repr ,
2098- cond = single_repr ,
2099- pairwise_repr = pairwise_repr ,
2100- mask = mask ,
2101- windowed_mask = windowed_mask
2102- )
2103-
2104- ff_out = transition (
2105- noised_repr ,
2106- cond = single_repr
2107- )
2108-
2109- # in the algorithm, they omitted the residual, but it could be an error
2110- # attn + ff + residual was used in GPT-J and PaLM, but later found to be unstable configuration, so it seems unlikely attn + ff would work
2111- # but in the case they figured out something we have not, you can use their exact formulation by setting `serial = False` and `add_residual = False`
2112-
2113- residual = noised_repr if self .add_residual else 0.
2114-
2115- noised_repr = ff_out + attn_out + residual
2116-
2117- return noised_repr
2118-
21192073 @typecheck
21202074 def forward (
21212075 self ,
@@ -2126,7 +2080,7 @@ def forward(
21262080 mask : Bool ['b n' ] | None = None ,
21272081 windowed_mask : Bool ['b nw w (w*2)' ] | None = None
21282082 ):
2129- w , serial = self .attn_window_size , self . serial
2083+ w = self .attn_window_size
21302084 has_windows = exists (w )
21312085
21322086 # handle windowing
@@ -2151,12 +2105,10 @@ def forward(
21512105
21522106 # main transformer
21532107
2154- if serial and should_checkpoint (self , (noised_repr , single_repr , pairwise_repr )):
2108+ if should_checkpoint (self , (noised_repr , single_repr , pairwise_repr )):
21552109 to_layers_fn = self .to_checkpointed_serial_layers
2156- elif serial :
2157- to_layers_fn = self .to_serial_layers
21582110 else :
2159- to_layers_fn = self .to_parallel_layers
2111+ to_layers_fn = self .to_serial_layers
21602112
21612113 noised_repr = to_layers_fn (
21622114 noised_repr ,
@@ -2230,7 +2182,6 @@ def __init__(
22302182 token_transformer_heads = 16 ,
22312183 atom_decoder_depth = 3 ,
22322184 atom_decoder_heads = 4 ,
2233- serial = True ,
22342185 atom_encoder_kwargs : dict = dict (),
22352186 atom_decoder_kwargs : dict = dict (),
22362187 token_transformer_kwargs : dict = dict (),
@@ -2298,7 +2249,6 @@ def __init__(
22982249 attn_window_size = atoms_per_window ,
22992250 depth = atom_encoder_depth ,
23002251 heads = atom_encoder_heads ,
2301- serial = serial ,
23022252 use_linear_attn = use_linear_attn ,
23032253 linear_attn_kwargs = linear_attn_kwargs ,
23042254 checkpoint = checkpoint ,
@@ -2323,7 +2273,6 @@ def __init__(
23232273 dim_pairwise = dim_pairwise ,
23242274 depth = token_transformer_depth ,
23252275 heads = token_transformer_heads ,
2326- serial = serial ,
23272276 checkpoint = checkpoint ,
23282277 ** token_transformer_kwargs
23292278 )
@@ -2341,7 +2290,6 @@ def __init__(
23412290 attn_window_size = atoms_per_window ,
23422291 depth = atom_decoder_depth ,
23432292 heads = atom_decoder_heads ,
2344- serial = serial ,
23452293 use_linear_attn = use_linear_attn ,
23462294 linear_attn_kwargs = linear_attn_kwargs ,
23472295 checkpoint = checkpoint ,
0 commit comments