@@ -1241,7 +1241,7 @@ def inner(inputs, *args, **kwargs):
12411241 wrapped_layers .append (pair_bias_attn_wrapper (pair_bias_attn ))
12421242 wrapped_layers .append (single_transition_wrapper (single_transition ))
12431243
1244- single_repr , pairwise_repr , _ = checkpoint_sequential (wrapped_layers , self .checkpoint_segments , inputs )
1244+ single_repr , pairwise_repr , _ = checkpoint_sequential (wrapped_layers , self .checkpoint_segments , inputs , use_reentrant = False )
12451245
12461246 return single_repr , pairwise_repr
12471247
@@ -1615,6 +1615,8 @@ def __init__(
16151615 serial = False ,
16161616 add_residual = True ,
16171617 use_linear_attn = False ,
1618+ checkpoint = False ,
1619+ checkpoint_segments = 1 ,
16181620 linear_attn_kwargs = dict (
16191621 heads = 8 ,
16201622 dim_head = 16
@@ -1689,6 +1691,9 @@ def __init__(
16891691 conditionable_transition
16901692 ]))
16911693
1694+ self .checkpoint = checkpoint
1695+ self .checkpoint_segments = checkpoint_segments
1696+
16921697 self .layers = layers
16931698
16941699 self .serial = serial
@@ -1703,7 +1708,7 @@ def __init__(
17031708 self .registers = nn .Parameter (torch .zeros (num_register_tokens , dim ))
17041709
17051710 @typecheck
1706- def forward (
1711+ def to_checkpointed_serial_layers (
17071712 self ,
17081713 noised_repr : Float ['b n d' ],
17091714 * ,
@@ -1712,32 +1717,92 @@ def forward(
17121717 mask : Bool ['b n' ] | None = None ,
17131718 windowed_mask : Bool ['b nw w (w*2)' ] | None = None
17141719 ):
1715- w = self .attn_window_size
1716- has_windows = exists (w )
17171720
1718- serial = self . serial
1721+ inputs = ( noised_repr , single_repr , pairwise_repr , mask , windowed_mask )
17191722
1720- # handle windowing
1723+ wrapped_layers = []
17211724
1722- pairwise_is_windowed = pairwise_repr .ndim == 5
1725+ def efficient_attn_wrapper (fn ):
1726+ def inner (inputs ):
1727+ noised_repr , single_repr , pairwise_repr , mask , windowed_mask = inputs
1728+ noised_repr = fn (noised_repr , mask = mask ) + noised_repr
1729+ return noised_repr , single_repr , pairwise_repr , mask , windowed_mask
1730+ return inner
17231731
1724- if has_windows and not pairwise_is_windowed :
1725- pairwise_repr = full_pairwise_repr_to_windowed (pairwise_repr , window_size = w )
1732+ def attn_wrapper (fn ):
1733+ def inner (inputs ):
1734+ noised_repr , single_repr , pairwise_repr , mask , windowed_mask = inputs
1735+ noised_repr = fn (noised_repr , cond = single_repr , pairwise_repr = pairwise_repr , mask = mask , windowed_mask = windowed_mask ) + noised_repr
1736+ return noised_repr , single_repr , pairwise_repr , mask , windowed_mask
1737+ return inner
17261738
1727- # register tokens
1739+ def transition_wrapper (fn ):
1740+ def inner (inputs ):
1741+ noised_repr , single_repr , pairwise_repr , mask , windowed_mask = inputs
1742+ noised_repr = fn (noised_repr , cond = single_repr ) + noised_repr
1743+ return noised_repr , single_repr , pairwise_repr , mask , windowed_mask
1744+ return inner
17281745
1729- if self .has_registers :
1730- num_registers = self .num_registers
1731- registers = repeat (self .registers , 'r d -> b r d' , b = noised_repr .shape [0 ])
1732- noised_repr , registers_ps = pack ((registers , noised_repr ), 'b * d' )
1746+ for linear_attn , colt5_attn , attn , transition in self .layers :
17331747
1734- single_repr = F . pad ( single_repr , ( 0 , 0 , num_registers , 0 ), value = 0. )
1735- pairwise_repr = F . pad ( pairwise_repr , ( 0 , 0 , num_registers , 0 , num_registers , 0 ), value = 0. )
1748+ if exists ( linear_attn ):
1749+ wrapped_layers . append ( efficient_attn_wrapper ( linear_attn ) )
17361750
1737- if exists (mask ):
1738- mask = F . pad ( mask , ( num_registers , 0 ), value = True )
1751+ if exists (colt5_attn ):
1752+ wrapped_layers . append ( efficient_attn_wrapper ( colt5_attn ) )
17391753
1740- # main transformer
1754+ wrapped_layers .append (attn_wrapper (attn ))
1755+ wrapped_layers .append (transition_wrapper (transition ))
1756+
1757+ out = checkpoint_sequential (wrapped_layers , self .checkpoint_segments , inputs , use_reentrant = False )
1758+
1759+ noised_repr , * _ = out
1760+ return noised_repr
1761+
1762+ @typecheck
1763+ def to_serial_layers (
1764+ self ,
1765+ noised_repr : Float ['b n d' ],
1766+ * ,
1767+ single_repr : Float ['b n ds' ],
1768+ pairwise_repr : Float ['b n n dp' ] | Float ['b nw w (w*2) dp' ],
1769+ mask : Bool ['b n' ] | None = None ,
1770+ windowed_mask : Bool ['b nw w (w*2)' ] | None = None
1771+ ):
1772+
1773+ for linear_attn , colt5_attn , attn , transition in self .layers :
1774+
1775+ if exists (linear_attn ):
1776+ noised_repr = linear_attn (noised_repr , mask = mask ) + noised_repr
1777+
1778+ if exists (colt5_attn ):
1779+ noised_repr = colt5_attn (noised_repr , mask = mask ) + noised_repr
1780+
1781+ noised_repr = attn (
1782+ noised_repr ,
1783+ cond = single_repr ,
1784+ pairwise_repr = pairwise_repr ,
1785+ mask = mask ,
1786+ windowed_mask = windowed_mask
1787+ ) + noised_repr
1788+
1789+ noised_repr = transition (
1790+ noised_repr ,
1791+ cond = single_repr
1792+ ) + noised_repr
1793+
1794+ return noised_repr
1795+
1796+ @typecheck
1797+ def to_parallel_layers (
1798+ self ,
1799+ noised_repr : Float ['b n d' ],
1800+ * ,
1801+ single_repr : Float ['b n ds' ],
1802+ pairwise_repr : Float ['b n n dp' ] | Float ['b nw w (w*2) dp' ],
1803+ mask : Bool ['b n' ] | None = None ,
1804+ windowed_mask : Bool ['b nw w (w*2)' ] | None = None
1805+ ):
17411806
17421807 for linear_attn , colt5_attn , attn , transition in self .layers :
17431808
@@ -1755,25 +1820,72 @@ def forward(
17551820 windowed_mask = windowed_mask
17561821 )
17571822
1758- if serial :
1759- noised_repr = attn_out + noised_repr
1760-
17611823 ff_out = transition (
17621824 noised_repr ,
17631825 cond = single_repr
17641826 )
17651827
1766- if serial :
1767- noised_repr = ff_out + noised_repr
1768-
17691828 # in the algorithm, they omitted the residual, but it could be an error
17701829 # attn + ff + residual was used in GPT-J and PaLM, but later found to be unstable configuration, so it seems unlikely attn + ff would work
17711830 # but in the case they figured out something we have not, you can use their exact formulation by setting `serial = False` and `add_residual = False`
17721831
17731832 residual = noised_repr if self .add_residual else 0.
17741833
1775- if not serial :
1776- noised_repr = ff_out + attn_out + residual
1834+ noised_repr = ff_out + attn_out + residual
1835+
1836+ return noised_repr
1837+
1838+ @typecheck
1839+ def forward (
1840+ self ,
1841+ noised_repr : Float ['b n d' ],
1842+ * ,
1843+ single_repr : Float ['b n ds' ],
1844+ pairwise_repr : Float ['b n n dp' ] | Float ['b nw w (w*2) dp' ],
1845+ mask : Bool ['b n' ] | None = None ,
1846+ windowed_mask : Bool ['b nw w (w*2)' ] | None = None
1847+ ):
1848+ w = self .attn_window_size
1849+ has_windows = exists (w )
1850+
1851+ serial = self .serial
1852+
1853+ # handle windowing
1854+
1855+ pairwise_is_windowed = pairwise_repr .ndim == 5
1856+
1857+ if has_windows and not pairwise_is_windowed :
1858+ pairwise_repr = full_pairwise_repr_to_windowed (pairwise_repr , window_size = w )
1859+
1860+ # register tokens
1861+
1862+ if self .has_registers :
1863+ num_registers = self .num_registers
1864+ registers = repeat (self .registers , 'r d -> b r d' , b = noised_repr .shape [0 ])
1865+ noised_repr , registers_ps = pack ((registers , noised_repr ), 'b * d' )
1866+
1867+ single_repr = F .pad (single_repr , (0 , 0 , num_registers , 0 ), value = 0. )
1868+ pairwise_repr = F .pad (pairwise_repr , (0 , 0 , num_registers , 0 , num_registers , 0 ), value = 0. )
1869+
1870+ if exists (mask ):
1871+ mask = F .pad (mask , (num_registers , 0 ), value = True )
1872+
1873+ # main transformer
1874+
1875+ if self .serial and should_checkpoint (self , (noised_repr , single_repr , pairwise_repr )):
1876+ to_layers_fn = self .to_checkpointed_serial_layers
1877+ elif self .serial :
1878+ to_layers_fn = self .to_serial_layers
1879+ else :
1880+ to_layers_fn = self .to_parallel_layers
1881+
1882+ noised_repr = to_layers_fn (
1883+ noised_repr ,
1884+ single_repr = single_repr ,
1885+ pairwise_repr = pairwise_repr ,
1886+ mask = mask ,
1887+ windowed_mask = windowed_mask ,
1888+ )
17771889
17781890 # splice out registers
17791891
0 commit comments