Skip to content

Commit 23eba7a

Browse files
Cleanup
1 parent db07f4e commit 23eba7a

File tree

1 file changed

+1
-89
lines changed

1 file changed

+1
-89
lines changed

bitsandbytes/functional.py

Lines changed: 1 addition & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,21 +1821,12 @@ def gemv_4bit(
18211821
transposed_B=False,
18221822
state=None,
18231823
):
1824-
# sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
18251824
if state is None:
18261825
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")
18271826

1828-
if A.numel() != A.shape[-1]:
1829-
raise ValueError(
1830-
'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]',
1831-
)
1832-
1833-
# Bshape = state.shape
1834-
# bout = Bshape[0]
18351827
absmax = state.absmax
18361828
if state.nested:
1837-
absmax = dequantize_blockwise(state.absmax, state.state2)
1838-
absmax += state.offset
1829+
absmax = dequantize_blockwise(absmax, state.state2) + state.offset
18391830

18401831
return torch.ops.bitsandbytes.gemv_4bit(
18411832
A,
@@ -1846,85 +1837,6 @@ def gemv_4bit(
18461837
state.blocksize,
18471838
)
18481839

1849-
# if out is None:
1850-
# if len(A.shape) == 3:
1851-
# out = torch.empty(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device)
1852-
# else:
1853-
# out = torch.empty(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
1854-
1855-
# n = 1
1856-
# m = Bshape[0]
1857-
# k = Bshape[1]
1858-
# lda = Bshape[0]
1859-
# ldc = Bshape[0]
1860-
# ldb = (A.shape[-1] + 1) // 2
1861-
# is_on_gpu([B, A, out, absmax, state.code])
1862-
# m = ct.c_int32(m)
1863-
# n = ct.c_int32(n)
1864-
# k = ct.c_int32(k)
1865-
# lda = ct.c_int32(lda)
1866-
# ldb = ct.c_int32(ldb)
1867-
# ldc = ct.c_int32(ldc)
1868-
# stream = _get_tensor_stream(A)
1869-
1870-
# with _cuda_device_of(A):
1871-
# if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
1872-
# if A.dtype == torch.float16:
1873-
# lib.cgemm_4bit_inference_naive_fp16(
1874-
# m,
1875-
# n,
1876-
# k,
1877-
# get_ptr(A),
1878-
# get_ptr(B),
1879-
# get_ptr(absmax),
1880-
# get_ptr(state.code),
1881-
# get_ptr(out),
1882-
# lda,
1883-
# ldb,
1884-
# ldc,
1885-
# ct.c_int32(state.blocksize),
1886-
# stream,
1887-
# )
1888-
# elif A.dtype == torch.bfloat16:
1889-
# lib.cgemm_4bit_inference_naive_bf16(
1890-
# m,
1891-
# n,
1892-
# k,
1893-
# get_ptr(A),
1894-
# get_ptr(B),
1895-
# get_ptr(absmax),
1896-
# get_ptr(state.code),
1897-
# get_ptr(out),
1898-
# lda,
1899-
# ldb,
1900-
# ldc,
1901-
# ct.c_int32(state.blocksize),
1902-
# stream,
1903-
# )
1904-
# elif A.dtype == torch.float32:
1905-
# lib.cgemm_4bit_inference_naive_fp32(
1906-
# m,
1907-
# n,
1908-
# k,
1909-
# get_ptr(A),
1910-
# get_ptr(B),
1911-
# get_ptr(absmax),
1912-
# get_ptr(state.code),
1913-
# get_ptr(out),
1914-
# lda,
1915-
# ldb,
1916-
# ldc,
1917-
# ct.c_int32(state.blocksize),
1918-
# stream,
1919-
# )
1920-
# else:
1921-
# raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
1922-
1923-
# else:
1924-
# raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
1925-
1926-
# return out
1927-
19281840

19291841
def igemm(
19301842
A: Tensor,

0 commit comments

Comments
 (0)