@@ -251,12 +251,6 @@ def fill(A, value, device=None, prefetch=True):
251251 elementwise_func ("fill" , A , None , value )
252252
253253
254- @deprecated ("Function will be removed in a future release." , category = FutureWarning )
255- def arange (A , device = None ):
256- elementwise_func ("arange" , A , None , 0 )
257-
258-
259- @deprecated ("Function will be removed in a future release." , category = FutureWarning )
260254def _mul (A , B , device = None ):
261255 elementwise_func ("_mul" , A , B , 0 )
262256
@@ -407,6 +401,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
407401 return torch .tensor (data , dtype = torch .float32 )
408402
409403
404+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
410405def create_quantile_map (A , total_bits = 8 ):
411406 q = estimate_quantiles (A , num_quantiles = 2 ** total_bits - 1 )
412407 q = q .tolist ()
@@ -480,17 +475,6 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
480475
481476
482477@deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
483- def pre_call (device ):
484- prev_device = torch .cuda .current_device ()
485- torch .cuda .set_device (device )
486- return prev_device
487-
488-
489- @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
490- def post_call (prev_device ):
491- torch .cuda .set_device (prev_device )
492-
493-
494478def estimate_quantiles (
495479 A : Tensor ,
496480 out : Optional [torch .Tensor ] = None ,
@@ -539,15 +523,16 @@ def estimate_quantiles(
539523
540524 if out is None :
541525 out = torch .zeros ((256 ,), dtype = torch .float32 , device = A .device )
542- is_on_gpu ([A , out ])
543- device = pre_call (A .device )
544- if A .dtype == torch .float32 :
545- lib .cestimate_quantiles_fp32 (get_ptr (A ), get_ptr (out ), ct .c_float (offset ), ct .c_int (A .numel ()))
546- elif A .dtype == torch .float16 :
547- lib .cestimate_quantiles_fp16 (get_ptr (A ), get_ptr (out ), ct .c_float (offset ), ct .c_int (A .numel ()))
548- else :
549- raise NotImplementedError (f"Not supported data type { A .dtype } " )
550- post_call (device )
526+
527+ with _cuda_device_of (A ):
528+ is_on_gpu ([A , out ])
529+
530+ if A .dtype == torch .float32 :
531+ lib .cestimate_quantiles_fp32 (get_ptr (A ), get_ptr (out ), ct .c_float (offset ), ct .c_int (A .numel ()))
532+ elif A .dtype == torch .float16 :
533+ lib .cestimate_quantiles_fp16 (get_ptr (A ), get_ptr (out ), ct .c_float (offset ), ct .c_int (A .numel ()))
534+ else :
535+ raise NotImplementedError (f"Not supported data type { A .dtype } " )
551536
552537 if num_quantiles < 256 :
553538 step = round (256 / num_quantiles )
@@ -1219,12 +1204,12 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No
12191204 torch.Tensor:
12201205 Quantized 8-bit tensor.
12211206 """
1222- prev_device = pre_call ( A . device )
1223- if out is None :
1224- out = torch .zeros_like (A , dtype = torch .uint8 )
1225- is_on_gpu ([A , out ])
1226- lib .cquantize (get_ptr (code ), get_ptr (A ), get_ptr (out ), ct .c_int (A .numel ()))
1227- post_call ( prev_device )
1207+ with _cuda_device_of ( A ):
1208+ if out is None :
1209+ out = torch .zeros_like (A , dtype = torch .uint8 )
1210+ is_on_gpu ([A , out ])
1211+ lib .cquantize (get_ptr (code ), get_ptr (A ), get_ptr (out ), ct .c_int (A .numel ()))
1212+
12281213 return out
12291214
12301215
@@ -1250,13 +1235,13 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =
12501235 torch.Tensor:
12511236 32-bit output tensor.
12521237 """
1253- prev_device = pre_call ( A . device )
1254- if out is None :
1255- out = torch .zeros_like (A , dtype = torch .float32 )
1256- is_on_gpu ([code , A , out ])
1257- stream = _get_tensor_stream (A )
1258- lib .cdequantize (get_ptr (code ), get_ptr (A ), get_ptr (out ), ct .c_int (A .numel ()), stream )
1259- post_call ( prev_device )
1238+ with _cuda_device_of ( A ):
1239+ if out is None :
1240+ out = torch .zeros_like (A , dtype = torch .float32 )
1241+ is_on_gpu ([code , A , out ])
1242+ stream = _get_tensor_stream (A )
1243+ lib .cdequantize (get_ptr (code ), get_ptr (A ), get_ptr (out ), ct .c_int (A .numel ()), stream )
1244+
12601245 return out
12611246
12621247
@@ -1444,61 +1429,60 @@ def optimizer_update_8bit(
14441429 if max_unorm > 0.0 :
14451430 param_norm = torch .norm (p .data .float ())
14461431
1447- prev_device = pre_call (g .device )
1448- is_on_gpu ([g , p , state1 , state2 , unorm_vec , qmap1 , qmap2 , max1 , max2 , new_max1 , new_max2 ])
1449- if g .dtype == torch .float32 and state1 .dtype == torch .uint8 :
1450- str2optimizer8bit [optimizer_name ][0 ](
1451- get_ptr (p ),
1452- get_ptr (g ),
1453- get_ptr (state1 ),
1454- get_ptr (state2 ),
1455- get_ptr (unorm_vec ),
1456- ct .c_float (max_unorm ),
1457- ct .c_float (param_norm ),
1458- ct .c_float (beta1 ),
1459- ct .c_float (beta2 ),
1460- ct .c_float (eps ),
1461- ct .c_int32 (step ),
1462- ct .c_float (lr ),
1463- get_ptr (qmap1 ),
1464- get_ptr (qmap2 ),
1465- get_ptr (max1 ),
1466- get_ptr (max2 ),
1467- get_ptr (new_max1 ),
1468- get_ptr (new_max2 ),
1469- ct .c_float (weight_decay ),
1470- ct .c_float (gnorm_scale ),
1471- ct .c_int32 (g .numel ()),
1472- )
1473- elif g .dtype == torch .float16 and state1 .dtype == torch .uint8 :
1474- str2optimizer8bit [optimizer_name ][1 ](
1475- get_ptr (p ),
1476- get_ptr (g ),
1477- get_ptr (state1 ),
1478- get_ptr (state2 ),
1479- get_ptr (unorm_vec ),
1480- ct .c_float (max_unorm ),
1481- ct .c_float (param_norm ),
1482- ct .c_float (beta1 ),
1483- ct .c_float (beta2 ),
1484- ct .c_float (eps ),
1485- ct .c_int32 (step ),
1486- ct .c_float (lr ),
1487- get_ptr (qmap1 ),
1488- get_ptr (qmap2 ),
1489- get_ptr (max1 ),
1490- get_ptr (max2 ),
1491- get_ptr (new_max1 ),
1492- get_ptr (new_max2 ),
1493- ct .c_float (weight_decay ),
1494- ct .c_float (gnorm_scale ),
1495- ct .c_int32 (g .numel ()),
1496- )
1497- else :
1498- raise ValueError (
1499- f"Gradient+optimizer bit data type combination not supported: grad { g .dtype } , optimizer { state1 .dtype } " ,
1500- )
1501- post_call (prev_device )
1432+ with _cuda_device_of (g ):
1433+ is_on_gpu ([g , p , state1 , state2 , unorm_vec , qmap1 , qmap2 , max1 , max2 , new_max1 , new_max2 ])
1434+ if g .dtype == torch .float32 and state1 .dtype == torch .uint8 :
1435+ str2optimizer8bit [optimizer_name ][0 ](
1436+ get_ptr (p ),
1437+ get_ptr (g ),
1438+ get_ptr (state1 ),
1439+ get_ptr (state2 ),
1440+ get_ptr (unorm_vec ),
1441+ ct .c_float (max_unorm ),
1442+ ct .c_float (param_norm ),
1443+ ct .c_float (beta1 ),
1444+ ct .c_float (beta2 ),
1445+ ct .c_float (eps ),
1446+ ct .c_int32 (step ),
1447+ ct .c_float (lr ),
1448+ get_ptr (qmap1 ),
1449+ get_ptr (qmap2 ),
1450+ get_ptr (max1 ),
1451+ get_ptr (max2 ),
1452+ get_ptr (new_max1 ),
1453+ get_ptr (new_max2 ),
1454+ ct .c_float (weight_decay ),
1455+ ct .c_float (gnorm_scale ),
1456+ ct .c_int32 (g .numel ()),
1457+ )
1458+ elif g .dtype == torch .float16 and state1 .dtype == torch .uint8 :
1459+ str2optimizer8bit [optimizer_name ][1 ](
1460+ get_ptr (p ),
1461+ get_ptr (g ),
1462+ get_ptr (state1 ),
1463+ get_ptr (state2 ),
1464+ get_ptr (unorm_vec ),
1465+ ct .c_float (max_unorm ),
1466+ ct .c_float (param_norm ),
1467+ ct .c_float (beta1 ),
1468+ ct .c_float (beta2 ),
1469+ ct .c_float (eps ),
1470+ ct .c_int32 (step ),
1471+ ct .c_float (lr ),
1472+ get_ptr (qmap1 ),
1473+ get_ptr (qmap2 ),
1474+ get_ptr (max1 ),
1475+ get_ptr (max2 ),
1476+ get_ptr (new_max1 ),
1477+ get_ptr (new_max2 ),
1478+ ct .c_float (weight_decay ),
1479+ ct .c_float (gnorm_scale ),
1480+ ct .c_int32 (g .numel ()),
1481+ )
1482+ else :
1483+ raise ValueError (
1484+ f"Gradient+optimizer bit data type combination not supported: grad { g .dtype } , optimizer { state1 .dtype } " ,
1485+ )
15021486
15031487
15041488def optimizer_update_8bit_blockwise (
@@ -1577,25 +1561,24 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
15771561 The current optimization steps (number of past gradient norms).
15781562
15791563 """
1580- prev_device = pre_call (grad .device )
1581- is_on_gpu ([grad , gnorm_vec ])
1582- if grad .dtype == torch .float32 :
1583- lib .cpercentile_clipping_g32 (
1584- get_ptr (grad ),
1585- get_ptr (gnorm_vec ),
1586- ct .c_int32 (step ),
1587- ct .c_int32 (grad .numel ()),
1588- )
1589- elif grad .dtype == torch .float16 :
1590- lib .cpercentile_clipping_g16 (
1591- get_ptr (grad ),
1592- get_ptr (gnorm_vec ),
1593- ct .c_int32 (step ),
1594- ct .c_int32 (grad .numel ()),
1595- )
1596- else :
1597- raise ValueError (f"Gradient type { grad .dtype } not supported!" )
1598- post_call (prev_device )
1564+ with _cuda_device_of (grad ):
1565+ is_on_gpu ([grad , gnorm_vec ])
1566+ if grad .dtype == torch .float32 :
1567+ lib .cpercentile_clipping_g32 (
1568+ get_ptr (grad ),
1569+ get_ptr (gnorm_vec ),
1570+ ct .c_int32 (step ),
1571+ ct .c_int32 (grad .numel ()),
1572+ )
1573+ elif grad .dtype == torch .float16 :
1574+ lib .cpercentile_clipping_g16 (
1575+ get_ptr (grad ),
1576+ get_ptr (gnorm_vec ),
1577+ ct .c_int32 (step ),
1578+ ct .c_int32 (grad .numel ()),
1579+ )
1580+ else :
1581+ raise ValueError (f"Gradient type { grad .dtype } not supported!" )
15991582
16001583 current_gnorm = torch .sqrt (gnorm_vec [step % 100 ])
16011584 vals , idx = torch .sort (gnorm_vec )
@@ -2333,7 +2316,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
23332316 if out is None :
23342317 out = torch .zeros ((cooA .rows , B .shape [1 ]), device = B .device , dtype = cooA .values .dtype )
23352318 nnz = cooA .nnz
2336- prev_device = pre_call ( B . device )
2319+
23372320 assert cooA .rowidx .numel () == nnz
23382321 assert cooA .colidx .numel () == nnz
23392322 assert cooA .values .numel () == nnz
@@ -2370,43 +2353,43 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
23702353 cldb = ct .c_int32 (ldb )
23712354 cldc = ct .c_int32 (ldc )
23722355
2373- is_on_gpu ([ cooA . rowidx , cooA . colidx , cooA . values , B , out , dequant_stats ])
2374- if B . dtype == torch . float16 :
2375- lib . cspmm_coo_very_sparse_naive_fp16 (
2376- ptrMaxCount ,
2377- ptrMaxIdx ,
2378- ptrOffset ,
2379- ptrRowidx ,
2380- ptrColidx ,
2381- ptrValues ,
2382- ptrB ,
2383- ptrC ,
2384- ptrDequantStats ,
2385- cnnz_rows ,
2386- cnnz ,
2387- crowsA ,
2388- crowsB ,
2389- ccolsB ,
2390- )
2391- elif B . dtype == torch . int8 :
2392- lib . cspmm_coo_very_sparse_naive_int8 (
2393- ptrMaxCount ,
2394- ptrMaxIdx ,
2395- ptrOffset ,
2396- ptrRowidx ,
2397- ptrColidx ,
2398- ptrValues ,
2399- ptrB ,
2400- ptrC ,
2401- ptrDequantStats ,
2402- cnnz_rows ,
2403- cnnz ,
2404- crowsA ,
2405- crowsB ,
2406- ccolsB ,
2407- )
2408- # else: assertion error
2409- post_call ( prev_device )
2356+ with _cuda_device_of ( B ):
2357+ is_on_gpu ([ cooA . rowidx , cooA . colidx , cooA . values , B , out , dequant_stats ])
2358+ if B . dtype == torch . float16 :
2359+ lib . cspmm_coo_very_sparse_naive_fp16 (
2360+ ptrMaxCount ,
2361+ ptrMaxIdx ,
2362+ ptrOffset ,
2363+ ptrRowidx ,
2364+ ptrColidx ,
2365+ ptrValues ,
2366+ ptrB ,
2367+ ptrC ,
2368+ ptrDequantStats ,
2369+ cnnz_rows ,
2370+ cnnz ,
2371+ crowsA ,
2372+ crowsB ,
2373+ ccolsB ,
2374+ )
2375+ elif B . dtype == torch . int8 :
2376+ lib . cspmm_coo_very_sparse_naive_int8 (
2377+ ptrMaxCount ,
2378+ ptrMaxIdx ,
2379+ ptrOffset ,
2380+ ptrRowidx ,
2381+ ptrColidx ,
2382+ ptrValues ,
2383+ ptrB ,
2384+ ptrC ,
2385+ ptrDequantStats ,
2386+ cnnz_rows ,
2387+ cnnz ,
2388+ crowsA ,
2389+ crowsB ,
2390+ ccolsB ,
2391+ )
2392+ # else: assertion error
24102393
24112394 return out
24122395
@@ -2463,18 +2446,6 @@ def vectorwise_quant(x, dim=1, quant_type="vector"):
24632446 return None
24642447
24652448
2466- @deprecated (
2467- "This function is deprecated and will be removed in a future release." ,
2468- category = FutureWarning ,
2469- )
2470- def vectorwise_dequant (xq , max1 , quant_type = "vector" ):
2471- if quant_type == "vector" :
2472- x = (xq / C * max1 ).to (torch .float32 )
2473- return x
2474- else :
2475- return None
2476-
2477-
24782449@deprecated (
24792450 "This function is deprecated and will be removed in a future release." ,
24802451 category = FutureWarning ,
0 commit comments