File tree Expand file tree Collapse file tree 2 files changed +4
-14
lines changed
Expand file tree Collapse file tree 2 files changed +4
-14
lines changed Original file line number Diff line number Diff line change @@ -42,8 +42,6 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {
4242 using Base::up_bb_;
4343 using Base::up_bc_;
4444
45- std::filesystem::path prefix;
46-
4745#ifdef CHECK
4846 char verify_bb[100000000 ];
4947 char check_bb[100000000 ];
@@ -406,7 +404,7 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {
406404 auto & load = config_.load ;
407405 auto & save = config_.save ;
408406
409- prefix = config_.path ;
407+ std::filesystem::path prefix = config_.path ;
410408 prefix = prefix / (" _layer_" + std::to_string (config_.layer_idx )) / (" _numa_" + std::to_string (tp_part_idx));
411409 if (save) {
412410 std::cout << " Creating " << prefix << std::endl;
@@ -498,6 +496,9 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {
498496 throw std::runtime_error (" AMX load weights from gate_projs is not supported" );
499497 } else {
500498 int nth = T::recommended_nth (config_.intermediate_size );
499+ std::filesystem::path prefix = config_.path ;
500+ prefix = prefix / (" _layer_" + std::to_string (config_.layer_idx )) / (" _numa_" + std::to_string (tp_part_idx));
501+
501502 if (config_.load ) {
502503 throw std::runtime_error (" AMX load weights from file is not supported" );
503504 }
Original file line number Diff line number Diff line change @@ -32,13 +32,6 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
3232 using Base::up_bb_;
3333 using Base::up_bc_;
3434
35- void * gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
36- // quantized)]
37- void * up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
38- // quantized)]
39- void * down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if
40- // quantized)]
41-
4235#ifdef CHECK
4336 char verify_bb[100000000 ];
4437 char check_bb[100000000 ];
@@ -159,10 +152,6 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
159152 throw std::runtime_error (" Path not found: " + prefix.string ());
160153 }
161154 }
162-
163- gate_proj_ = config_.gate_proj ;
164- up_proj_ = config_.up_proj ;
165- down_proj_ = config_.down_proj ;
166155 }
167156
168157 ~AMX_MOE_TP () = default ;
You can’t perform that action at this time.
0 commit comments