@@ -401,23 +401,6 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
401401 return torch .tensor (data , dtype = torch .float32 )
402402
403403
404- @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
405- def create_quantile_map (A , total_bits = 8 ):
406- q = estimate_quantiles (A , num_quantiles = 2 ** total_bits - 1 )
407- q = q .tolist ()
408- q .append (0 )
409-
410- gap = 256 - len (q )
411- for i in range (gap ):
412- q .append (0 )
413-
414- q .sort ()
415-
416- q = Tensor (q )
417- q = q / q .abs ().max ()
418- return q
419-
420-
421404def is_on_gpu (tensors : Iterable [Optional [torch .Tensor ]]):
422405 """Verifies that the input tensors are all on the same device.
423406
@@ -474,74 +457,6 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
474457 return ct .c_void_p (A .data_ptr ())
475458
476459
477- @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
478- def estimate_quantiles (
479- A : Tensor ,
480- out : Optional [torch .Tensor ] = None ,
481- offset : float = 1 / 512 ,
482- num_quantiles = 256 ,
483- ) -> Tensor :
484- """
485- Estimates 256 equidistant quantiles on the input tensor eCDF.
486-
487- Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
488- via the eCDF of the input tensor `A`. This is a fast but approximate algorithm
489- and the extreme quantiles close to 0 and 1 have high variance / large estimation
490- errors. These large errors can be avoided by using the offset variable which trims
491- the distribution. The default offset value of 1/512 ensures minimum entropy encoding -- it
492- trims 1/512 = 0.2% from each side of the distrivution. An offset value of 0.01 to 0.02
493- usually has a much lower error but is not a minimum entropy encoding. Given an offset
494- of 0.02 equidistance points in the range [0.02, 0.98] are used for the quantiles.
495-
496- Parameters
497- ----------
498- A : torch.Tensor
499- The input tensor. Any shape.
500- out : torch.Tensor
501- Tensor with the 256 estimated quantiles.
502- offset : float
503- The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
504- num_quantiles : int
505- The number of equally spaced quantiles.
506-
507- Returns
508- -------
509- torch.Tensor:
510- The 256 quantiles in float32 datatype.
511- """
512- if A .numel () < 256 :
513- raise NotImplementedError (
514- f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only { A .numel ()} values." ,
515- )
516- if num_quantiles > 256 :
517- raise NotImplementedError (
518- f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={ num_quantiles } " ,
519- )
520- if num_quantiles < 256 and offset == 1 / (512 ):
521- # override default arguments
522- offset = 1 / (2 * num_quantiles )
523-
524- if out is None :
525- out = torch .zeros ((256 ,), dtype = torch .float32 , device = A .device )
526-
527- with _cuda_device_of (A ):
528- is_on_gpu ([A , out ])
529-
530- if A .dtype == torch .float32 :
531- lib .cestimate_quantiles_fp32 (get_ptr (A ), get_ptr (out ), ct .c_float (offset ), ct .c_int (A .numel ()))
532- elif A .dtype == torch .float16 :
533- lib .cestimate_quantiles_fp16 (get_ptr (A ), get_ptr (out ), ct .c_float (offset ), ct .c_int (A .numel ()))
534- else :
535- raise NotImplementedError (f"Not supported data type { A .dtype } " )
536-
537- if num_quantiles < 256 :
538- step = round (256 / num_quantiles )
539- idx = torch .linspace (0 , 255 , num_quantiles ).long ().to (A .device )
540- out = out [idx ]
541-
542- return out
543-
544-
545460class QuantState :
546461 """container for quantization state components to work with Params4bit and similar classes"""
547462
@@ -1601,25 +1516,6 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
16011516 return current_gnorm , clip_value , gnorm_scale
16021517
16031518
1604- @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
1605- def histogram_scatter_add_2d (histogram : Tensor , index1 : Tensor , index2 : Tensor , source : Tensor ):
1606- assert len (histogram .shape ) == 2
1607- assert histogram .dtype == torch .float32
1608- assert source .dtype == torch .float32
1609- assert index1 .dtype == torch .int32
1610- assert index2 .dtype == torch .int32
1611-
1612- assert histogram .device .type == "cuda"
1613- assert index1 .device .type == "cuda"
1614- assert index2 .device .type == "cuda"
1615- assert source .device .type == "cuda"
1616-
1617- maxdim1 = ct .c_int32 (histogram .shape [0 ])
1618- n = ct .c_int32 (index1 .numel ())
1619- is_on_gpu ([histogram , index1 , index2 , source ])
1620- lib .chistogram_scatter_add_2d (get_ptr (histogram ), get_ptr (index1 ), get_ptr (index2 ), get_ptr (source ), maxdim1 , n )
1621-
1622-
16231519def check_matmul (A , B , out , transposed_A , transposed_B , expected_type = torch .int8 ):
16241520 if not torch .cuda .is_initialized ():
16251521 torch .cuda .init ()
@@ -2426,118 +2322,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
24262322C = 127.0
24272323
24282324
2429- @deprecated (
2430- "This function is deprecated and will be removed in a future release. "
2431- "Consider using `int8_vectorwise_quant` instead." ,
2432- category = FutureWarning ,
2433- )
2434- def vectorwise_quant (x , dim = 1 , quant_type = "vector" ):
2435- if quant_type == "linear" :
2436- max1 = torch .abs (x ).max ().float ()
2437- xq = torch .round (x / max1 * 127 ).to (torch .int8 )
2438- return xq , max1
2439- elif quant_type in ["vector" , "row" ]:
2440- max1 = torch .amax (torch .abs (x ), dim = dim , keepdim = True )
2441- xq = torch .round (x * (C / max1 )).to (torch .int8 )
2442- return xq , max1
2443- elif quant_type == "zeropoint" :
2444- dtype = x .dtype
2445- x = x .float ()
2446- dyna = x .max () - x .min ()
2447- if dyna == 0 :
2448- dyna = 1
2449- qx = 255.0 / dyna
2450- minx = x .min ()
2451- zpx = torch .round (minx * qx )
2452- x = torch .round (qx * x - zpx ) + zpx
2453- return x , qx
2454- elif quant_type in ["vector-zeropoint" , "row-zeropoint" ]:
2455- dtype = x .dtype
2456- x = x .float ()
2457- dyna = torch .amax (x , dim = dim , keepdim = True ) - torch .amin (x , dim = dim , keepdim = True )
2458- dyna [dyna == 0 ] = 1
2459- qx = 255.0 / dyna
2460- minx = torch .amin (x , dim = dim , keepdim = True )
2461- zpx = torch .round (minx * qx )
2462- x = torch .round (qx * x - zpx ) + zpx
2463- return x , qx
2464- elif quant_type == "truncated-vector" :
2465- with torch .no_grad ():
2466- absx = torch .abs (x )
2467- max1 = torch .amax (absx , dim = dim , keepdim = True )
2468- max1 = max1 * 0.7
2469- idx = absx > max1 .expand_as (absx )
2470- sign = torch .sign (x [idx ])
2471- x [idx ] = max1 .expand_as (absx )[idx ] * sign
2472- xq = torch .round (x / max1 * C ).to (torch .int8 )
2473- return xq , max1
2474- else :
2475- return None
2476-
2477-
2478- @deprecated (
2479- "This function is deprecated and will be removed in a future release." ,
2480- category = FutureWarning ,
2481- )
2482- def vectorwise_mm_dequant (xq , S1 , S2 , dtype = torch .half , quant_type = "vector" ):
2483- if quant_type == "linear" :
2484- norm = S1 * S2 / (C * C )
2485- # double cast needed to prevent overflows
2486- return (xq .float () * norm ).to (dtype )
2487- elif quant_type == "zeropoint" :
2488- norm = 1.0 / (S1 * S2 )
2489- return (xq .float () * norm ).to (dtype )
2490- elif quant_type == "row-zeropoint" :
2491- norm = 1.0 / (S1 * S2 )
2492- x = xq .float ()
2493- if len (S1 .shape ) == 3 and len (x .shape ) == 2 :
2494- S1 = S1 .squeeze (0 )
2495- if len (S2 .shape ) == 3 and len (x .shape ) == 2 :
2496- S2 = S2 .squeeze (0 )
2497- if len (S1 .shape ) == 2 :
2498- x *= norm
2499- else :
2500- x *= norm
2501- return x .to (dtype )
2502- elif quant_type == "vector-zeropoint" :
2503- x = xq .float ()
2504- if len (S1 .shape ) == 3 and len (x .shape ) == 2 :
2505- S1 = S1 .squeeze (0 )
2506- if len (S2 .shape ) == 3 and len (x .shape ) == 2 :
2507- S2 = S2 .squeeze (0 )
2508- if len (S1 .shape ) == 2 :
2509- x *= 1.0 / S1
2510- else :
2511- x *= 1.0 / S1
2512- x *= 1.0 / S2 .t ()
2513- return x .to (dtype )
2514- elif quant_type == "row" :
2515- x = xq .float ()
2516- if len (S1 .shape ) == 3 and len (x .shape ) == 2 :
2517- S1 = S1 .squeeze (0 )
2518- if len (S2 .shape ) == 3 and len (x .shape ) == 2 :
2519- S2 = S2 .squeeze (0 )
2520- if len (S1 .shape ) == 2 :
2521- x *= S1 * S2 / (C * C )
2522- else :
2523- x *= S1 * S2 / (C * C )
2524- return x .to (dtype )
2525- elif quant_type in ["truncated-vector" , "vector" ]:
2526- x = xq .float ()
2527- if len (S1 .shape ) == 3 and len (x .shape ) == 2 :
2528- S1 = S1 .squeeze (0 )
2529- if len (S2 .shape ) == 3 and len (x .shape ) == 2 :
2530- S2 = S2 .squeeze (0 )
2531- if len (S1 .shape ) == 2 :
2532- x *= S1 / C
2533- else :
2534- x *= S1 / C
2535- x *= S2 / C
2536- return x .to (dtype )
2537- else :
2538- return None
2539-
2540-
25412325def _enable_ipex_fusion (linear : torch .nn .Module , x : torch .Tensor ):
25422326 quant_state = linear .weight .quant_state
25432327
0 commit comments