@@ -1555,9 +1555,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =
15551555
15561556def optimizer_update_32bit (
15571557 optimizer_name : str ,
1558- g : Tensor ,
1559- p : Tensor ,
1560- state1 : Tensor ,
1558+ g : torch . Tensor ,
1559+ p : torch . Tensor ,
1560+ state1 : torch . Tensor ,
15611561 beta1 : float ,
15621562 eps : float ,
15631563 step : int ,
@@ -1571,6 +1571,7 @@ def optimizer_update_32bit(
15711571 unorm_vec : Optional [torch .Tensor ] = None ,
15721572 max_unorm : float = 0.0 ,
15731573 skip_zeros = False ,
1574+ return_updates : Optional [torch .Tensor ] = None ,
15741575) -> None :
15751576 """
15761577 Performs an inplace optimizer update with one or two optimizer states.
@@ -1613,6 +1614,8 @@ def optimizer_update_32bit(
16131614 The maximum update norm relative to the weight norm.
16141615 skip_zeros : bool
16151616 Whether to skip zero-valued gradients or not (default: False).
1617+ return_updates: Optional[torch.Tensor]
1618+ When provided, updates are written to this tensor and not applied directly to `p`. (default: None)
16161619 """
16171620
16181621 param_norm = 0.0
@@ -1636,6 +1639,7 @@ def optimizer_update_32bit(
16361639 optim_func (
16371640 get_ptr (g ),
16381641 get_ptr (p ),
1642+ get_ptr (return_updates ),
16391643 get_ptr (state1 ),
16401644 get_ptr (state2 ),
16411645 get_ptr (unorm_vec ),
@@ -1658,25 +1662,26 @@ def optimizer_update_32bit(
16581662
16591663def optimizer_update_8bit (
16601664 optimizer_name : str ,
1661- g : Tensor ,
1662- p : Tensor ,
1663- state1 : Tensor ,
1665+ g : torch . Tensor ,
1666+ p : torch . Tensor ,
1667+ state1 : torch . Tensor ,
16641668 state2 : Optional [torch .Tensor ],
16651669 beta1 : float ,
16661670 beta2 : float ,
16671671 eps : float ,
16681672 step : int ,
16691673 lr : float ,
1670- qmap1 : Tensor ,
1674+ qmap1 : torch . Tensor ,
16711675 qmap2 : Optional [torch .Tensor ],
1672- max1 : Tensor ,
1676+ max1 : torch . Tensor ,
16731677 max2 : Optional [torch .Tensor ],
1674- new_max1 : Tensor ,
1678+ new_max1 : torch . Tensor ,
16751679 new_max2 : Optional [torch .Tensor ],
16761680 weight_decay : float = 0.0 ,
16771681 gnorm_scale : float = 1.0 ,
16781682 unorm_vec : Optional [torch .Tensor ] = None ,
16791683 max_unorm : float = 0.0 ,
1684+ return_updates : Optional [torch .Tensor ] = None ,
16801685) -> None :
16811686 """
16821687 Performs an inplace Adam update.
@@ -1726,6 +1731,8 @@ def optimizer_update_8bit(
17261731 The tensor for the update norm.
17271732 max_unorm : float
17281733 The maximum update norm relative to the weight norm.
1734+ return_updates: Optional[torch.Tensor]
1735+ When provided, updates are written to this tensor and not applied directly to `p`. (default: None)
17291736 """
17301737
17311738 param_norm = 0.0
@@ -1738,6 +1745,7 @@ def optimizer_update_8bit(
17381745 str2optimizer8bit [optimizer_name ][0 ](
17391746 get_ptr (p ),
17401747 get_ptr (g ),
1748+ get_ptr (return_updates ),
17411749 get_ptr (state1 ),
17421750 get_ptr (state2 ),
17431751 get_ptr (unorm_vec ),
@@ -1762,6 +1770,7 @@ def optimizer_update_8bit(
17621770 str2optimizer8bit [optimizer_name ][1 ](
17631771 get_ptr (p ),
17641772 get_ptr (g ),
1773+ get_ptr (return_updates ),
17651774 get_ptr (state1 ),
17661775 get_ptr (state2 ),
17671776 get_ptr (unorm_vec ),
@@ -1809,6 +1818,7 @@ def optimizer_update_8bit_blockwise(
18091818 weight_decay : float = 0.0 ,
18101819 gnorm_scale : float = 1.0 ,
18111820 skip_zeros = False ,
1821+ return_updates : Optional [torch .Tensor ] = None ,
18121822) -> None :
18131823 optim_func = None
18141824 prev_device = pre_call (g .device )
@@ -1835,6 +1845,7 @@ def optimizer_update_8bit_blockwise(
18351845 optim_func (
18361846 get_ptr (p ),
18371847 get_ptr (g ),
1848+ get_ptr (return_updates ),
18381849 get_ptr (state1 ),
18391850 get_ptr (state2 ),
18401851 ct .c_float (beta1 ),
0 commit comments