Skip to content

Commit ddb9575

Browse files
authored
Fix moe bug. (kvcache-ai#1783)
* [fix]: fix moe.hpp load from file bug. * [fix]: fix all moe hpp init bug. * [fix]: fix moe & awq-moe ug.
1 parent dc6394e commit ddb9575

File tree

2 files changed

+4
-14
lines changed

2 files changed

+4
-14
lines changed

kt-kernel/operators/amx/awq-moe.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {
4242
using Base::up_bb_;
4343
using Base::up_bc_;
4444

45-
std::filesystem::path prefix;
46-
4745
#ifdef CHECK
4846
char verify_bb[100000000];
4947
char check_bb[100000000];
@@ -406,7 +404,7 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {
406404
auto& load = config_.load;
407405
auto& save = config_.save;
408406

409-
prefix = config_.path;
407+
std::filesystem::path prefix = config_.path;
410408
prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx));
411409
if (save) {
412410
std::cout << "Creating " << prefix << std::endl;
@@ -498,6 +496,9 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {
498496
throw std::runtime_error("AMX load weights from gate_projs is not supported");
499497
} else {
500498
int nth = T::recommended_nth(config_.intermediate_size);
499+
std::filesystem::path prefix = config_.path;
500+
prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx));
501+
501502
if (config_.load) {
502503
throw std::runtime_error("AMX load weights from file is not supported");
503504
}

kt-kernel/operators/amx/moe.hpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,6 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
3232
using Base::up_bb_;
3333
using Base::up_bc_;
3434

35-
void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
36-
// quantized)]
37-
void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
38-
// quantized)]
39-
void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if
40-
// quantized)]
41-
4235
#ifdef CHECK
4336
char verify_bb[100000000];
4437
char check_bb[100000000];
@@ -159,10 +152,6 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
159152
throw std::runtime_error("Path not found: " + prefix.string());
160153
}
161154
}
162-
163-
gate_proj_ = config_.gate_proj;
164-
up_proj_ = config_.up_proj;
165-
down_proj_ = config_.down_proj;
166155
}
167156

168157
~AMX_MOE_TP() = default;

0 commit comments

Comments
 (0)