@@ -1795,102 +1795,6 @@ def int8_mm_dequant(
17951795 return result
17961796
17971797
1798- @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
1799- def get_colrow_absmax (
1800- A : torch .Tensor ,
1801- row_stats : Optional [torch .Tensor ] = None ,
1802- col_stats : Optional [torch .Tensor ] = None ,
1803- nnz_block_ptr : Optional [torch .Tensor ] = None ,
1804- threshold = 0.0 ,
1805- ) -> tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
1806- """ "Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
1807-
1808- The row-wise and column-wise absmax values are determined.
1809-
1810- For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
1811-
1812- <Tip>
1813- This function is useful for training, but for inference it is advised to use [`get_row_absmax`] instead.
1814- The column-wise quantization scales are not typically needed in inference scenarios.
1815- </Tip>
1816-
1817- Args:
1818- A (`torch.Tensor` with dtype `torch.float16`): Input tensor.
1819- row_stats (`torch.Tensor`, *optional*): If provided, calculation of row statistics is skipped.
1820- col_stats (`torch.Tensor`, *optional*): If provided, calculation of column statistics is skipped.
1821- nnz_block_ptr (`torch.Tensor`, *optional*): Not used.
1822- threshold (`float`, *optional*):
1823- An optional threshold for sparse decomposition of outlier features.
1824- No outliers are held back when 0.0. Defaults to 0.0.
1825-
1826- Returns:
1827- `Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing quantization statistics.
1828- - `torch.Tensor` with dtype `torch.float32`: The row-wise quantization statistics.
1829- - `torch.Tensor` with dtype `torch.float32`: The column-wise quantization statistics.
1830- - `torch.Tensor` with dtype `torch.bool`, *optional*: A mask indicating the locations of outliers in the input tensor.
1831- """
1832- assert A .is_floating_point ()
1833-
1834- outlier_mask = None
1835-
1836- if row_stats is None or col_stats is None :
1837- absA = A .abs ().view (- 1 , A .shape [- 1 ])
1838-
1839- if threshold > 0.0 :
1840- # Filter outliers from stats when enabled
1841- outlier_mask = absA >= threshold
1842- absA .masked_fill_ (outlier_mask , 0.0 )
1843-
1844- if row_stats is None :
1845- # shape [rows]; unsqueeze(-1) gives [rows,1]
1846- # We have a CUDA kernel for row max, but not yet for cols.
1847- row_stats = get_row_absmax (A , threshold )
1848-
1849- if col_stats is None :
1850- # shape [cols]; unsqueeze(0) gives [1,cols]
1851- col_stats = absA .amax (dim = 0 , keepdim = False ).float ()
1852-
1853- return row_stats , col_stats , outlier_mask
1854-
1855-
1856- @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
1857- def get_row_absmax (A : torch .Tensor , threshold = 0.0 ):
1858- """Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
1859-
1860- For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
1861-
1862- Args:
1863- A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
1864- threshold (`float`, *optional*):
1865- An optional threshold for sparse decomposition of outlier features.
1866- No outliers are held back when 0.0. Defaults to 0.0.
1867-
1868- Returns:
1869- `torch.Tensor` with dtype `torch.float32`: The absolute maximum value for each row, with outliers ignored.
1870- """
1871-
1872- assert A .dtype == torch .float16
1873-
1874- rows = prod (A .shape [:- 1 ])
1875- cols = A .shape [- 1 ]
1876-
1877- row_stats = torch .empty ((rows ,), dtype = torch .float32 , device = A .device )
1878-
1879- is_on_gpu ([A ])
1880-
1881- with _cuda_device_of (A ):
1882- lib .cget_row_stats (
1883- get_ptr (A ),
1884- get_ptr (row_stats ),
1885- ct .c_float (threshold ),
1886- ct .c_int32 (rows ),
1887- ct .c_int32 (cols ),
1888- _get_tensor_stream (A ),
1889- )
1890-
1891- return row_stats
1892-
1893-
18941798class COOSparseTensor :
18951799 def __init__ (
18961800 self , rows : int , cols : int , nnz : int , rowidx : torch .Tensor , colidx : torch .Tensor , values : torch .Tensor
0 commit comments