File tree Expand file tree Collapse file tree 1 file changed +8
-3
lines changed
bitsandbytes/backends/triton Expand file tree Collapse file tree 1 file changed +8
-3
lines changed Original file line number Diff line number Diff line change @@ -114,7 +114,6 @@ def optimizer_update_8bit_blockwise_pytorch(
114114 gnorm_scale : float ,
115115 skip_zeros : bool ,
116116 # ADEMIX
117- n : int ,
118117 * ,
119118 optimizer_name : str ,
120119) -> None :
@@ -262,7 +261,6 @@ def optimizer_update_8bit_blockwise_triton_quant(
262261 gnorm_scale : float ,
263262 skip_zeros : bool ,
264263 # ADEMIX
265- n : int ,
266264 * ,
267265 optimizer_name : str ,
268266) -> None :
@@ -627,7 +625,7 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel(
627625}
628626
629627
630- def optimizer_update_8bit_blockwise_impl (
628+ def optimizer_update_8bit_blockwise_triton_impl (
631629 optimizer_name : str ,
632630 g : torch .Tensor ,
633631 p : torch .Tensor ,
@@ -699,3 +697,10 @@ def optimizer_update_8bit_blockwise_impl(
699697 OPTIMIZER_ID = optimizer_id ,
700698 num_warps = 2 ,
701699 )
700+
701+
702+ # optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_pytorch
703+ # optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_pytorch_impl)
704+ # optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_triton_quant
705+ # optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_triton_quant)
706+ optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_triton_impl
You can’t perform that action at this time.
0 commit comments