@@ -770,6 +770,8 @@ def optimizer_update_32bit(
770770 f'Optimizer not implemented: { optimizer_name } . Choices: { "," .join (str2optimizer32bit .keys ())} '
771771 )
772772
773+ prev_device = pre_call (g .device )
774+ is_on_gpu ([g , p , state1 , state2 , unorm_vec ])
773775 if g .dtype == torch .float32 and state1 .dtype == torch .float32 :
774776 str2optimizer32bit [optimizer_name ][0 ](
775777 get_ptr (g ),
@@ -812,6 +814,7 @@ def optimizer_update_32bit(
812814 raise ValueError (
813815 f"Gradient+optimizer bit data type combination not supported: grad { g .dtype } , optimizer { state1 .dtype } "
814816 )
817+ post_call (prev_device )
815818
816819
817820def optimizer_update_8bit (
@@ -890,6 +893,8 @@ def optimizer_update_8bit(
890893 if max_unorm > 0.0 :
891894 param_norm = torch .norm (p .data .float ())
892895
896+ prev_device = pre_call (g .device )
897+ is_on_gpu ([g , p , state1 , state2 , unorm_vec , qmap1 , qmap2 , max1 , max2 , new_max1 , new_max2 ])
893898 if g .dtype == torch .float32 and state1 .dtype == torch .uint8 :
894899 str2optimizer8bit [optimizer_name ][0 ](
895900 get_ptr (p ),
@@ -942,6 +947,7 @@ def optimizer_update_8bit(
942947 raise ValueError (
943948 f"Gradient+optimizer bit data type combination not supported: grad { g .dtype } , optimizer { state1 .dtype } "
944949 )
950+ post_call (prev_device )
945951
946952
947953def optimizer_update_8bit_blockwise (
@@ -964,6 +970,8 @@ def optimizer_update_8bit_blockwise(
964970 skip_zeros = False ,
965971) -> None :
966972
973+ prev_device = pre_call (g .device )
974+ is_on_gpu ([g , p , state1 , state2 , qmap1 , qmap2 , absmax1 , absmax2 ])
967975 if g .dtype == torch .float32 and state1 .dtype == torch .uint8 :
968976 str2optimizer8bit_blockwise [optimizer_name ][0 ](
969977 get_ptr (p ),
@@ -1008,6 +1016,7 @@ def optimizer_update_8bit_blockwise(
10081016 raise ValueError (
10091017 f"Gradient+optimizer bit data type combination not supported: grad { g .dtype } , optimizer { state1 .dtype } "
10101018 )
1019+ post_call (prev_device )
10111020
10121021
10131022def percentile_clipping (
@@ -1023,6 +1032,7 @@ def percentile_clipping(
10231032 The current optimiation steps (number of past gradient norms).
10241033
10251034 """
1035+ prev_device = pre_call (grad .device )
10261036 is_on_gpu ([grad , gnorm_vec ])
10271037 if grad .dtype == torch .float32 :
10281038 lib .cpercentile_clipping_g32 (
@@ -1040,6 +1050,7 @@ def percentile_clipping(
10401050 )
10411051 else :
10421052 raise ValueError (f"Gradient type { grad .dtype } not supported!" )
1053+ post_call (prev_device )
10431054
10441055 current_gnorm = torch .sqrt (gnorm_vec [step % 100 ])
10451056 vals , idx = torch .sort (gnorm_vec )
@@ -1796,6 +1807,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
17961807 (cooA .rows , B .shape [1 ]), device = B .device , dtype = cooA .values .dtype
17971808 )
17981809 nnz = cooA .nnz
1810+ prev_device = pre_call (B .device )
17991811 assert cooA .rowidx .numel () == nnz
18001812 assert cooA .colidx .numel () == nnz
18011813 assert cooA .values .numel () == nnz
@@ -1872,6 +1884,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
18721884 ccolsB ,
18731885 )
18741886 # else: assertion error
1887+ post_call (prev_device )
18751888
18761889 return out
18771890
0 commit comments