Skip to content

Commit 51841e3

Browse files
gchalumpfacebook-github-bot
authored andcommitted
Add support rowwise_adagrad_wtith_counter on CPU (#5146)
Summary: X-link: facebookresearch/FBGEMM#2145 Initial support has been added in D81998586. Differential Revision: D87104079
1 parent 903002a commit 51841e3

File tree

1 file changed

+119
-10
lines changed

1 file changed

+119
-10
lines changed

fbgemm_gpu/codegen/genscript/optimizers.py

Lines changed: 119 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -600,23 +600,132 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
600600
exp_reg_correction = SHFL_SYNC(exp_reg_correction, 0);
601601
"""
602602
split_weight_update_cpu = """
603+
auto offset_idx = momentum1_offsets_data[feature_begin] + idx;
604+
605+
// Counter update logic with halflife decay
606+
at::acc_type<grad_t, true> freq = 1.0;
607+
at::acc_type<grad_t, true> tail_id_threshold_val = tail_id_threshold;
608+
if (max_counter != 0.0) {
609+
if (is_tail_id_thresh_ratio == 1) {
610+
tail_id_threshold_val = std::floor(tail_id_threshold * max_counter);
611+
}
612+
613+
if (counter_halflife > 0) {
614+
// Decay based on counter_halflife
615+
const auto iter_delta = prev_iter_host[offset_idx] == 0 ? 1.0 : iter * 1.0 - prev_iter_host[offset_idx];
616+
const auto counter_log_rho = std::log(2.0) / counter_halflife;
617+
row_counter_host[offset_idx] = 1.0 + std::exp(-iter_delta * counter_log_rho) * row_counter_host[offset_idx];
618+
} else if (counter_halflife == 0) {
619+
// Count only 1 (appear or not)
620+
row_counter_host[offset_idx] = 1.0;
621+
} else {
622+
// Count raw appearance without decaying
623+
row_counter_host[offset_idx] += 1.0;
624+
}
625+
}
626+
freq = counter_halflife / row_counter_host[offset_idx];
627+
628+
// Compute gradient statistics
603629
at::acc_type<grad_t, true> g_local_sum_square = 0.0;
630+
at::acc_type<grad_t, true> w_local_sum_square = 0.0;
631+
604632
for (int64_t d = 0; d < D; ++d) {
605-
g_local_sum_square += grad_buffer[d] * grad_buffer[d];
633+
auto grad = grad_buffer[d];
634+
// For L2 regularization (weight_decay_mode=1), add weight_decay to gradient before other computation
635+
if (weight_decay_mode == 1) {
636+
grad += weight_decay * host_weights_data[embedding_begin + d];
637+
}
638+
g_local_sum_square += grad * grad;
639+
640+
// COW-clip (regularization_mode=4) requires weight norm
641+
if (regularization_mode == 4) {
642+
const auto weight = host_weights_data[embedding_begin + d];
643+
w_local_sum_square += weight * weight;
644+
}
606645
}
607-
auto g_avg_square = g_local_sum_square / D;
608-
auto offset_idx = momentum1_offsets_data[feature_begin] + idx;
646+
647+
const auto g_sum_square = g_local_sum_square;
648+
const auto g_avg_square = g_sum_square / D;
649+
const auto w_sum_square = w_local_sum_square;
650+
651+
// Update momentum
609652
at::acc_type<grad_t, true> new_sum_square_grads = momentum1_host[offset_idx] + g_avg_square;
610653
momentum1_host[offset_idx] = new_sum_square_grads;
611-
at::acc_type<grad_t, true> multiplier;
612-
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
613-
const auto iter_delta = iter * 1.0 - prev_iter_host[offset_idx];
654+
const auto multiplier = learning_rate / (std::sqrt(new_sum_square_grads) + eps);
655+
const auto adjustment_enabled = adjustment_iter <= 0 || (adjustment_iter > 0 && iter > adjustment_iter);
656+
657+
// Compute adjusted multiplier and regularization correction
658+
at::acc_type<grad_t, true> adjusted_multiplier = 0.0;
659+
at::acc_type<grad_t, true> exp_reg_correction = 0.0;
660+
661+
if (regularization_mode == 3) {
662+
// Counter-based regularization (regularization_mode=3)
663+
adjusted_multiplier = multiplier;
664+
if (learning_rate_mode >= 0 && adjustment_enabled) {
665+
if (row_counter_host[offset_idx] > tail_id_threshold_val) {
666+
if (learning_rate_mode == 0) {
667+
adjusted_multiplier = multiplier * std::max(std::min(std::pow(max_counter / (row_counter_host[offset_idx] + 1.0), adjustment_ub), 10.0), 1.0);
668+
} else if (learning_rate_mode == 1) {
669+
adjusted_multiplier = multiplier * std::min(std::max(std::pow((row_counter_host[offset_idx] + 1.0) / max_counter, adjustment_ub), 0.1), 1.0);
670+
} else if (learning_rate_mode == 2) {
671+
adjusted_multiplier = learning_rate / (std::sqrt(adjustment_ub * row_counter_host[offset_idx]) + eps);
672+
}
673+
}
674+
}
675+
} else if (regularization_mode == 4) {
676+
// COW-clip (regularization_mode=4)
677+
const auto clip_thresh = row_counter_host[offset_idx] * std::max(weight_norm_coefficient * std::sqrt(w_sum_square), lower_bound);
678+
adjusted_multiplier = std::min(1.0f, static_cast<float>(clip_thresh / std::sqrt(g_sum_square))) * multiplier;
679+
} else {
680+
// Default: no special regularization
681+
adjusted_multiplier = multiplier;
682+
}
683+
684+
// Compute regularization correction
685+
exp_reg_correction = 1.0;
686+
if (regularization_mode == 3) {
687+
// Counter-based regularization (regularization_mode=3)
688+
if (adjustment_enabled) {
689+
if (weight_decay_mode == 3) {
690+
// AdagradW (weight_decay_mode=3)
691+
if (counter_halflife == -1) {
692+
adjusted_multiplier = multiplier * std::sqrt(row_counter_host[offset_idx] * 1.0);
693+
} else if (counter_halflife == -2) {
694+
adjusted_multiplier = std::min(static_cast<float>(learning_rate * std::pow(row_counter_host[offset_idx] * 1.0, 1.0)), adjustment_ub) / (std::sqrt(new_sum_square_grads) + eps);
695+
}
696+
exp_reg_correction = 1.0 - weight_decay * learning_rate;
697+
const auto lazy_delta = prev_iter_host[offset_idx] == 0 ? 1.0 : iter * 1.0 - prev_iter_host[offset_idx];
698+
const auto lazy_multiplier = std::pow(exp_reg_correction, std::min(lazy_delta, iter * 1.0 - adjustment_iter) - 1.0);
699+
adjusted_multiplier *= lazy_multiplier;
700+
exp_reg_correction *= lazy_multiplier;
701+
} else if (weight_decay_mode == 2) {
702+
// Decoupled weight decay (weight_decay_mode=2)
703+
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate;
704+
} else if (weight_decay_mode == 1) {
705+
// L2 regularization (coupled wd)
706+
exp_reg_correction = 1.0 - freq * weight_decay * multiplier;
707+
}
708+
}
709+
} else if (regularization_mode == 4) {
710+
// COW-clip (regularization_mode=4)
711+
if (weight_decay_mode == 2) {
712+
// Decoupled weight decay (weight_decay_mode=2)
713+
exp_reg_correction = 1.0 - weight_decay * learning_rate;
714+
} else if (weight_decay_mode == 1) {
715+
// L2 regularization (coupled wd)
716+
exp_reg_correction = 1.0 - weight_decay * adjusted_multiplier;
717+
}
718+
} else {
719+
// Default regularization
720+
exp_reg_correction = 1.0;
721+
}
722+
723+
// Update prev_iter
614724
prev_iter_host[offset_idx] = iter * 1.0;
615-
const auto exp_reg = 1.0 / (weight_decay * multiplier + 1.0);
616-
const auto exp_reg_correction = powf(exp_reg, iter_delta);
725+
726+
// Apply weight updates
617727
for (int64_t d = 0; d < D; ++d) {
618-
const auto weight = host_weights_data[embedding_begin + d];
619-
host_weights_data[embedding_begin + d] = exp_reg_correction * weight - exp_reg * multiplier * grad_buffer[d];
728+
host_weights_data[embedding_begin + d] = exp_reg_correction * host_weights_data[embedding_begin + d] - adjusted_multiplier * grad_buffer[d];
620729
}
621730
"""
622731

0 commit comments

Comments
 (0)