Skip to content

Commit ec38ba9

Browse files
committed
Added paging.
1 parent 264a948 commit ec38ba9

File tree

8 files changed

+167
-90
lines changed

8 files changed

+167
-90
lines changed

bitsandbytes/cextension.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
lib.cadam_8bit_blockwise_fp32
2727
lib.get_context.restype = ct.c_void_p
2828
lib.get_cusparse.restype = ct.c_void_p
29+
lib.cget_managed_ptr.restype = ct.c_void_p
30+
lib.cget_stream.restype = ct.c_void_p
2931
COMPILED_WITH_CUDA = True
3032
except AttributeError:
3133
warn("The installed version of bitsandbytes was compiled without GPU support. "

bitsandbytes/functional.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,61 @@ def get_instance(cls):
130130
cls._instance.initialize()
131131
return cls._instance
132132

133+
dtype2bytes = {}
134+
dtype2bytes[torch.float32] = 4
135+
dtype2bytes[torch.float16] = 2
136+
dtype2bytes[torch.bfloat16] = 2
137+
dtype2bytes[torch.uint8] = 1
138+
dtype2bytes[torch.int8] = 1
139+
140+
def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)):
141+
num_bytes = dtype2bytes[dtype]*prod(shape)
142+
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
143+
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
144+
new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
145+
out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape))
146+
out.is_paged = True
147+
out.page_deviceid = device.index
148+
return out
149+
150+
def prefetch_tensor(A, to_cpu=False):
151+
assert A.is_paged, 'Only paged tensors can be prefetched!'
152+
if to_cpu:
153+
deviceid = -1
154+
else:
155+
deviceid = A.page_deviceid
156+
157+
num_bytes = dtype2bytes[A.dtype]*A.numel()
158+
lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid))
159+
160+
def elementwise_func(func_name, A, B, value, prefetch=True):
161+
func = None
162+
if A.dtype == torch.float32:
163+
func = getattr(lib, f'c{func_name}_fp32', None)
164+
cvalue = ct.c_float(value)
165+
elif A.dtype == torch.uint8:
166+
func = getattr(lib, f'c{func_name}_uint8', None)
167+
cvalue = ct.c_uint8(value)
168+
169+
if func is None: raise NotImplementedError(f'Function not implemented: {func_name}')
170+
171+
is_managed = getattr(A, 'is_managed', False)
172+
if is_managed and prefetch:
173+
prefetch_tensor(A)
174+
if B is not None: prefetch_tensor(B)
175+
176+
func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel()))
177+
if A.is_paged or B.is_paged:
178+
# paged function are fully asynchronous
179+
# if we return from this function, we want to the tensor
180+
# to be in the correct state, that is the final state after the
181+
# operation occured. So we synchronize.
182+
torch.cuda.synchronize()
183+
184+
def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value)
185+
def arange(A, device=None): elementwise_func('arange', A, None, 0)
186+
def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0)
187+
133188

134189
def create_linear_map(signed=True, total_bits=8, add_zero=True):
135190
sign = (-1.0 if signed else 0.0)

csrc/kernels.cu

Lines changed: 19 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3522,69 +3522,34 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
35223522
//}
35233523

35243524

3525-
__device__ void compute(float* global_out, float const* shared_in)
3525+
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n)
35263526
{
3527-
3528-
}
3529-
template <size_t stages_count /* Pipeline with stages_count stages */>
3530-
__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz) {
3531-
auto grid = cooperative_groups::this_grid();
3532-
auto block = cooperative_groups::this_thread_block();
3533-
assert(size == batch_sz * grid.size()); // Assume input size fits batch_sz * grid_size
3534-
3535-
extern __shared__ float shared[]; // stages_count * block.size() * sizeof(int) bytes
3536-
size_t shared_offset[stages_count];
3537-
for (int s = 0; s < stages_count; ++s) shared_offset[s] = s * block.size();
3538-
3539-
__shared__ cuda::pipeline_shared_state<
3540-
cuda::thread_scope::thread_scope_block,
3541-
stages_count
3542-
> shared_state;
3543-
auto pipeline = cuda::make_pipeline(block, &shared_state);
3544-
3545-
auto block_batch = [&](size_t batch) -> int {
3546-
return block.group_index().x * block.size() + grid.size() * batch;
3547-
};
3548-
3549-
// compute_batch: next batch to process
3550-
// fetch_batch: next batch to fetch from global memory
3551-
for (size_t compute_batch = 0, fetch_batch = 0; compute_batch < batch_sz; ++compute_batch) {
3552-
// The outer loop iterates over the computation of the batches
3553-
for (; fetch_batch < batch_sz && fetch_batch < (compute_batch + stages_count); ++fetch_batch) {
3554-
// This inner loop iterates over the memory transfers, making sure that the pipeline is always full
3555-
pipeline.producer_acquire();
3556-
size_t shared_idx = fetch_batch % stages_count;
3557-
size_t batch_idx = fetch_batch;
3558-
size_t block_batch_idx = block_batch(batch_idx);
3559-
cuda::memcpy_async(block, shared + shared_offset[shared_idx], global_in + block_batch_idx, sizeof(float) * block.size(), pipeline);
3560-
pipeline.producer_commit();
3561-
}
3562-
pipeline.consumer_wait();
3563-
int shared_idx = compute_batch % stages_count;
3564-
int batch_idx = compute_batch;
3565-
compute(global_out + block_batch(batch_idx), shared + shared_offset[shared_idx]);
3566-
pipeline.consumer_release();
3527+
for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x))
3528+
{
3529+
switch(FUNC)
3530+
{
3531+
case FILL:
3532+
A[i] = (T)value;
3533+
break;
3534+
case ARANGE:
3535+
A[i] = (T)i;
3536+
break;
3537+
case _MUL:
3538+
A[i] = A[i]*B[i];
3539+
break;
35673540
}
3541+
}
35683542
}
35693543

35703544

35713545
//==============================================================
35723546
// TEMPLATE DEFINITIONS
35733547
//==============================================================
35743548

3575-
//template <class MShape, class NShape, class KShape,
3576-
// class TA, class AStride, class ABlockLayout, class AThreadLayout,
3577-
// class TB, class BStride, class BBlockLayout, class BThreadLayout,
3578-
// class TC, class CStride, class CBlockLayout, class CThreadLayout,
3579-
// class Alpha, class Beta>
3580-
//__global__ static
3581-
//__launch_bounds__(decltype(size(CThreadLayout{}))::value)
3582-
//void
3583-
//gemm_device(MShape M, NShape N, KShape K,
3584-
// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
3585-
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
3586-
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
3587-
// half alpha, half beta);
3549+
template __global__ void kfunc<float, FILL>(float *A, float *B, float value, long n);
3550+
template __global__ void kfunc<unsigned char, FILL>(unsigned char *A, unsigned char *B, unsigned char value, long n);
3551+
template __global__ void kfunc<float, ARANGE>(float *A, float *B, float value, long n);
3552+
template __global__ void kfunc<float, _MUL>(float *A, float *B, float value, long n);
35883553

35893554
// these are not used and make no sense, but the compiler needs them
35903555
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
@@ -3611,9 +3576,6 @@ template __global__ void gemm_device<half, 16, 96>(int M, int N, int K, half * _
36113576
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
36123577
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
36133578

3614-
3615-
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
3616-
template __global__ void with_staging_unified<2>(float const* global_in, float * global_out, size_t size, size_t batch_sz);
36173579
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
36183580
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
36193581

csrc/kernels.cuh

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -122,23 +122,9 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
122122

123123
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
124124

125-
//template <class MShape, class NShape, class KShape,
126-
// class TA, class AStride, class ABlockLayout, class AThreadLayout,
127-
// class TB, class BStride, class BBlockLayout, class BThreadLayout,
128-
// class TC, class CStride, class CBlockLayout, class CThreadLayout,
129-
// class Alpha, class Beta>
130-
//__global__ static
131-
//__launch_bounds__(decltype(size(CThreadLayout{}))::value)
132-
//void
133-
//gemm_device(MShape M, NShape N, KShape K,
134-
// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
135-
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
136-
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
137-
// Alpha alpha, Beta beta);
138-
template <size_t stages_count /* Pipeline with stages_count stages */>
139-
__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz);
140-
141125
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
142126
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
143127

128+
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n);
129+
144130
#endif

csrc/ops.cu

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -663,16 +663,6 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
663663
}
664664

665665

666-
void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
667-
{
668-
669-
int threads = 256;
670-
int num_blocks = (n+(256*batch_size)+1)/(batch_size*256);
671-
672-
with_staging_unified<2><<<num_blocks, threads>>>(A, B, n, batch_size);
673-
CUDA_CHECK_RETURN(cudaPeekAtLastError());
674-
}
675-
676666

677667

678668
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
@@ -717,10 +707,25 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
717707
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
718708
}
719709

710+
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
711+
{
712+
int threads = 512;
713+
int blocks = n/threads;
714+
blocks = n % threads == 0 ? blocks : blocks + 1;
715+
blocks = blocks > 65535 ? 65535 : blocks;
716+
kfunc<T, FUNC><<<blocks, 512>>>(A, B, value, n);
717+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
718+
}
719+
720720
//==============================================================
721721
// TEMPLATE DEFINITIONS
722722
//==============================================================
723723

724+
template void func<float, FILL>(float *A, float *B, float value, long n);
725+
template void func<unsigned char, FILL>(unsigned char *A, unsigned char *B, unsigned char value, long n);
726+
template void func<float, ARANGE>(float *A, float *B, float value, long n);
727+
template void func<float, _MUL>(float *A, float *B, float value, long n);
728+
724729
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
725730
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
726731
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);

csrc/ops.cuh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ typedef enum DataType_t
9393
NF4 = 2,
9494
} DataType_t;
9595

96+
typedef enum Funcs_t
97+
{
98+
FILL = 0,
99+
ARANGE = 1,
100+
_MUL = 2,
101+
} Funcs_t;
102+
96103
class Context
97104
{
98105
public:
@@ -193,6 +200,6 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
193200
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
194201
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
195202

203+
template <typename T, int FUNC> void func(T *A, T *B, T value, long n);
196204

197-
void pipeline_test(float *A, float *B, size_t n, size_t batch_size);
198205
#endif

csrc/pythonInterface.c

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l
2828
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
2929
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
3030

31+
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
32+
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
33+
34+
MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
35+
MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
36+
MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
37+
MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
38+
3139

3240
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
3341
void fname##32bit_g##gbits(gtype *g, gtype *p, \
@@ -314,7 +322,6 @@ extern "C"
314322

315323
void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
316324
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
317-
void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); }
318325

319326
//void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
320327
//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
@@ -325,6 +332,29 @@ extern "C"
325332
void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
326333
{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
327334

335+
void *cget_managed_ptr(size_t bytes)
336+
{
337+
void *ptr;
338+
CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost));
339+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
340+
341+
return ptr;
342+
}
343+
344+
void cprefetch(void *ptr, size_t bytes, int device)
345+
{
346+
CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0));
347+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
348+
}
349+
350+
#define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
351+
void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \
352+
353+
CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
354+
CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
355+
CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
356+
CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
357+
328358
#endif
329359
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
330360
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }

tests/test_functional.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2489,8 +2489,38 @@ def test_gemm_4bit(dtype):
24892489
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
24902490
print(dim, (max_err.item(), max_relerr.item()))
24912491

2492-
def test_pipeline_func():
2493-
a = torch.rand(2, 4).cuda()
2494-
out = F.pipeline_test(a, 2)
2495-
print(a)
2496-
print(out)
2492+
def test_managed():
2493+
n = 32*10
2494+
A = F.get_paged(n, n, dtype=torch.float32)
2495+
B = F.get_paged(n, n, dtype=torch.uint8)
2496+
B2 = F.get_paged(n, n, dtype=torch.float32)
2497+
assert A.is_paged
2498+
assert B.is_paged
2499+
assert A.page_deviceid==0
2500+
assert B.page_deviceid==0
2501+
F.fill(A, 17.0)
2502+
F.fill(B, 17)
2503+
F.fill(B2, 2)
2504+
assert (A==17).sum().item() == n*n
2505+
assert (B==17).sum().item() == n*n
2506+
C = A*B.float()
2507+
assert (C==289).sum().item() == n*n
2508+
F._mul(A, B2)
2509+
F._mul(A, B2)
2510+
F._mul(A, B2)
2511+
assert (A==17*(2**3)).sum().item() == n*n
2512+
# F.prefetch_tensor(A)
2513+
# F.prefetch_tensor(B)
2514+
2515+
2516+
# F.fill(B2, 17.0)
2517+
# F._mul(A, B2)
2518+
2519+
# F.prefetch_tensor(A, to_cpu=True)
2520+
# F.prefetch_tensor(B, to_cpu=True)
2521+
# F.prefetch_tensor(B2, to_cpu=True)
2522+
# torch.cuda.synchronize()
2523+
2524+
# assert (A==17).sum().item() == n*n
2525+
2526+
# torch.testing.assert_allclose(A, torch.ones(A.shape)*289)

0 commit comments

Comments
 (0)