Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 223 additions & 2 deletions kt-kernel/operators/amx/k2-moe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,229 @@ class AMX_K2_MOE_TP {

void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
void* output) {
for (int i = 0; i < qlen; i ++)
forward_decode(k, expert_ids + i * k, weights + i * k, (ggml_bf16_t*)input + i * config_.hidden_size, (float*)output + i * config_.hidden_size);
auto pool = config_.pool->get_subpool(tp_part_idx);
auto& quant_config = config_.quant_config;
int& group_size = quant_config.group_size;
#ifdef FORWARD_TIME_PROFILE
auto start_time = std::chrono::high_resolution_clock::now();
auto last = start_time;
// 用于保存各阶段耗时(单位:微秒)
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
int max_local_num = 0; // 记录最大的 local num
#endif

int activated_expert = 0;
for (int i = 0; i < config_.expert_num; i++) {
m_local_num_[i] = 0;
}
for (int i = 0; i < qlen; i++) {
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
}
}

for (int i = 0; i < config_.expert_num; i++) {
if (m_local_num_[i] > 0) {
#ifdef FORWARD_TIME_PROFILE
max_local_num = std::max(max_local_num, m_local_num_[i]);
#endif
m_expert_id_map_[activated_expert] = i;
activated_expert++;
}
}

// activated_expert 已经统计完成

size_t offset = 0;
for (int i = 0; i < config_.expert_num; i++) {
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
offset += m_local_num_[i];
}
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
prepare_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif

DIRECT_OR_POOL_BY_QLEN(qlen, [&](int i) {
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
(ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
}
});

#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
cpy_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif

DIRECT_OR_POOL_BY_QLEN(activated_expert, [this](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
});

#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif

int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * activated_expert * 2, [](int _) { T::config(); },
[this, nth, qlen](int task_id2) {
int& group_size = config_.quant_config.group_size;
int task_id = task_id2 / 2;
bool do_up = task_id2 % 2;
int expert_idx = m_expert_id_map_[task_id / nth];

int ith = task_id % nth;
if (do_up) {
MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
group_size, gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx],
ith, nth);
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
} else {
MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
group_size, gate_up_ba_[expert_idx], gate_bb_[expert_idx],
gate_bc_[expert_idx], ith, nth);
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
}
},
nullptr);

#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif

auto up_gate_fn = [this, nth](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = amx::act_fn(gate_val0, up_val0);
__m512 result1 = amx::act_fn(gate_val1, up_val1);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
}
}
};
DIRECT_OR_POOL_BY_QLEN(nth * activated_expert, up_gate_fn);

#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif

pool->do_work_stealing_job(
activated_expert, nullptr,
[this](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif

nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * activated_expert, [](int _) { T::config(); },
[this, nth, qlen](int task_id) {
int& group_size = config_.quant_config.group_size;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size,
group_size, down_ba_[expert_idx], down_bb_[expert_idx], down_bc_[expert_idx],
ith, nth);
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
},
nullptr);

#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif

pool->do_work_stealing_job(
qlen, nullptr,
[this, nth, output, k, expert_ids, weights](int i) {
for (int e = 0; e < config_.hidden_size; e += 32) {
__m512 x0 = _mm512_setzero_ps();
__m512 x1 = _mm512_setzero_ps();
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
__m512 down_output0, down_output1;
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
m_local_pos_[i][j] * config_.hidden_size + e),
&down_output0, &down_output1);
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
}
auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e);
f32out[0] = x0;
f32out[1] = x1;
Comment on lines +715 to +717
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The direct cast to (__m512*) and subsequent write operations assume that the output buffer is 64-byte aligned. If a non-aligned buffer is provided by the caller, this will result in a segmentation fault. The warm_up function in this file, for instance, uses a std::vector for the output buffer, which does not guarantee alignment, highlighting a scenario where this could fail. To enhance robustness, it is recommended to use unaligned store intrinsics.

            auto f32out = (float*)output + i * config_.hidden_size + e;
            _mm512_storeu_ps(f32out, x0);
            _mm512_storeu_ps(f32out + 16, x1);

}
},
nullptr);

#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
auto end_time = std::chrono::high_resolution_clock::now();
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
// 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen
printf(
"Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, "
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: "
"%d, qlen: %d\n",
tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time,
down_time, weight_time, forward_total_time, max_local_num, qlen);
#endif
// for (int i = 0; i < qlen; i ++)
// forward_decode(k, expert_ids + i * k, weights + i * k, (ggml_bf16_t*)input + i * config_.hidden_size, (float*)output + i * config_.hidden_size);
}

void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
Expand Down
4 changes: 2 additions & 2 deletions kt-kernel/operators/amx/la/amx_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2905,7 +2905,7 @@ struct GemmKernel224Int4SmallKGroup {
static inline void integer_mat_vec_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB *bb, BufferC* bc, int ith, int nth) {
auto [n_start, n_end] = split_range_n(n, ith, nth);
for (int m_begin = 0; m_begin < m; m_begin ++) {
float* c = bc->get_submat(m, n, m_begin, 0);
float* c = bc->get_submat(m, n, m_begin, n_start);
__m512i* a512 = (__m512i*)ba->get_submat(m, k, m_begin, 0);

for (int n_block_begin = n_start; n_block_begin < n_end; n_block_begin ++) {
Expand All @@ -2929,7 +2929,7 @@ struct GemmKernel224Int4SmallKGroup {
WORK_K_BLOCK(k_block + 1);
}

c[n_block_begin] = _mm512_reduce_add_ps(sum) / 16;
c[n_block_begin - n_start] = _mm512_reduce_add_ps(sum) / 16;
}
}
}
Expand Down