@@ -1011,7 +1011,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
10111011
10121012template <typename T, int OPTIMIZER>
10131013__launch_bounds__ (TH, 1 )
1014- __global__ void kOptimizer32bit1State(T *g, T *p,
1014+ __global__ void kOptimizer32bit1State(T *g, T *p, T *return_updates,
10151015 float *state1, float *unorm, const float max_unorm, const float param_norm,
10161016 const float beta1, const float beta2, const float eps, const float weight_decay,
10171017 const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
@@ -1057,13 +1057,13 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
10571057 __syncthreads ();
10581058 LoadFloat (temp_storage.loadf ).Load (&(state1[i]), s1_vals, valid_items);
10591059 __syncthreads ();
1060- Load (temp_storage.load ).Load (&(p[i]), p_vals, valid_items);
1060+ Load (temp_storage.load ).Load (return_updates == nullptr ? &(p[i]) : &(return_updates [i]), p_vals, valid_items);
10611061
10621062 # pragma unroll 4
10631063 for (unsigned int j = 0 ; j < NUM_PER_THREAD; j++)
10641064 {
10651065 g_vals[j] = gnorm_scale*((float )g_vals[j]);
1066- if (weight_decay > 0 .0f )
1066+ if (weight_decay > 0 .0f && return_updates == nullptr )
10671067 g_vals[j] = (float )g_vals[j] + (((float )p_vals[j])*weight_decay);
10681068 }
10691069
@@ -1080,26 +1080,26 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
10801080 else
10811081 s1_vals[j] = s1_vals[j]*beta1 + ((float )g_vals[j]);
10821082
1083- p_vals[j] = ((float )p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
1083+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) + update_scale*(-lr*(s1_vals[j]));
10841084 break ;
10851085 case LION:
1086- p_vals[j] = ((float )p_vals[j]) - update_scale*(lr*sgn (((float )s1_vals[j])*beta1 + ((1 .0f -beta1)*((float )g_vals[j]))));
1086+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - update_scale*(lr*sgn (((float )s1_vals[j])*beta1 + ((1 .0f -beta1)*((float )g_vals[j]))));
10871087 s1_vals[j] = s1_vals[j]*beta2 + ((1 .0f -beta2)*((float )g_vals[j]));
10881088 break ;
10891089 case RMSPROP:
10901090 s1_vals[j] = s1_vals[j]*beta1 + ((1 .0f -beta1)*((float )g_vals[j])*((float )g_vals[j]));
1091- p_vals[j] = ((float )p_vals[j]) - update_scale*(lr*__fdividef ((float )g_vals[j],sqrtf ((float )s1_vals[j])+eps));
1091+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - update_scale*(lr*__fdividef ((float )g_vals[j],sqrtf ((float )s1_vals[j])+eps));
10921092 break ;
10931093 case ADAGRAD:
10941094 s1_vals[j] = s1_vals[j] + ((float )g_vals[j])*((float )g_vals[j]);
1095- p_vals[j] = ((float )p_vals[j]) - lr*__fdividef ((float )g_vals[j],sqrtf ((float )s1_vals[j])+eps);
1095+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - lr*__fdividef ((float )g_vals[j],sqrtf ((float )s1_vals[j])+eps);
10961096 break ;
10971097 }
10981098 }
10991099 }
11001100
11011101 __syncthreads ();
1102- Store (temp_storage.store ).Store (&(p[i]), p_vals, valid_items);
1102+ Store (temp_storage.store ).Store (return_updates == nullptr ? &(p[i]) : &(return_updates [i]), p_vals, valid_items);
11031103 __syncthreads ();
11041104 StoreFloat (temp_storage.storef ).Store (&(state1[i]), s1_vals, valid_items);
11051105 }
@@ -1447,7 +1447,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
14471447template <typename T, int OPTIMIZER>
14481448__global__ void
14491449__launch_bounds__ (1024 , 1 )
1450- kOptimizerStatic8bit1State(T* p, T* const g, unsigned char * state1,
1450+ kOptimizerStatic8bit1State(T* p, T* const g, T* return_updates, unsigned char * state1,
14511451 const float *unorm, const float max_unorm, const float param_norm,
14521452 const float beta1, const float beta2,
14531453 const float eps, const int step, const float lr,
@@ -1503,7 +1503,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
15031503 __syncthreads ();
15041504 LoadChar (temp_storage.loadc ).Load (&(state1[i]), c1s, valid_items, 128 );
15051505 __syncthreads ();
1506- LoadT (temp_storage.loadh ).Load (&(p[i]), p_vals, valid_items);
1506+ LoadT (temp_storage.loadh ).Load (return_updates == nullptr ? &(p[i]) : &(return_updates [i]), p_vals, valid_items);
15071507
15081508 if ((i + (threadIdx .x *NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue ; }
15091509
@@ -1513,7 +1513,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
15131513 g_val = float (g_vals[j]);
15141514 g_val *= gnorm_scale;
15151515
1516- if (weight_decay > 0 .0f ) {
1516+ if (weight_decay > 0 .0f && return_updates == nullptr ) {
15171517 switch (OPTIMIZER) {
15181518 case ADAGRAD:
15191519 case MOMENTUM:
@@ -1536,15 +1536,15 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
15361536 else
15371537 s1_vals[j] = s1_vals[j]*beta1 + ((float )g_vals[j]);
15381538
1539- p_vals[j] = ((float )p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
1539+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) + (-lr*update_scale*(s1_vals[j]));
15401540 break ;
15411541 case LION:
1542- p_vals[j] = ((float )p_vals[j]) - (lr*sgn (((float )s1_vals[j])*beta1 + ((1 .0f -beta1)*((float )g_val))));
1542+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - (lr*sgn (((float )s1_vals[j])*beta1 + ((1 .0f -beta1)*((float )g_val))));
15431543 s1_vals[j] = s1_vals[j]*beta2 + ((1 .0f -beta2)*g_val);
15441544 break ;
15451545 case RMSPROP:
15461546 s1_vals[j] = s1_vals[j]*beta1 + ((1 .0f -beta1)*(g_val*g_val));
1547- p_vals[j] = ((float )p_vals[j]) - (lr*__fdividef (g_val,sqrtf (s1_vals[j])+eps));
1547+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - (lr*__fdividef (g_val,sqrtf (s1_vals[j])+eps));
15481548 break ;
15491549 }
15501550
@@ -1560,7 +1560,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
15601560 }
15611561 }
15621562
1563- StoreT (temp_storage.storeh ).Store (&(p[i]), p_vals, valid_items);
1563+ StoreT (temp_storage.storeh ).Store (return_updates == nullptr ? &(p[i]) : &(return_updates [i]), p_vals, valid_items);
15641564 __syncthreads ();
15651565 StoreChar (temp_storage.storec ).Store (&(state1[i]), c1s, valid_items);
15661566 __syncthreads ();
@@ -1893,7 +1893,7 @@ kOptimizerStatic8bit2StateBlockwise(
18931893template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
18941894__launch_bounds__ (256 , 3 )
18951895__global__ void
1896- kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char * state1,
1896+ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, T* return_updates, unsigned char * state1,
18971897 const float beta1, const float beta2,
18981898 const float eps, const int step, const float lr,
18991899 float * __restrict__ const quantiles1,
@@ -1957,7 +1957,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
19571957 __syncthreads ();
19581958 LoadChar (temp_storage.loadc ).Load (&(state1[i]), c1s, valid_items, 128 );
19591959 __syncthreads ();
1960- LoadT (temp_storage.loadh ).Load (&(p[i]), p_vals, valid_items, (T)0 .0f );
1960+ LoadT (temp_storage.loadh ).Load (return_updates == nullptr ? &(p[i]) : &(return_updates [i]), p_vals, valid_items, (T)0 .0f );
19611961
19621962 new_local_abs_max1 = -FLT_MAX;
19631963
@@ -1969,7 +1969,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
19691969 g_val *= gnorm_scale;
19701970 if (!skip_zeros || (skip_zeros && ((float )g_vals[j] != 0 .0f )))
19711971 {
1972- if (weight_decay > 0 .0f ) {
1972+ if (weight_decay > 0 .0f && return_updates == nullptr ) {
19731973 switch (OPTIMIZER) {
19741974 case MOMENTUM:
19751975 case ADAGRAD:
@@ -2032,18 +2032,18 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
20322032 switch (OPTIMIZER)
20332033 {
20342034 case MOMENTUM:
2035- p_vals[j] = ((float )p_vals[j]) - lr*(s1_vals[j]);
2035+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - lr*(s1_vals[j]);
20362036 break ;
20372037 case LION:
2038- p_vals[j] = ((float )p_vals[j]) - ((float )g_vals[j]);
2038+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - ((float )g_vals[j]);
20392039 break ;
20402040 case RMSPROP:
20412041 g_val = g_vals[j];
2042- p_vals[j] = ((float )p_vals[j]) - lr*(__fdividef (g_val, sqrtf (s1_vals[j])+eps));
2042+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - lr*(__fdividef (g_val, sqrtf (s1_vals[j])+eps));
20432043 break ;
20442044 case ADAGRAD:
20452045 g_val = g_vals[j];
2046- p_vals[j] = ((float )p_vals[j]) - lr*(__fdividef (g_val, sqrtf (s1_vals[j])+eps));
2046+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - lr*(__fdividef (g_val, sqrtf (s1_vals[j])+eps));
20472047 break ;
20482048 }
20492049 }
@@ -3782,7 +3782,7 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
37823782MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16)
37833783
37843784#define MAKE_Optimizer32bit1State (oname, gtype ) \
3785- template __global__ void kOptimizer32bit1State <gtype, oname>(gtype* g, gtype* p, float * state1, float *unorm, const float max_unorm, const float param_norm, \
3785+ template __global__ void kOptimizer32bit1State <gtype, oname>(gtype* g, gtype* p, gtype* return_updates, float * state1, float *unorm, const float max_unorm, const float param_norm, \
37863786 const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
37873787
37883788MAKE_Optimizer32bit1State (MOMENTUM, half)
@@ -3847,7 +3847,7 @@ MAKE_PreconditionStatic8bit1State(ADAGRAD, half)
38473847MAKE_PreconditionStatic8bit1State(ADAGRAD, float )
38483848
38493849#define MAKE_optimizerStatic8bit1State (oname, gtype ) \
3850- template __global__ void kOptimizerStatic8bit1State <gtype, oname>(gtype* p, gtype* const g, unsigned char * state1, \
3850+ template __global__ void kOptimizerStatic8bit1State <gtype, oname>(gtype* p, gtype* const g, gtype* return_updates, unsigned char * state1, \
38513851 const float *unorm, const float max_unorm, const float param_norm, \
38523852 const float beta1, \
38533853 const float beta2, \
@@ -4002,7 +4002,7 @@ MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1)
40024002
40034003#define MAKE_OptimizerStatic8bit1StateBlockwise (oname, gtype, block_size, num_per_thread ) \
40044004template __global__ void kOptimizerStatic8bit1StateBlockwise <gtype, oname, block_size, num_per_thread>( \
4005- gtype* p, gtype* __restrict__ const g, unsigned char * state1, \
4005+ gtype* p, gtype* __restrict__ const g, gtype* return_updates, unsigned char * state1, \
40064006 const float beta1, const float beta2, \
40074007 const float eps, const int step, const float lr, \
40084008 float * __restrict__ const quantiles1, \
0 commit comments