Skip to content

Commit cc68b22

Browse files
committed
Updated kernels
1 parent fd3ae6d commit cc68b22

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

bitsandbytes/backends/triton/kernels_optim.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def optimizer_update_8bit_blockwise_pytorch(
114114
gnorm_scale: float,
115115
skip_zeros: bool,
116116
# ADEMIX
117-
n: int,
118117
*,
119118
optimizer_name: str,
120119
) -> None:
@@ -262,7 +261,6 @@ def optimizer_update_8bit_blockwise_triton_quant(
262261
gnorm_scale: float,
263262
skip_zeros: bool,
264263
# ADEMIX
265-
n: int,
266264
*,
267265
optimizer_name: str,
268266
) -> None:
@@ -627,7 +625,7 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel(
627625
}
628626

629627

630-
def optimizer_update_8bit_blockwise_impl(
628+
def optimizer_update_8bit_blockwise_triton_impl(
631629
optimizer_name: str,
632630
g: torch.Tensor,
633631
p: torch.Tensor,
@@ -699,3 +697,10 @@ def optimizer_update_8bit_blockwise_impl(
699697
OPTIMIZER_ID=optimizer_id,
700698
num_warps=2,
701699
)
700+
701+
702+
# optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_pytorch
703+
# optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_pytorch_impl)
704+
# optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_triton_quant
705+
# optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_triton_quant)
706+
optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_triton_impl

0 commit comments

Comments
 (0)