@@ -21,18 +21,16 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
2121 private:
2222 using Base = AMX_MOE_BASE<T, AMX_MOE_TP<T>>;
2323 using Base::config_;
24- using Base::tp_part_idx;
25- using Base::gate_bb_;
26- using Base::up_bb_;
27- using Base::down_bb_;
28- using Base::gate_up_ba_;
29- using Base::gate_bc_;
30- using Base::up_bc_;
3124 using Base::down_ba_;
25+ using Base::down_bb_;
3226 using Base::down_bc_;
27+ using Base::gate_bb_;
28+ using Base::gate_bc_;
29+ using Base::gate_up_ba_;
3330 using Base::m_local_num_;
34-
35- std::filesystem::path prefix;
31+ using Base::tp_part_idx;
32+ using Base::up_bb_;
33+ using Base::up_bc_;
3634
3735 void * gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
3836 // quantized)]
@@ -140,11 +138,15 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
140138 AMX_MOE_TP () = default ;
141139
142140 AMX_MOE_TP (GeneralMOEConfig config, int tp_part_idx = 0 ) : Base(config, tp_part_idx) {
141+ // Initialization now happens in derived_init() which is called by base constructor
142+ }
143+
144+ void derived_init () {
143145 printf (" Creating AMX_MOE_TP %d at numa %d\n " , tp_part_idx, numa_node_of_cpu (sched_getcpu ()));
144146 auto & load = config_.load ;
145147 auto & save = config_.save ;
146148
147- prefix = config_.path ;
149+ std::filesystem::path prefix = config_.path ;
148150 prefix = prefix / (" _layer_" + std::to_string (config_.layer_idx )) / (" _numa_" + std::to_string (tp_part_idx));
149151 if (save) {
150152 std::cout << " Creating " << prefix << std::endl;
@@ -169,15 +171,9 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
169171 // CRTP buffer creation - no group_size
170172 // ============================================================================
171173
172- size_t buffer_a_required_size_impl (size_t m, size_t k) const {
173- return T::BufferA::required_size (m, k);
174- }
175- size_t buffer_b_required_size_impl (size_t n, size_t k) const {
176- return T::BufferB::required_size (n, k);
177- }
178- size_t buffer_c_required_size_impl (size_t m, size_t n) const {
179- return T::BufferC::required_size (m, n);
180- }
174+ size_t buffer_a_required_size_impl (size_t m, size_t k) const { return T::BufferA::required_size (m, k); }
175+ size_t buffer_b_required_size_impl (size_t n, size_t k) const { return T::BufferB::required_size (n, k); }
176+ size_t buffer_c_required_size_impl (size_t m, size_t n) const { return T::BufferC::required_size (m, n); }
181177
182178 std::shared_ptr<typename T::BufferA> make_buffer_a_impl (size_t m, size_t k, void * data) const {
183179 return std::make_shared<typename T::BufferA>(m, k, data);
@@ -260,6 +256,9 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
260256 } else {
261257 int nth = T::recommended_nth (config_.intermediate_size );
262258 static uint8_t mat_type_all = 3 , mat_split = 1 ;
259+ std::filesystem::path prefix = config_.path ;
260+ prefix = prefix / (" _layer_" + std::to_string (config_.layer_idx )) / (" _numa_" + std::to_string (tp_part_idx));
261+
263262 if (config_.load ) {
264263 std::cout << " Loading from " << prefix << std::endl;
265264 for (int task_id = 0 ; task_id < config_.expert_num * mat_type_all * mat_split; task_id++) {
@@ -335,7 +334,7 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
335334 if (config_.save ) {
336335 pool->do_work_stealing_job (
337336 config_.expert_num * mat_type_all, nullptr ,
338- [this , physical_to_logical_map](int task_id) {
337+ [this , physical_to_logical_map, prefix ](int task_id) {
339338 int64_t expert_idx = task_id / mat_type_all;
340339 expert_idx = expert_map (physical_to_logical_map, expert_idx);
341340 uint8_t mat_class = task_id % mat_type_all;
@@ -426,7 +425,7 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_MOE_TP<K>>> {
426425
427426 this ->weights_loaded = true ;
428427 } else if (config.path != " " ) {
429- printf (" TP Load from file\n " );
428+ printf (" TP Load from file %s \n " , config. path . c_str () );
430429 DO_TPS_LOAD_WEIGHTS (pool);
431430 this ->weights_loaded = true ;
432431 } else {
0 commit comments