@@ -191,6 +191,16 @@ def get_instance(cls):
191191
192192FIRST_CUDA_DEVICE = torch .device ("cuda" , index = 0 )
193193
194+ if torch .cuda .device_count () > 1 :
195+
196+ def _cuda_device_of (a : torch .Tensor ):
197+ return torch .cuda .device_of (a )
198+ else :
199+ import contextlib
200+
201+ def _cuda_device_of (a : torch .Tensor ):
202+ return contextlib .nullcontext ()
203+
194204
195205def get_paged (* shape , dtype = torch .float32 , device = FIRST_CUDA_DEVICE ):
196206 num_bytes = dtype2bytes [dtype ] * prod (shape )
@@ -881,7 +891,7 @@ def quantize_blockwise(
881891
882892 is_on_gpu ([A , out , absmax ])
883893
884- with torch . cuda . device_of (A ):
894+ with _cuda_device_of (A ):
885895 args = (
886896 get_ptr (code ),
887897 get_ptr (A ),
@@ -992,7 +1002,7 @@ def dequantize_blockwise(
9921002
9931003 is_on_gpu ([A , absmax , out ])
9941004
995- with torch . cuda . device_of (A ):
1005+ with _cuda_device_of (A ):
9961006 args = (
9971007 get_ptr (quant_state .code ),
9981008 get_ptr (A ),
@@ -1183,7 +1193,7 @@ def quantize_4bit(
11831193
11841194 is_on_gpu ([A , out , absmax ])
11851195
1186- with torch . cuda . device_of (A ):
1196+ with _cuda_device_of (A ):
11871197 args = (
11881198 get_ptr (None ),
11891199 get_ptr (A ),
@@ -1330,7 +1340,7 @@ def dequantize_4bit(
13301340 is_on_gpu ([A , absmax , out ])
13311341 stream = _get_tensor_stream (A )
13321342
1333- with torch . cuda . device_of (A ):
1343+ with _cuda_device_of (A ):
13341344 args = (
13351345 get_ptr (None ),
13361346 get_ptr (A ),
@@ -1547,28 +1557,28 @@ def optimizer_update_32bit(
15471557 )
15481558
15491559 is_on_gpu ([g , p , state1 , state2 , unorm_vec ])
1550- prev_device = pre_call ( g . device )
1551- optim_func (
1552- get_ptr ( g ),
1553- get_ptr (p ),
1554- get_ptr (state1 ),
1555- get_ptr (state2 ),
1556- get_ptr (unorm_vec ),
1557- ct . c_float ( max_unorm ),
1558- ct .c_float (param_norm ),
1559- ct .c_float (beta1 ),
1560- ct .c_float (beta2 ),
1561- ct .c_float (beta3 ),
1562- ct .c_float (alpha ),
1563- ct .c_float (eps ),
1564- ct .c_float (weight_decay ),
1565- ct .c_int32 ( step ),
1566- ct .c_float ( lr ),
1567- ct .c_float (gnorm_scale ),
1568- ct .c_bool ( skip_zeros ),
1569- ct .c_int32 ( g . numel () ),
1570- )
1571- post_call ( prev_device )
1560+
1561+ with _cuda_device_of ( g ):
1562+ optim_func (
1563+ get_ptr (g ),
1564+ get_ptr (p ),
1565+ get_ptr (state1 ),
1566+ get_ptr (state2 ),
1567+ get_ptr ( unorm_vec ),
1568+ ct .c_float (max_unorm ),
1569+ ct .c_float (param_norm ),
1570+ ct .c_float (beta1 ),
1571+ ct .c_float (beta2 ),
1572+ ct .c_float (beta3 ),
1573+ ct .c_float (alpha ),
1574+ ct .c_float (eps ),
1575+ ct .c_float ( weight_decay ),
1576+ ct .c_int32 ( step ),
1577+ ct .c_float (lr ),
1578+ ct .c_float ( gnorm_scale ),
1579+ ct .c_bool ( skip_zeros ),
1580+ ct . c_int32 ( g . numel ()),
1581+ )
15721582
15731583
15741584@deprecated (
@@ -1731,8 +1741,7 @@ def optimizer_update_8bit_blockwise(
17311741 skip_zeros = False ,
17321742) -> None :
17331743 optim_func = None
1734- prev_device = pre_call (g .device )
1735- is_on_gpu ([g , p , state1 , state2 , qmap1 , qmap2 , absmax1 , absmax2 ])
1744+
17361745 if g .dtype == torch .float32 and state1 .dtype == torch .uint8 :
17371746 optim_func = str2optimizer8bit_blockwise [optimizer_name ][0 ]
17381747 elif g .dtype == torch .float16 and state1 .dtype == torch .uint8 :
@@ -1747,33 +1756,31 @@ def optimizer_update_8bit_blockwise(
17471756 raise ValueError (
17481757 f"Gradient+optimizer bit data type combination not supported: grad { g .dtype } , optimizer { state1 .dtype } " ,
17491758 )
1750- post_call (prev_device )
17511759
17521760 is_on_gpu ([p , g , state1 , state2 , qmap1 , qmap2 , absmax1 , absmax2 ])
17531761
1754- prev_device = pre_call (g .device )
1755- optim_func (
1756- get_ptr (p ),
1757- get_ptr (g ),
1758- get_ptr (state1 ),
1759- get_ptr (state2 ),
1760- ct .c_float (beta1 ),
1761- ct .c_float (beta2 ),
1762- ct .c_float (beta3 ),
1763- ct .c_float (alpha ),
1764- ct .c_float (eps ),
1765- ct .c_int32 (step ),
1766- ct .c_float (lr ),
1767- get_ptr (qmap1 ),
1768- get_ptr (qmap2 ),
1769- get_ptr (absmax1 ),
1770- get_ptr (absmax2 ),
1771- ct .c_float (weight_decay ),
1772- ct .c_float (gnorm_scale ),
1773- ct .c_bool (skip_zeros ),
1774- ct .c_int32 (g .numel ()),
1775- )
1776- post_call (prev_device )
1762+ with _cuda_device_of (g ):
1763+ optim_func (
1764+ get_ptr (p ),
1765+ get_ptr (g ),
1766+ get_ptr (state1 ),
1767+ get_ptr (state2 ),
1768+ ct .c_float (beta1 ),
1769+ ct .c_float (beta2 ),
1770+ ct .c_float (beta3 ),
1771+ ct .c_float (alpha ),
1772+ ct .c_float (eps ),
1773+ ct .c_int32 (step ),
1774+ ct .c_float (lr ),
1775+ get_ptr (qmap1 ),
1776+ get_ptr (qmap2 ),
1777+ get_ptr (absmax1 ),
1778+ get_ptr (absmax2 ),
1779+ ct .c_float (weight_decay ),
1780+ ct .c_float (gnorm_scale ),
1781+ ct .c_bool (skip_zeros ),
1782+ ct .c_int32 (g .numel ()),
1783+ )
17771784
17781785
17791786@deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
@@ -1966,7 +1973,7 @@ def gemv_4bit(
19661973 ldc = ct .c_int32 (ldc )
19671974 stream = _get_tensor_stream (A )
19681975
1969- with torch . cuda . device_of (A ):
1976+ with _cuda_device_of (A ):
19701977 if B .dtype in [torch .uint8 , torch .bfloat16 , torch .float16 , torch .float32 ]:
19711978 if A .dtype == torch .float16 :
19721979 lib .cgemm_4bit_inference_naive_fp16 (
@@ -2285,7 +2292,7 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten
22852292
22862293 is_on_gpu ([A , B , out ])
22872294
2288- with torch . cuda . device_of (A ):
2295+ with _cuda_device_of (A ):
22892296 ctx = CUBLAS_Context .get_instance ().get_context (A .device )
22902297 ptrA = get_ptr (A )
22912298 ptrB = get_ptr (B )
@@ -2343,7 +2350,7 @@ def int8_mm_dequant(
23432350
23442351 is_on_gpu ([A , row_stats , col_stats , out , bias ])
23452352
2346- with torch . cuda . device_of (A ):
2353+ with _cuda_device_of (A ):
23472354 lib .cdequant_mm_int32_fp16 (
23482355 ptrA , ptrRowStats , ptrColStats , ptrOut , ptrBias , numRows , numCols , _get_tensor_stream (A )
23492356 )
@@ -2407,7 +2414,7 @@ def get_row_absmax(A: torch.Tensor, threshold=0.0):
24072414
24082415 is_on_gpu ([A ])
24092416
2410- with torch . cuda . device_of (A ):
2417+ with _cuda_device_of (A ):
24112418 lib .cget_row_stats (
24122419 get_ptr (A ),
24132420 get_ptr (row_stats ),
@@ -2550,7 +2557,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
25502557 if outliers .any ():
25512558 outlier_cols = torch .argwhere (outliers .any (dim = 0 )).view (- 1 )
25522559
2553- with torch . cuda . device_of (A ):
2560+ with _cuda_device_of (A ):
25542561 lib .cint8_vector_quant (
25552562 get_ptr (A ),
25562563 get_ptr (out_row ),
0 commit comments