@@ -600,23 +600,132 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
600600 exp_reg_correction = SHFL_SYNC(exp_reg_correction, 0);
601601 """
602602 split_weight_update_cpu = """
603+ auto offset_idx = momentum1_offsets_data[feature_begin] + idx;
604+
605+ // Counter update logic with halflife decay
606+ at::acc_type<grad_t, true> freq = 1.0;
607+ at::acc_type<grad_t, true> tail_id_threshold_val = tail_id_threshold;
608+ if (max_counter != 0.0) {
609+ if (is_tail_id_thresh_ratio == 1) {
610+ tail_id_threshold_val = std::floor(tail_id_threshold * max_counter);
611+ }
612+
613+ if (counter_halflife > 0) {
614+ // Decay based on counter_halflife
615+ const auto iter_delta = prev_iter_host[offset_idx] == 0 ? 1.0 : iter * 1.0 - prev_iter_host[offset_idx];
616+ const auto counter_log_rho = std::log(2.0) / counter_halflife;
617+ row_counter_host[offset_idx] = 1.0 + std::exp(-iter_delta * counter_log_rho) * row_counter_host[offset_idx];
618+ } else if (counter_halflife == 0) {
619+ // Count only 1 (appear or not)
620+ row_counter_host[offset_idx] = 1.0;
621+ } else {
622+ // Count raw appearance without decaying
623+ row_counter_host[offset_idx] += 1.0;
624+ }
625+ }
626+ freq = counter_halflife / row_counter_host[offset_idx];
627+
628+ // Compute gradient statistics
603629 at::acc_type<grad_t, true> g_local_sum_square = 0.0;
630+ at::acc_type<grad_t, true> w_local_sum_square = 0.0;
631+
604632 for (int64_t d = 0; d < D; ++d) {
605- g_local_sum_square += grad_buffer[d] * grad_buffer[d];
633+ auto grad = grad_buffer[d];
634+ // For L2 regularization (weight_decay_mode=1), add weight_decay to gradient before other computation
635+ if (weight_decay_mode == 1) {
636+ grad += weight_decay * host_weights_data[embedding_begin + d];
637+ }
638+ g_local_sum_square += grad * grad;
639+
640+ // COW-clip (regularization_mode=4) requires weight norm
641+ if (regularization_mode == 4) {
642+ const auto weight = host_weights_data[embedding_begin + d];
643+ w_local_sum_square += weight * weight;
644+ }
606645 }
607- auto g_avg_square = g_local_sum_square / D;
608- auto offset_idx = momentum1_offsets_data[feature_begin] + idx;
646+
647+ const auto g_sum_square = g_local_sum_square;
648+ const auto g_avg_square = g_sum_square / D;
649+ const auto w_sum_square = w_local_sum_square;
650+
651+ // Update momentum
609652 at::acc_type<grad_t, true> new_sum_square_grads = momentum1_host[offset_idx] + g_avg_square;
610653 momentum1_host[offset_idx] = new_sum_square_grads;
611- at::acc_type<grad_t, true> multiplier;
612- multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
613- const auto iter_delta = iter * 1.0 - prev_iter_host[offset_idx];
654+ const auto multiplier = learning_rate / (std::sqrt(new_sum_square_grads) + eps);
655+ const auto adjustment_enabled = adjustment_iter <= 0 || (adjustment_iter > 0 && iter > adjustment_iter);
656+
657+ // Compute adjusted multiplier and regularization correction
658+ at::acc_type<grad_t, true> adjusted_multiplier = 0.0;
659+ at::acc_type<grad_t, true> exp_reg_correction = 0.0;
660+
661+ if (regularization_mode == 3) {
662+ // Counter-based regularization (regularization_mode=3)
663+ adjusted_multiplier = multiplier;
664+ if (learning_rate_mode >= 0 && adjustment_enabled) {
665+ if (row_counter_host[offset_idx] > tail_id_threshold_val) {
666+ if (learning_rate_mode == 0) {
667+ adjusted_multiplier = multiplier * std::max(std::min(std::pow(max_counter / (row_counter_host[offset_idx] + 1.0), adjustment_ub), 10.0), 1.0);
668+ } else if (learning_rate_mode == 1) {
669+ adjusted_multiplier = multiplier * std::min(std::max(std::pow((row_counter_host[offset_idx] + 1.0) / max_counter, adjustment_ub), 0.1), 1.0);
670+ } else if (learning_rate_mode == 2) {
671+ adjusted_multiplier = learning_rate / (std::sqrt(adjustment_ub * row_counter_host[offset_idx]) + eps);
672+ }
673+ }
674+ }
675+ } else if (regularization_mode == 4) {
676+ // COW-clip (regularization_mode=4)
677+ const auto clip_thresh = row_counter_host[offset_idx] * std::max(weight_norm_coefficient * std::sqrt(w_sum_square), lower_bound);
678+ adjusted_multiplier = std::min(1.0f, static_cast<float>(clip_thresh / std::sqrt(g_sum_square))) * multiplier;
679+ } else {
680+ // Default: no special regularization
681+ adjusted_multiplier = multiplier;
682+ }
683+
684+ // Compute regularization correction
685+ exp_reg_correction = 1.0;
686+ if (regularization_mode == 3) {
687+ // Counter-based regularization (regularization_mode=3)
688+ if (adjustment_enabled) {
689+ if (weight_decay_mode == 3) {
690+ // AdagradW (weight_decay_mode=3)
691+ if (counter_halflife == -1) {
692+ adjusted_multiplier = multiplier * std::sqrt(row_counter_host[offset_idx] * 1.0);
693+ } else if (counter_halflife == -2) {
694+ adjusted_multiplier = std::min(static_cast<float>(learning_rate * std::pow(row_counter_host[offset_idx] * 1.0, 1.0)), adjustment_ub) / (std::sqrt(new_sum_square_grads) + eps);
695+ }
696+ exp_reg_correction = 1.0 - weight_decay * learning_rate;
697+ const auto lazy_delta = prev_iter_host[offset_idx] == 0 ? 1.0 : iter * 1.0 - prev_iter_host[offset_idx];
698+ const auto lazy_multiplier = std::pow(exp_reg_correction, std::min(lazy_delta, iter * 1.0 - adjustment_iter) - 1.0);
699+ adjusted_multiplier *= lazy_multiplier;
700+ exp_reg_correction *= lazy_multiplier;
701+ } else if (weight_decay_mode == 2) {
702+ // Decoupled weight decay (weight_decay_mode=2)
703+ exp_reg_correction = 1.0 - freq * weight_decay * learning_rate;
704+ } else if (weight_decay_mode == 1) {
705+ // L2 regularization (coupled wd)
706+ exp_reg_correction = 1.0 - freq * weight_decay * multiplier;
707+ }
708+ }
709+ } else if (regularization_mode == 4) {
710+ // COW-clip (regularization_mode=4)
711+ if (weight_decay_mode == 2) {
712+ // Decoupled weight decay (weight_decay_mode=2)
713+ exp_reg_correction = 1.0 - weight_decay * learning_rate;
714+ } else if (weight_decay_mode == 1) {
715+ // L2 regularization (coupled wd)
716+ exp_reg_correction = 1.0 - weight_decay * adjusted_multiplier;
717+ }
718+ } else {
719+ // Default regularization
720+ exp_reg_correction = 1.0;
721+ }
722+
723+ // Update prev_iter
614724 prev_iter_host[offset_idx] = iter * 1.0;
615- const auto exp_reg = 1.0 / (weight_decay * multiplier + 1.0);
616- const auto exp_reg_correction = powf(exp_reg, iter_delta);
725+
726+ // Apply weight updates
617727 for (int64_t d = 0; d < D; ++d) {
618- const auto weight = host_weights_data[embedding_begin + d];
619- host_weights_data[embedding_begin + d] = exp_reg_correction * weight - exp_reg * multiplier * grad_buffer[d];
728+ host_weights_data[embedding_begin + d] = exp_reg_correction * host_weights_data[embedding_begin + d] - adjusted_multiplier * grad_buffer[d];
620729 }
621730 """
622731
0 commit comments