Skip to content

Commit ed922b8

Browse files
remove unused kernels; improved type annotations
1 parent b1c4adc commit ed922b8

File tree

5 files changed

+4
-319
lines changed

5 files changed

+4
-319
lines changed

csrc/kernels.cu

Lines changed: 2 additions & 264 deletions
Original file line numberDiff line numberDiff line change
@@ -2233,160 +2233,6 @@ __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshol
22332233
}
22342234
}
22352235

2236-
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols)
2237-
{
2238-
// 0. reset stats to -FLT_MAX
2239-
// 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
2240-
// 2. compute col max (per thread); store in smem due to register pressure
2241-
// 3. compute row max (per block); store in smem to accumulate full global mem transation
2242-
// 4. store data via atomicMax
2243-
2244-
// each block loads TILE_COLs columns and TILE_ROW rows
2245-
// after reading a tile the row counter increase by TILE_ROWS
2246-
// the col counter reset after reading TILE_COL elements
2247-
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
2248-
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
2249-
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
2250-
const int base_idx = (base_row*cols) + base_col;
2251-
const int items_per_load = ITEMS_PER_THREAD*THREADS;
2252-
2253-
typedef cub::BlockLoad<T, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadT;
2254-
typedef cub::BlockReduce<float, THREADS> BlockRowReduce;
2255-
typedef cub::BlockReduce<int, THREADS> BlockRowSum;
2256-
typedef cub::BlockExchange<float, THREADS, ITEMS_PER_THREAD> BlockExchange;
2257-
2258-
__shared__ union {
2259-
typename BlockExchange::TempStorage exchange;
2260-
typename BlockRowReduce::TempStorage rowreduce;
2261-
typename BlockRowSum::TempStorage rowsum;
2262-
typename LoadT::TempStorage loadt;
2263-
} temp_storage;
2264-
2265-
__shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
2266-
__shared__ int smem_row_nnz_values[TILE_ROWS];
2267-
2268-
half local_data[ITEMS_PER_THREAD];
2269-
float local_data_fp32[ITEMS_PER_THREAD];
2270-
float local_col_absmax_values[ITEMS_PER_THREAD];
2271-
int local_row_nnz_count = 0;
2272-
float row_absmax = -FLT_MAX;
2273-
2274-
// 0. reset stats to -FLT_MAX
2275-
for(int j = 0; j < ITEMS_PER_THREAD; j++)
2276-
{
2277-
//smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
2278-
smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
2279-
// smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0;
2280-
}
2281-
2282-
#pragma unroll TILE_ROWS
2283-
for (int j = 0; j < TILE_ROWS; j++) {
2284-
smem_row_nnz_values[j] = 0;
2285-
}
2286-
2287-
#pragma unroll ITEMS_PER_THREAD
2288-
for(int j = 0; j < ITEMS_PER_THREAD; j++)
2289-
local_col_absmax_values[j] = -FLT_MAX;
2290-
2291-
__syncthreads();
2292-
2293-
int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
2294-
int i = base_idx;
2295-
// we load row after row from the base_position
2296-
// 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
2297-
for(int row = 0; row < TILE_ROWS; row++)
2298-
{
2299-
if(base_row+row >= rows){ break; }
2300-
local_row_nnz_count = 0;
2301-
i = base_idx + ((row)*cols);
2302-
// each thread gets data from the same column
2303-
__syncthreads();
2304-
LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f));
2305-
2306-
#pragma unroll ITEMS_PER_THREAD
2307-
for(int j = 0; j < ITEMS_PER_THREAD; j++)
2308-
local_data[j] = fabsf(local_data[j]);
2309-
2310-
2311-
if(SPARSE_DECOMP)
2312-
#pragma unroll ITEMS_PER_THREAD
2313-
for(int j = 0; j < ITEMS_PER_THREAD; j++)
2314-
{
2315-
if((float)local_data[j] >= nnz_threshold)
2316-
{
2317-
local_row_nnz_count += 1;
2318-
local_data[j] = 0.0f;
2319-
}
2320-
}
2321-
2322-
// 2. compute col max (per thread); store in smem due to register pressure
2323-
#pragma unroll ITEMS_PER_THREAD
2324-
for(int j = 0; j < ITEMS_PER_THREAD; j++)
2325-
// take the col max for this row
2326-
// we use shared memory because register pressure is too high if we do this locally
2327-
//smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j]));
2328-
local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));
2329-
2330-
// 3. compute row max (per block); store in smem to accumulate full global mem transation
2331-
2332-
// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
2333-
#pragma unroll ITEMS_PER_THREAD
2334-
for(int j = 0; j < ITEMS_PER_THREAD; j++)
2335-
local_data_fp32[j] = local_data[j];
2336-
2337-
__syncthreads();
2338-
2339-
row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max());
2340-
if(SPARSE_DECOMP)
2341-
{
2342-
__syncthreads();
2343-
local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count);
2344-
}
2345-
// we store the data temporarily in shared memory so we
2346-
// can execute a full atomic block transaction into global memory later
2347-
// we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores
2348-
if(threadIdx.x == 0)
2349-
{
2350-
smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax;
2351-
// each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block
2352-
smem_row_nnz_values[row] = local_row_nnz_count;
2353-
}
2354-
2355-
__syncthreads();
2356-
2357-
}
2358-
2359-
// 4. store data via atomicMax
2360-
// to store col data efficiently we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0
2361-
// into a striped arrangement: [0, 8, 16, 24, ..] for t0
2362-
__syncthreads();
2363-
BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values);
2364-
2365-
#pragma unroll ITEMS_PER_THREAD
2366-
for(int j = 0; j < ITEMS_PER_THREAD; j++)
2367-
if(base_col+threadIdx.x+(j*THREADS) < cols)
2368-
{
2369-
float val = colStats[base_col+(threadIdx.x+(j*THREADS))];
2370-
if(val < local_col_absmax_values[j])
2371-
atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]);
2372-
}
2373-
2374-
for(int j = 0; j < ITEMS_PER_THREAD; j++)
2375-
if(base_row+threadIdx.x+(j*THREADS) < rows)
2376-
{
2377-
float val = rowStats[base_row+(threadIdx.x+(j*THREADS))];
2378-
if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)])
2379-
atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]);
2380-
}
2381-
2382-
if(SPARSE_DECOMP)
2383-
if(threadIdx.x < TILE_ROWS)
2384-
nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x];
2385-
2386-
}
2387-
2388-
template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 0>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
2389-
template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
23902236
template __global__ void kgetRowStats<half, 1024, 0>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
23912237
template __global__ void kgetRowStats<half, 1024, 1>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
23922238

@@ -2430,16 +2276,15 @@ __global__ void kdequant_mm_int32_fp16(
24302276
row_idx = (block_offset + thread_offset + j) / numCols;
24312277
col_idx = (block_offset + thread_offset + j) % numCols;
24322278

2433-
local_colStats[j] = col_idx >= numCols ? 0.0f : colStats[col_idx];
2434-
local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx];
2279+
local_colStats[j] = col_idx >= numCols ? 0.0f : __ldg(&colStats[col_idx]);
2280+
local_rowStats[j] = row_idx >= numRows ? 0.0f : __ldg(&rowStats[row_idx]);
24352281
local_biasValue[j] = ((bias == nullptr) || col_idx >= numCols) ? 0.0f : __half2float(bias[col_idx]);
24362282
}
24372283

24382284
// Each block loads THREADS * ITEMS_PER_THREAD values from A
24392285
int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out
24402286
? THREADS * ITEMS_PER_THREAD
24412287
: n_out - block_offset;
2442-
__syncthreads();
24432288
LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0);
24442289

24452290
#pragma unroll ITEMS_PER_THREAD
@@ -2458,110 +2303,6 @@ __global__ void kdequant_mm_int32_fp16(
24582303
}
24592304
}
24602305

2461-
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols)
2462-
{
2463-
// assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD
2464-
// Each thread reads the same column but multiple rows
2465-
// Rows are loaded in shared memory and access is shared across the threadblock (broadcast)
2466-
2467-
// 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
2468-
// 1. Load data row by row (should be at least with TILE_SIZE = 512)
2469-
// 2. quantize data with row/col stats
2470-
// 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance)
2471-
2472-
// each block loads TILE_COLs columns and TILE_ROW rows
2473-
// after reading a tile the row counter increase by TILE_ROWS
2474-
// the col counter reset after reading TILE_COL elements
2475-
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
2476-
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
2477-
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
2478-
const int base_idx = (base_row*cols) + base_col;
2479-
const int items_per_load = ITEMS_PER_THREAD*THREADS;
2480-
2481-
typedef cub::BlockLoad<half, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadHalf;
2482-
__shared__ typename LoadHalf::TempStorage loadhalf;
2483-
typedef cub::BlockStore<char, THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_VECTORIZE> StoreInt8;
2484-
__shared__ typename StoreInt8::TempStorage storeint8;
2485-
2486-
__shared__ float smem_row_stats[TILE_ROWS];
2487-
__shared__ unsigned int smem_nnz_row_idx[TILE_ROWS];
2488-
2489-
half local_data[ITEMS_PER_THREAD];
2490-
float local_col_stats[ITEMS_PER_THREAD];
2491-
char local_quantized_data[ITEMS_PER_THREAD];
2492-
2493-
// 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
2494-
#pragma unroll ITEMS_PER_THREAD
2495-
for(int j = 0; j < ITEMS_PER_THREAD; j++)
2496-
if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols)
2497-
local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]);
2498-
2499-
for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x)
2500-
{
2501-
if(base_row + i < rows)
2502-
smem_row_stats[i] = rowStats[base_row+i];
2503-
2504-
if(SPARSE_DECOMP)
2505-
smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i];
2506-
}
2507-
__syncthreads();
2508-
2509-
// we load row after row from the base_position
2510-
// 1. Load data row by row (should be at least with TILE_SIZE = 512)
2511-
for(int row = 0; row < TILE_ROWS; row++)
2512-
{
2513-
if(base_row + row >= rows){ break; }
2514-
int i = base_idx + (row*cols);
2515-
int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
2516-
2517-
2518-
LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f);
2519-
float row_stat = __fdividef(127.0f, smem_row_stats[row]);
2520-
2521-
// 2. quantize data with row/col stats
2522-
#pragma unroll ITEMS_PER_THREAD
2523-
for(int j = 0; j < ITEMS_PER_THREAD; j++)
2524-
{
2525-
// we already pre-normalized the col/row stat:
2526-
// what this does is float/absmax*127 = int8
2527-
if(SPARSE_DECOMP)
2528-
{
2529-
if(fabsf((float)local_data[j]) >= threshold)
2530-
{
2531-
local_quantized_data[j] = 0;
2532-
2533-
int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX);
2534-
2535-
rowidx[old_idx] = base_row+row;
2536-
colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j;
2537-
val[old_idx] = local_data[j];
2538-
}
2539-
else
2540-
{
2541-
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
2542-
}
2543-
}
2544-
else
2545-
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
2546-
}
2547-
2548-
StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items);
2549-
2550-
// 2. quantize data with row/col stats
2551-
#pragma unroll ITEMS_PER_THREAD
2552-
for(int j = 0; j < ITEMS_PER_THREAD; j++)
2553-
{
2554-
// we already pre-normalized the col/row stat:
2555-
// what this does is float/absmax*127 = int8
2556-
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j]));
2557-
}
2558-
2559-
__syncthreads();
2560-
StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items);
2561-
2562-
}
2563-
}
2564-
25652306
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols)
25662307
{
25672308

@@ -3864,9 +3605,6 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(
38643605

38653606
template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
38663607

3867-
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
3868-
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
3869-
38703608
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
38713609
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);
38723610

csrc/kernels.cuh

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,9 @@ template <int ITEMS_PER_THREAD, int THREADS>__global__ void kdequant_mm_int32_fp
116116
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
117117
half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
118118

119-
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
120119
template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
121120
template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
122121

123-
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
124-
125122
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
126123

127124
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);

csrc/ops.cu

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,8 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
465465
NULL, NULL, 0, stream
466466
));
467467
} else {
468+
// This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows.
469+
468470
if (!SCALE_ROWS) {
469471
float alpha = 1.0f, beta = 0.0f;
470472
has_error |= checkCublasStatus(cublasLtMatmul(
@@ -532,28 +534,6 @@ void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float
532534
CUDA_CHECK_RETURN(cudaPeekAtLastError());
533535
}
534536

535-
#define STATS_THREADS 64
536-
#define STATS_ITEMS 4
537-
#define STATS_ROWS 16
538-
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
539-
{
540-
int tile_cols = STATS_THREADS*STATS_ITEMS;
541-
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
542-
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
543-
int row_tiles = (tiledRows/STATS_ROWS);
544-
int col_tiles = (tiledCols/tile_cols);
545-
row_tiles = row_tiles > 0 ? row_tiles : 1;
546-
col_tiles = col_tiles > 0 ? col_tiles : 1;
547-
int num_blocks = row_tiles * col_tiles;
548-
549-
if(nnz_threshold == 0.0)
550-
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
551-
else if(nnz_threshold != 0.0)
552-
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 1><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
553-
CUDA_CHECK_RETURN(cudaPeekAtLastError());
554-
555-
}
556-
557537
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
558538
if (threshold == 0.0)
559539
kgetRowStats<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
@@ -562,29 +542,6 @@ void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols,
562542
CUDA_CHECK_RETURN(cudaPeekAtLastError());
563543
}
564544

565-
void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols)
566-
{
567-
int threads = 64;
568-
int items_per_thread = 4;
569-
int tile_cols = threads*items_per_thread;
570-
int tile_rows = 16;
571-
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
572-
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
573-
int row_tiles = (tiledRows/tile_rows);
574-
int col_tiles = (tiledCols/tile_cols);
575-
row_tiles = row_tiles > 0 ? row_tiles : 1;
576-
col_tiles = col_tiles > 0 ? col_tiles : 1;
577-
int num_blocks = row_tiles * col_tiles;
578-
579-
580-
if(threshold > 0.0f)
581-
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
582-
else
583-
kDoubleRowColQuant<64, 4, 16, 64*4, 0><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
584-
585-
CUDA_CHECK_RETURN(cudaPeekAtLastError());
586-
}
587-
588545
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols)
589546
{
590547
int threads = 256;

csrc/ops.cuh

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,7 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle,
176176
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
177177
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
178178
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream);
179-
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols);
180179
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
181-
void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed,
182-
int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols);
183180
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
184181

185182
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols);

0 commit comments

Comments
 (0)