@@ -2330,7 +2330,7 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32):
23302330 ldb = shapeB [- 1 ] # Activations (batch, tokens, inputs)
23312331 ldc = shapeC [- 1 ] # Output (batch, tokens, outputs)
23322332
2333- assert lda == ldb , f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = { shapeB } @ B= { shapeA } "
2333+ assert lda == ldb , f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = { shapeB } @ { shapeA } "
23342334
23352335 prev_device = A .device
23362336 torch .cuda .set_device (A .device )
@@ -2361,18 +2361,25 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32):
23612361 return out , Sout
23622362
23632363
2364- def mm_dequant (A , quant_state , row_stats , col_stats , out = None , new_row_stats = None , new_col_stats = None , bias = None ):
2364+ def mm_dequant_torch (
2365+ A : torch .Tensor ,
2366+ quant_state : Optional [Tuple [torch .Size , str ]], # TODO: deprecate. (shape, format)
2367+ row_stats : torch .Tensor ,
2368+ col_stats : torch .Tensor ,
2369+ out : Optional [torch .Tensor ] = None ,
2370+ new_row_stats = None , # TODO: unused
2371+ new_col_stats = None , # TODO: unused
2372+ bias : Optional [torch .Tensor ] = None ,
2373+ ):
23652374 assert A .dtype == torch .int32
23662375
2367- compute_dtype = torch .float32
2368-
2369- A_calc = A .view (- 1 , A .shape [- 1 ]).to (compute_dtype )
2370- row_stats = row_stats .reshape (- 1 ).unsqueeze (- 1 ).to (compute_dtype )
2371- col_stats = col_stats .reshape (- 1 ).unsqueeze (0 ).to (compute_dtype )
2376+ A_calc = A .view (- 1 , A .shape [- 1 ])
2377+ row_stats = row_stats .reshape (- 1 ).unsqueeze (- 1 )
2378+ col_stats = col_stats .reshape (- 1 ).unsqueeze (0 )
23722379
23732380 # TODO support out != None
23742381
2375- out = A_calc * (row_stats * col_stats ) * 6.200124e-5 # .to(torch.float16)
2382+ out = A_calc * (row_stats * col_stats ) * 6.200124e-5
23762383
23772384 if bias is not None :
23782385 # assert bias.dtype == torch.float16
@@ -2381,42 +2388,40 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non
23812388 return out .to (torch .float16 )
23822389
23832390
2384- def mm_dequant_old (A , quant_state , row_stats , col_stats , out = None , new_row_stats = None , new_col_stats = None , bias = None ):
2391+ def mm_dequant (
2392+ A : torch .Tensor ,
2393+ quant_state : Optional [Tuple [torch .Size , str ]], # TODO: deprecate. (shape, format)
2394+ row_stats : torch .Tensor ,
2395+ col_stats : torch .Tensor ,
2396+ out : Optional [torch .Tensor ] = None ,
2397+ new_row_stats = None , # TODO: unused
2398+ new_col_stats = None , # TODO: unused
2399+ bias : Optional [torch .Tensor ] = None ,
2400+ ):
23852401 assert A .dtype == torch .int32
2402+
23862403 if bias is not None :
23872404 assert bias .dtype == torch .float16
2388- out_shape = quant_state [0 ]
2389- if len (out_shape ) == 3 :
2390- out_shape = (out_shape [0 ] * out_shape [1 ], out_shape [2 ])
23912405
23922406 if out is None :
2393- out = torch .empty (out_shape , dtype = torch .float16 , device = A .device )
2394- if new_row_stats is None :
2395- new_row_stats = torch .empty (out_shape [0 ], dtype = torch .float32 , device = A .device )
2396- if new_col_stats is None :
2397- new_col_stats = torch .empty (out_shape [1 ], dtype = torch .float32 , device = A .device )
2398- assert new_row_stats .shape [0 ] == row_stats .shape [0 ], f"{ new_row_stats .shape } vs { row_stats .shape } "
2399- assert new_col_stats .shape [0 ] == col_stats .shape [0 ], f"{ new_col_stats .shape } vs { col_stats .shape } "
2407+ out = torch .empty_like (A , dtype = torch .float16 )
24002408
2401- prev_device = pre_call (A .device )
24022409 ptrA = get_ptr (A )
24032410 ptrOut = get_ptr (out )
24042411 ptrRowStats = get_ptr (row_stats )
24052412 ptrColStats = get_ptr (col_stats )
2406- ptrNewRowStats = get_ptr (new_row_stats )
2407- ptrNewColStats = get_ptr (new_col_stats )
24082413 ptrBias = get_ptr (bias )
2409- numRows = ct .c_int32 (out_shape [0 ])
2410- numCols = ct .c_int32 (out_shape [1 ])
2414+ numRows = ct .c_int32 (prod (A .shape [:- 1 ]))
2415+ numCols = ct .c_int32 (A .shape [- 1 ])
2416+
2417+ is_on_gpu ([A , row_stats , col_stats , out , bias ])
24112418
2412- is_on_gpu ([ A , row_stats , col_stats , out , new_row_stats , new_col_stats , bias ] )
2419+ prev_device = pre_call ( A . device )
24132420 lib .cdequant_mm_int32_fp16 (
24142421 ptrA ,
24152422 ptrRowStats ,
24162423 ptrColStats ,
24172424 ptrOut ,
2418- ptrNewRowStats ,
2419- ptrNewColStats ,
24202425 ptrBias ,
24212426 numRows ,
24222427 numCols ,
@@ -2426,7 +2431,33 @@ def mm_dequant_old(A, quant_state, row_stats, col_stats, out=None, new_row_stats
24262431 return out
24272432
24282433
2429- def get_colrow_absmax (A , row_stats = None , col_stats = None , nnz_block_ptr = None , threshold = 0.0 ):
2434+ def get_colrow_absmax (
2435+ A : torch .Tensor ,
2436+ row_stats : torch .Tensor = None ,
2437+ col_stats : torch .Tensor = None ,
2438+ nnz_block_ptr : torch .Tensor = None ,
2439+ threshold = 0.0 ,
2440+ ):
2441+ # Note: prior impl only works with fp16
2442+ assert A .is_floating_point ()
2443+
2444+ if row_stats is None or col_stats is None :
2445+ absA = A .abs ().view (- 1 , A .shape [- 1 ]) # view as 2D
2446+ if row_stats is None :
2447+ # shape [rows]; unsqueeze(-1) gives [rows,1]
2448+ row_stats = absA .amax (dim = 1 , keepdim = False ).float ()
2449+ if col_stats is None :
2450+ # shape [cols]; unsqueeze(0) gives [1,cols]
2451+ col_stats = absA .amax (dim = 0 , keepdim = False ).float ()
2452+
2453+ # TODO: threshold support
2454+ if nnz_block_ptr is None and threshold > 0.0 :
2455+ nnz_block_ptr = torch .zeros_like (A , dtype = torch .int32 )
2456+
2457+ return row_stats , col_stats , nnz_block_ptr
2458+
2459+
2460+ def get_colrow_absmax_old (A , row_stats = None , col_stats = None , nnz_block_ptr = None , threshold = 0.0 ):
24302461 assert A .dtype == torch .float16
24312462 device = A .device
24322463
@@ -2543,19 +2574,27 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
25432574 return COOSparseTensor (rows , cols , nnz , rowidx , colidx , values )
25442575
25452576
2577+ @torch .compile
25462578def double_quant (A , col_stats = None , row_stats = None , out_col = None , out_row = None , threshold = 0.0 ):
2547- # TODO: Optimize/write CUDA kernel for this. Currently vectorwise_quant will recalculate row/col stats.
2579+ # TODO: Optimize/write CUDA kernel for this
25482580 # TODO: Support threshold
25492581
2550- # if out_col is None:
2551- # out_col = torch.zeros(A.shape, device=A.device, dtype=torch.int8)
2552- # if out_row is None:
2553- # out_row = torch.zeros(A.shape, device=A.device, dtype=torch.int8)
2582+ if row_stats is None or col_stats is None :
2583+ row_stats , col_stats , nnz_row_ptr = get_colrow_absmax (A , threshold = threshold )
2584+
2585+ scaled_A = A .mul (C )
2586+
2587+ # quant_row = torch.round(A * (C / row_stats.unsqueeze(-1))).to(torch.int8)
2588+ # quant_col = torch.round(A * (C / col_stats.unsqueeze(0))).to(torch.int8)
2589+ quant_row = torch .round (scaled_A / row_stats .unsqueeze (- 1 )).to (torch .int8 )
2590+ quant_col = torch .round (scaled_A / col_stats .unsqueeze (0 )).to (torch .int8 )
25542591
2555- out_col , Scol = vectorwise_quant (A , dim = 0 )
2556- out_row , Srow = vectorwise_quant (A , dim = 1 )
2592+ if out_row is not None :
2593+ quant_row = out_row .copy_ (quant_row )
2594+ if out_col is not None :
2595+ quant_col = out_col .copy_ (quant_col )
25572596
2558- return out_row , out_col , Srow .flatten ().float (), Scol .flatten ().float (), None # coo_tensor
2597+ return quant_row , quant_col , row_stats .flatten ().float (), col_stats .flatten ().float (), None
25592598
25602599
25612600def double_quant_old (A , col_stats = None , row_stats = None , out_col = None , out_row = None , threshold = 0.0 ):
0 commit comments