@@ -1821,21 +1821,12 @@ def gemv_4bit(
18211821 transposed_B = False ,
18221822 state = None ,
18231823):
1824- # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
18251824 if state is None :
18261825 raise ValueError ("state cannot be None. gemv_4bit() requires the state from quantize_4bit()" )
18271826
1828- if A .numel () != A .shape [- 1 ]:
1829- raise ValueError (
1830- 'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]' ,
1831- )
1832-
1833- # Bshape = state.shape
1834- # bout = Bshape[0]
18351827 absmax = state .absmax
18361828 if state .nested :
1837- absmax = dequantize_blockwise (state .absmax , state .state2 )
1838- absmax += state .offset
1829+ absmax = dequantize_blockwise (absmax , state .state2 ) + state .offset
18391830
18401831 return torch .ops .bitsandbytes .gemv_4bit (
18411832 A ,
@@ -1846,85 +1837,6 @@ def gemv_4bit(
18461837 state .blocksize ,
18471838 )
18481839
1849- # if out is None:
1850- # if len(A.shape) == 3:
1851- # out = torch.empty(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device)
1852- # else:
1853- # out = torch.empty(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
1854-
1855- # n = 1
1856- # m = Bshape[0]
1857- # k = Bshape[1]
1858- # lda = Bshape[0]
1859- # ldc = Bshape[0]
1860- # ldb = (A.shape[-1] + 1) // 2
1861- # is_on_gpu([B, A, out, absmax, state.code])
1862- # m = ct.c_int32(m)
1863- # n = ct.c_int32(n)
1864- # k = ct.c_int32(k)
1865- # lda = ct.c_int32(lda)
1866- # ldb = ct.c_int32(ldb)
1867- # ldc = ct.c_int32(ldc)
1868- # stream = _get_tensor_stream(A)
1869-
1870- # with _cuda_device_of(A):
1871- # if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
1872- # if A.dtype == torch.float16:
1873- # lib.cgemm_4bit_inference_naive_fp16(
1874- # m,
1875- # n,
1876- # k,
1877- # get_ptr(A),
1878- # get_ptr(B),
1879- # get_ptr(absmax),
1880- # get_ptr(state.code),
1881- # get_ptr(out),
1882- # lda,
1883- # ldb,
1884- # ldc,
1885- # ct.c_int32(state.blocksize),
1886- # stream,
1887- # )
1888- # elif A.dtype == torch.bfloat16:
1889- # lib.cgemm_4bit_inference_naive_bf16(
1890- # m,
1891- # n,
1892- # k,
1893- # get_ptr(A),
1894- # get_ptr(B),
1895- # get_ptr(absmax),
1896- # get_ptr(state.code),
1897- # get_ptr(out),
1898- # lda,
1899- # ldb,
1900- # ldc,
1901- # ct.c_int32(state.blocksize),
1902- # stream,
1903- # )
1904- # elif A.dtype == torch.float32:
1905- # lib.cgemm_4bit_inference_naive_fp32(
1906- # m,
1907- # n,
1908- # k,
1909- # get_ptr(A),
1910- # get_ptr(B),
1911- # get_ptr(absmax),
1912- # get_ptr(state.code),
1913- # get_ptr(out),
1914- # lda,
1915- # ldb,
1916- # ldc,
1917- # ct.c_int32(state.blocksize),
1918- # stream,
1919- # )
1920- # else:
1921- # raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
1922-
1923- # else:
1924- # raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
1925-
1926- # return out
1927-
19281840
19291841def igemm (
19301842 A : Tensor ,
0 commit comments