Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 119 additions & 10 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,23 +600,132 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
exp_reg_correction = SHFL_SYNC(exp_reg_correction, 0);
"""
split_weight_update_cpu = """
auto offset_idx = momentum1_offsets_data[feature_begin] + idx;

// Counter update logic with halflife decay
at::acc_type<grad_t, true> freq = 1.0;
at::acc_type<grad_t, true> tail_id_threshold_val = tail_id_threshold;
if (max_counter != 0.0) {
if (is_tail_id_thresh_ratio == 1) {
tail_id_threshold_val = std::floor(tail_id_threshold * max_counter);
}

if (counter_halflife > 0) {
// Decay based on counter_halflife
const auto iter_delta = prev_iter_host[offset_idx] == 0 ? 1.0 : iter * 1.0 - prev_iter_host[offset_idx];
const auto counter_log_rho = std::log(2.0) / counter_halflife;
row_counter_host[offset_idx] = 1.0 + std::exp(-iter_delta * counter_log_rho) * row_counter_host[offset_idx];
} else if (counter_halflife == 0) {
// Count only 1 (appear or not)
row_counter_host[offset_idx] = 1.0;
} else {
// Count raw appearance without decaying
row_counter_host[offset_idx] += 1.0;
}
}
freq = counter_halflife / row_counter_host[offset_idx];

// Compute gradient statistics
at::acc_type<grad_t, true> g_local_sum_square = 0.0;
at::acc_type<grad_t, true> w_local_sum_square = 0.0;

for (int64_t d = 0; d < D; ++d) {
g_local_sum_square += grad_buffer[d] * grad_buffer[d];
auto grad = grad_buffer[d];
// For L2 regularization (weight_decay_mode=1), add weight_decay to gradient before other computation
if (weight_decay_mode == 1) {
grad += weight_decay * host_weights_data[embedding_begin + d];
}
g_local_sum_square += grad * grad;

// COW-clip (regularization_mode=4) requires weight norm
if (regularization_mode == 4) {
const auto weight = host_weights_data[embedding_begin + d];
w_local_sum_square += weight * weight;
}
}
auto g_avg_square = g_local_sum_square / D;
auto offset_idx = momentum1_offsets_data[feature_begin] + idx;

const auto g_sum_square = g_local_sum_square;
const auto g_avg_square = g_sum_square / D;
const auto w_sum_square = w_local_sum_square;

// Update momentum
at::acc_type<grad_t, true> new_sum_square_grads = momentum1_host[offset_idx] + g_avg_square;
momentum1_host[offset_idx] = new_sum_square_grads;
at::acc_type<grad_t, true> multiplier;
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
const auto iter_delta = iter * 1.0 - prev_iter_host[offset_idx];
const auto multiplier = learning_rate / (std::sqrt(new_sum_square_grads) + eps);
const auto adjustment_enabled = adjustment_iter <= 0 || (adjustment_iter > 0 && iter > adjustment_iter);

// Compute adjusted multiplier and regularization correction
at::acc_type<grad_t, true> adjusted_multiplier = 0.0;
at::acc_type<grad_t, true> exp_reg_correction = 0.0;

if (regularization_mode == 3) {
// Counter-based regularization (regularization_mode=3)
adjusted_multiplier = multiplier;
if (learning_rate_mode >= 0 && adjustment_enabled) {
if (row_counter_host[offset_idx] > tail_id_threshold_val) {
if (learning_rate_mode == 0) {
adjusted_multiplier = multiplier * std::max(std::min(std::pow(max_counter / (row_counter_host[offset_idx] + 1.0), adjustment_ub), 10.0), 1.0);
} else if (learning_rate_mode == 1) {
adjusted_multiplier = multiplier * std::min(std::max(std::pow((row_counter_host[offset_idx] + 1.0) / max_counter, adjustment_ub), 0.1), 1.0);
} else if (learning_rate_mode == 2) {
adjusted_multiplier = learning_rate / (std::sqrt(adjustment_ub * row_counter_host[offset_idx]) + eps);
}
}
}
} else if (regularization_mode == 4) {
// COW-clip (regularization_mode=4)
const auto clip_thresh = row_counter_host[offset_idx] * std::max(weight_norm_coefficient * std::sqrt(w_sum_square), lower_bound);
adjusted_multiplier = std::min(1.0f, static_cast<float>(clip_thresh / std::sqrt(g_sum_square))) * multiplier;
} else {
// Default: no special regularization
adjusted_multiplier = multiplier;
}

// Compute regularization correction
exp_reg_correction = 1.0;
if (regularization_mode == 3) {
// Counter-based regularization (regularization_mode=3)
if (adjustment_enabled) {
if (weight_decay_mode == 3) {
// AdagradW (weight_decay_mode=3)
if (counter_halflife == -1) {
adjusted_multiplier = multiplier * std::sqrt(row_counter_host[offset_idx] * 1.0);
} else if (counter_halflife == -2) {
adjusted_multiplier = std::min(static_cast<float>(learning_rate * std::pow(row_counter_host[offset_idx] * 1.0, 1.0)), adjustment_ub) / (std::sqrt(new_sum_square_grads) + eps);
}
exp_reg_correction = 1.0 - weight_decay * learning_rate;
const auto lazy_delta = prev_iter_host[offset_idx] == 0 ? 1.0 : iter * 1.0 - prev_iter_host[offset_idx];
const auto lazy_multiplier = std::pow(exp_reg_correction, std::min(lazy_delta, iter * 1.0 - adjustment_iter) - 1.0);
adjusted_multiplier *= lazy_multiplier;
exp_reg_correction *= lazy_multiplier;
} else if (weight_decay_mode == 2) {
// Decoupled weight decay (weight_decay_mode=2)
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate;
} else if (weight_decay_mode == 1) {
// L2 regularization (coupled wd)
exp_reg_correction = 1.0 - freq * weight_decay * multiplier;
}
}
} else if (regularization_mode == 4) {
// COW-clip (regularization_mode=4)
if (weight_decay_mode == 2) {
// Decoupled weight decay (weight_decay_mode=2)
exp_reg_correction = 1.0 - weight_decay * learning_rate;
} else if (weight_decay_mode == 1) {
// L2 regularization (coupled wd)
exp_reg_correction = 1.0 - weight_decay * adjusted_multiplier;
}
} else {
// Default regularization
exp_reg_correction = 1.0;
}

// Update prev_iter
prev_iter_host[offset_idx] = iter * 1.0;
const auto exp_reg = 1.0 / (weight_decay * multiplier + 1.0);
const auto exp_reg_correction = powf(exp_reg, iter_delta);

// Apply weight updates
for (int64_t d = 0; d < D; ++d) {
const auto weight = host_weights_data[embedding_begin + d];
host_weights_data[embedding_begin + d] = exp_reg_correction * weight - exp_reg * multiplier * grad_buffer[d];
host_weights_data[embedding_begin + d] = exp_reg_correction * host_weights_data[embedding_begin + d] - adjusted_multiplier * grad_buffer[d];
}
"""

Expand Down