@@ -2205,333 +2205,6 @@ __global__ void kdequant_mm_int32_fp16(
22052205 }
22062206}
22072207
2208- template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat (char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols)
2209- {
2210-
2211- // 0. Load data into 32*32 shared memory tiles
2212- // 1. transpose / reorder in shared memory
2213- // 2. store
2214-
2215- // COL32 FORMAT:
2216- // rows*32 tiles
2217-
2218- // TURING FORMAT:
2219- // 8*32 tiles with 4*4 subtiles
2220- // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
2221- // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
2222- // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
2223- // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
2224- // index increases by 32
2225-
2226- // AMPERE FORMAT:
2227- // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
2228- // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
2229- // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
2230-
2231-
2232- // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values
2233- // As such we need:
2234- // at least 32*4 shared memory tiles for col32; preferably 32*32
2235- // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32
2236- // at least 32*8 shared memory tiles for col4_turing: preferably 32*32
2237- // for efficient loading of row major we need to load 128 elements and repeat this 32 items
2238- // this would imply a 32x128 shared memory tile -> 4kb
2239- // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb
2240- // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy
2241- // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough
2242- // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM
2243- //
2244- // to make the shared memory work with that occupancy we might need to union the block loads/stores
2245-
2246- // each block loads TILE_COLs columns and TILE_ROW rows
2247- // after reading a tile the row counter increase by TILE_ROWS
2248- // the col counter reset after reading TILE_COL elements
2249- const int base_row = ((blockIdx .x *TILE_COLS)/tiledCols)*TILE_ROWS;
2250- // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
2251- const int base_col = (blockIdx .x *TILE_COLS) % tiledCols;
2252- const int base_idx = (base_row*cols) + base_col;
2253-
2254- // we load 128 bytes per warp with
2255- // 32 rows for transposes that fill col32 types
2256- // so that we can have contiguous stores
2257- __shared__ char smem_data[32 *33 *ITEMS_PER_THREAD];
2258- char local_data[ITEMS_PER_THREAD];
2259- typedef cub::BlockExchange<char , THREADS, ITEMS_PER_THREAD> BlockExchange;
2260-
2261- // we load row after row from the base_position
2262- // Load data row by row
2263- int warps = blockDim .x /32 ;
2264- int warp_id = threadIdx .x /32 ;
2265- int warp_lane = threadIdx .x % 32 ;
2266- int offset = 0 ;
2267-
2268- int smem_row = 0 ;
2269- // each warp loads one row of 128 bytes
2270- for (int row = warp_id; row < TILE_ROWS; row+=warps)
2271- {
2272- int i = base_idx + (row*cols);
2273- // we load up to 128 bytes/items per load
2274- int valid_items = cols - base_col > 32 *ITEMS_PER_THREAD ? 32 *ITEMS_PER_THREAD : cols - base_col;
2275-
2276- // 0. Load data into 32*32 shared memory tiles
2277- if (base_row + row < rows)
2278- {
2279- #pragma unroll ITEMS_PER_THREAD
2280- for (int j = 0 ; j < ITEMS_PER_THREAD; j++)
2281- {
2282- int col_idx = warp_lane+(j*32 );
2283- if (col_idx < valid_items)
2284- local_data[j] = A[i+col_idx];
2285- else
2286- local_data[j] = 0 ;
2287- }
2288- }
2289- else
2290- {
2291- #pragma unroll ITEMS_PER_THREAD
2292- for (int j = 0 ; j < ITEMS_PER_THREAD; j++)
2293- local_data[j] = 0 ;
2294- }
2295-
2296- if (TRANSPOSE)
2297- {
2298- #pragma unroll ITEMS_PER_THREAD
2299- for (int j = 0 ; j < ITEMS_PER_THREAD; j++)
2300- {
2301- int local_col = (32 *j)+warp_lane;
2302- // int local_row = row;
2303- // store as 256x32
2304- smem_data[(local_col*33 ) + row] = local_data[j];
2305- }
2306- }
2307- else
2308- {
2309- // treat smem as 32x256, that is 32 rows and 256 columns
2310- #pragma unroll ITEMS_PER_THREAD
2311- for (int j = 0 ; j < ITEMS_PER_THREAD; j++)
2312- smem_data[row*32 *ITEMS_PER_THREAD + (warp_lane) + (j*32 )] = local_data[j];
2313- }
2314-
2315-
2316-
2317- smem_row += warps;
2318-
2319- // 1. transpose / reorder in shared memory
2320- if (smem_row % 32 == 0 )
2321- {
2322- smem_row = 0 ;
2323- __syncthreads ();
2324-
2325- for (int subrow = warp_id; subrow < 32 ; subrow+=warps)
2326- {
2327- for (int j = 0 ; j < ITEMS_PER_THREAD; j++)
2328- {
2329-
2330- switch (FORMAT)
2331- {
2332- case COL32:
2333- if (TRANSPOSE)
2334- {
2335- // data lies in shared memory in the following way:
2336- // row0 [col0 col1 ... col31]
2337- // row1 [col0 col1 ... col31]
2338- // ...
2339- //
2340- // As such we read consecutive entries with 256 threads (8rows x 32 columns)
2341- // as j increase, the row increase by a factor of 8
2342- // We load 8 rows per subrow loop, and subrow increase by 8 per loop
2343- // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8
2344- const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
2345- const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
2346- // const int local_row = warp_id; // each warp_id is one row
2347- // const int block_row = base_col; // block offset for row
2348- // const int local_col = warp_lane
2349- // const int global_col = base_row; // block offset for col
2350- if ((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
2351- {
2352- // each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem
2353- char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
2354-
2355- // each 32 columns we have new tile
2356- // each tile has size outRows*32 and base_row is done in increments of 32
2357- offset = base_row*outRows;
2358- out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx .x ] = data;
2359- }
2360- }
2361- else
2362- {
2363- if (((base_row+subrow) < rows) && (base_col+(j*32 )+warp_lane < outCols))
2364- {
2365- offset = (base_col/32 )*(32 *rows);
2366- char data = smem_data[(subrow*32 *ITEMS_PER_THREAD) + (j*32 ) + warp_lane];
2367- out[offset+(base_row+subrow)*32 + ((j)*rows*32 )+warp_lane] = data;
2368- }
2369- }
2370- break ;
2371- case COL_TURING:
2372- // TURING FORMAT:
2373- // 8*32 tiles with 4*4 subtiles
2374- // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
2375- // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
2376- // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
2377- // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
2378- // index increases by 32
2379- //
2380- // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
2381- if (TRANSPOSE)
2382- {
2383- const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
2384- const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
2385- // const int local_row = warp_id; // each warp_id is one row
2386- // const int block_row = base_col; // block offset for row
2387- // const int local_col = warp_lane
2388- // const int global_col = base_row; // block offset for col
2389- if ((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
2390- {
2391- // each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem
2392- char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
2393-
2394- // each 32 columns we have new tile
2395- // each tile has size 8*32 = 256 elements offset
2396- // for each row offset of 8 we increaes the tile first
2397- // after all rows are exhausted, we increase the col
2398- int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8 )*256 ; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
2399-
2400- // we increase by row_tile_column every 32 columns
2401- // base_row increase in increments of 32
2402- // int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements
2403- // int col_offset = (base_row/32)*row_tile_column;
2404- // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
2405- // 256*outRows/8*base_row/32 = outRows*base_row
2406- int col_offset = outRows*base_row;
2407-
2408- offset = row_offset+col_offset;
2409-
2410- // since we process even number of rows with each j (8) and with each subrow (8j) we can determine
2411- // odd or even rows with the warp_id (each warp processes one row)
2412- // the col is warp_lane (max 32 columns per row) and the row warp_id
2413- if (warp_id % 2 == 1 )
2414- // odd
2415- offset += 128 + (warp_lane/4 )*16 + (warp_lane%4 ) + (((warp_id%8 )-1 )*2 );
2416- else
2417- // even
2418- offset += 0 + (warp_lane/4 )*16 + (warp_lane%4 ) + ((warp_id%8 )*2 );
2419-
2420- out[offset] = data;
2421- }
2422- }
2423- else
2424- {
2425- if (((base_row+subrow) < rows) && (base_col+(j*32 )+warp_lane < outCols))
2426- {
2427- char data = smem_data[(subrow*32 *ITEMS_PER_THREAD) + (j*32 ) + warp_lane];
2428- // set offset designates the tile offset among the 8*32 tiles
2429- // we first increase rows and then columns. Since we load 128 columns at once
2430- // we increase the offset by outRows*32 every 32 columns
2431- // additionally, we increase the offset by 8*32=256 every 8 rows
2432- offset = ((base_col+(j*32 ))/32 )*outRows*32 + (((base_row+subrow)/8 )*256 ); // global offset (8x32 tile)
2433- // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd
2434- // each of these has 32 values in total for 32*4 = 128 as offset if odd
2435- // every set of 4 columns increases the total offset by 16
2436- // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2
2437- // this happens every 8 rows anew (subrow % 8)
2438- // one writes 4 columns at once that is (col % 4) for the particular index in the subtile
2439- int subcol = warp_lane;
2440-
2441- // add local offset (4x4 sub-tile)
2442- if (subrow % 2 == 1 )
2443- // odd
2444- offset += 128 + (subcol/4 )*16 + (subcol%4 ) + (((subrow%8 )-1 )*2 );
2445- else
2446- // even
2447- offset += 0 + (subcol/4 )*16 + (subcol%4 ) + ((subrow%8 )*2 );
2448-
2449- out[offset] = data;
2450- }
2451- }
2452- break ;
2453- case COL_AMPERE:
2454- // AMPERE FORMAT:
2455- // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
2456- // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
2457- // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
2458- if (TRANSPOSE)
2459- {
2460- const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
2461- const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
2462- // const int local_row = warp_id; // each warp_id is one row
2463- // const int block_row = base_col; // block offset for row
2464- // const int local_col = warp_lane
2465- // const int global_col = base_row; // block offset for col
2466- if ((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
2467- {
2468- // each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem
2469- char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
2470-
2471- // each 32 columns we have new tile
2472- // each tile has size 32*32 = 1024 elements offset
2473- // for each row offset of 32 we increaes the tile first
2474- // after all rows are exhausted, we increase the col
2475- int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32 )*1024 ; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
2476-
2477- // we increase by row_tile_column every 32 columns
2478- // base_row increase in increments of 32
2479- // int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements
2480- // int col_offset = (base_row/32)*row_tile_column;
2481- // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
2482- // 1024*outRows/32*base_row/32 = outRows*base_row
2483- int col_offset = outRows*base_row;
2484-
2485- offset = row_offset+col_offset;
2486-
2487-
2488- // same as in the non-transpose case (see below)
2489- // the difference is that now rows = cols
2490- // in this case warp_id = subrow
2491-
2492- // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
2493- // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
2494- // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
2495- // every 2 rows, the offset increases by two [0, 1, 8, 9...]
2496- // every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
2497- int local_row = (jrow + warp_id) % 32 ; // offset for row > 32 is already calculated into row_offset
2498- int ampere_row = ((local_row % 8 )/2 )*8 + (local_row/8 )*2 + (local_row % 2 );
2499-
2500- // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane
2501- out[offset + (ampere_row*32 ) + warp_lane] = data;
2502- }
2503- }
2504- else
2505- {
2506- if (((base_row+subrow) < rows) && (base_col+(j*32 )+warp_lane < outCols))
2507- {
2508- char data = smem_data[(subrow*32 *ITEMS_PER_THREAD) + (j*32 ) + warp_lane];
2509-
2510- // set offset designates the tile offset among the 32*32 tiles
2511- // we first increase rows and then columns. Since we load 128 columns at once
2512- // we increase the offset by outRows*32 every 32 columns
2513- // additionally, we increase the offset by 32*32=1024 every 32 rows
2514- offset = ((base_col+(j*32 ))/32 )*outRows*32 + (((base_row+subrow)/32 )*1024 ); // global offset (32x32 tile)
2515-
2516- // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
2517- // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
2518- // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
2519- // every 2 rows, the offset increases by two [0, 1, 8, 9...]
2520- // every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
2521- int local_row = ((subrow % 8 )/2 )*8 + (subrow/8 )*2 + (subrow % 2 );
2522-
2523- // global offset + row with 32 cols each + 32 cols per j + col_idx
2524- out[offset + (local_row*32 ) + warp_lane] = data;
2525- }
2526- }
2527- break ;
2528- }
2529- }
2530- }
2531- }
2532- }
2533- }
2534-
25352208#define DENORM 1 .0f /127 .0f
25362209#define MAX_SPARSE_COUNT 32
25372210#define SMEM_SIZE 8 *256
@@ -3386,13 +3059,6 @@ template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max
33863059template __global__ void kspmm_coo_very_sparse_naive<signed char , 16 , 8 >(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
33873060template __global__ void kspmm_coo_very_sparse_naive<signed char , 32 , 8 >(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
33883061
3389- template __global__ void kTransformRowToFormat <256 , 8 , 32 , 32 *8 , 0 , COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
3390- template __global__ void kTransformRowToFormat <256 , 8 , 32 , 32 *8 , 1 , COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
3391- template __global__ void kTransformRowToFormat <256 , 8 , 32 , 32 *8 , 0 , COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
3392- template __global__ void kTransformRowToFormat <256 , 8 , 32 , 32 *8 , 1 , COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
3393- template __global__ void kTransformRowToFormat <256 , 8 , 32 , 32 *8 , 0 , COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
3394- template __global__ void kTransformRowToFormat <256 , 8 , 32 , 32 *8 , 1 , COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
3395-
33963062template __global__ void kdequant_mm_int32_fp16<4 , 512 >(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
33973063
33983064template __device__ unsigned char dQuantize<0 >(float * smem_code, const float rand, float x);
0 commit comments