Skip to content

Commit 264a948

Browse files
committed
4-bit draft; 128 vector load 240.
1 parent 869b7e8 commit 264a948

File tree

4 files changed

+278
-136
lines changed

4 files changed

+278
-136
lines changed

bitsandbytes/functional.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,10 +1385,12 @@ def cutlass3_gemm(
13851385
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
13861386
if state is None:
13871387
Bshape = B.shape
1388+
bout = Bshape[1]
13881389
else:
13891390
Bshape = state[1]
1391+
bout = Bshape[0]
13901392
if out is None:
1391-
out = torch.zeros(size=(A.shape[0], Bshape[1]), dtype=A.dtype, device=A.device)
1393+
out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
13921394

13931395
sA = A.shape
13941396
sB = B.shape
@@ -1464,7 +1466,7 @@ def cutlass3_gemm(
14641466
if state is not None:
14651467
m = Bshape[0]
14661468
k = Bshape[1]
1467-
lda = Bshape[1]
1469+
lda = Bshape[0]
14681470
ldc = Bshape[0]
14691471
ldb = (ldb+1)//2
14701472
#print(m, n, k, lda, ldb, ldc)

0 commit comments

Comments
 (0)