Skip to content

Commit 12dde6b

Browse files
authored
refactor: fix header redefinition and use to define weight. (#291)
Signed-off-by: Tao Peng <[email protected]>
1 parent 0b6c9e3 commit 12dde6b

File tree

9 files changed

+44
-29
lines changed

9 files changed

+44
-29
lines changed

xllm/core/common/interruption_bus.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
118
#include <functional>
219
#include <vector>
320

xllm/core/common/mspti_helper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#pragma once
17+
1618
#include <cstdint>
1719
#include <iostream>
1820

xllm/core/distributed_runtime/remote_worker.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class RemoteWorker : public WorkerClient {
142142
ThreadPool threadpool_;
143143
// general working thread
144144
// do some overlap work with model execute
145-
ThreadPool general_threadpool_{5};
145+
ThreadPool general_threadpool_{4};
146146
const torch::Device device_;
147147
};
148148
} // namespace xllm

xllm/core/distributed_runtime/worker_service.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class WorkerService : public proto::DistributeWorker {
149149

150150
std::unique_ptr<std::thread> polling_thread_;
151151

152-
ThreadPool threadpool_{5};
152+
ThreadPool threadpool_{4};
153153
};
154154

155155
} // namespace xllm

xllm/core/framework/xtensor/xtensor_manager_service.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ class XTensorManagerService : public proto::DistributeXTensorManager {
7070
int32_t global_rank_;
7171
int32_t world_size_;
7272
torch::Device device_;
73-
ThreadPool threadpool_{5};
73+
ThreadPool threadpool_{4};
7474
std::unique_ptr<XTensorManager> xtensor_manager_;
7575
};
7676

77-
} // namespace xllm
77+
} // namespace xllm

xllm/core/layers/common/fuse_norm.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include <torch/torch.h>
1919

2020
#include "framework/state_dict/state_dict.h"
21+
#include "framework/state_dict/utils.h"
2122

2223
namespace xllm {
2324
namespace layer {
@@ -33,7 +34,7 @@ class FusedRMSNormImpl : public torch::nn::Module {
3334
void load_state_dict(const StateDict& state_dict);
3435

3536
private:
36-
torch::Tensor weight_;
37+
DEFINE_WEIGHT(weight);
3738
int64_t norm_dim_;
3839
double eps_;
3940
};

xllm/core/layers/common/word_embedding_impl.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,14 @@ class WordEmbeddingImpl : public torch::nn::Module {
7272
CHECK_EQ(weight_.sizes(), weight.sizes())
7373
<< "weight size mismatch for " << name();
7474
weight_.copy_(weight);
75-
is_loaded_ = true;
75+
weight_is_loaded_ = true;
7676
}
7777
}
7878

7979
// whether the weight is loaded
8080
void verify_loaded_weights(const std::string& prefix) const {
81-
CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight";
81+
CHECK(weight_is_loaded_)
82+
<< "weight is not loaded for " << prefix + "weight";
8283
}
8384

8485
void pretty_print(std::ostream& stream) const override {
@@ -94,11 +95,9 @@ class WordEmbeddingImpl : public torch::nn::Module {
9495

9596
// world size
9697
PROPERTY(int32_t, world_size) = 0;
97-
// parameter members, must be registered
98-
torch::Tensor weight_{nullptr};
9998

100-
// whether the weight is loaded
101-
bool is_loaded_ = false;
99+
// parameter members, must be registered
100+
DEFINE_WEIGHT(weight);
102101

103102
// parallel args
104103
ParallelArgs parallel_args_;

xllm/core/layers/multi_head_attention.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,37 +71,37 @@ void MultiheadAttentionImpl::load_state_dict(const StateDict& state_dict) {
7171
const auto in_proj_weight = state_dict.get_tensor("in_proj_weight");
7272
if (in_proj_weight.defined()) {
7373
in_proj_weight_ = in_proj_weight.to(options_);
74-
is_in_proj_weight_loaded_ = true;
74+
in_proj_weight_is_loaded_ = true;
7575
}
7676

7777
const auto in_proj_bias = state_dict.get_tensor("in_proj_bias");
7878
if (in_proj_bias.defined()) {
7979
in_proj_bias_ = in_proj_bias.to(options_);
80-
is_in_proj_bias_loaded_ = true;
80+
in_proj_bias_is_loaded_ = true;
8181
}
8282

8383
const auto out_proj_weight = state_dict.get_tensor("out_proj.weight");
8484
if (out_proj_weight.defined()) {
8585
out_proj_weight_ = out_proj_weight.to(options_);
86-
is_out_proj_weight_loaded_ = true;
86+
out_proj_weight_is_loaded_ = true;
8787
}
8888

8989
const auto out_proj_bias = state_dict.get_tensor("out_proj.bias");
9090
if (out_proj_bias.defined()) {
9191
out_proj_bias_ = out_proj_bias.to(options_);
92-
is_out_proj_bias_loaded_ = true;
92+
out_proj_bias_is_loaded_ = true;
9393
}
9494
}
9595

9696
void MultiheadAttentionImpl::verify_loaded_weights(
9797
const std::string& prefix) const {
98-
CHECK(is_in_proj_weight_loaded_)
98+
CHECK(in_proj_weight_is_loaded_)
9999
<< "in_proj_weight is not loaded for " << prefix + "in_proj_weight";
100-
CHECK(is_in_proj_bias_loaded_)
100+
CHECK(in_proj_bias_is_loaded_)
101101
<< "in_proj_bias is not loaded for " << prefix + "in_proj_bias";
102-
CHECK(is_out_proj_weight_loaded_)
102+
CHECK(out_proj_weight_is_loaded_)
103103
<< "out_proj.weight is not loaded for " << prefix + "out_proj.weight";
104-
CHECK(is_out_proj_bias_loaded_)
104+
CHECK(out_proj_bias_is_loaded_)
105105
<< "out_proj.bias is not loaded for " << prefix + "out_proj.bias";
106106
}
107107

xllm/core/layers/multi_head_attention.h

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919

2020
#include "framework/model_context.h"
2121
#include "framework/state_dict/state_dict.h"
22+
#include "framework/state_dict/utils.h"
2223

2324
namespace xllm {
2425
namespace layer {
@@ -42,18 +43,13 @@ class MultiheadAttentionImpl : public torch::nn::Module {
4243
int64_t hidden_size_;
4344
torch::TensorOptions options_;
4445

45-
torch::Tensor in_proj_weight_;
46-
torch::Tensor in_proj_bias_;
47-
torch::Tensor out_proj_weight_;
48-
torch::Tensor out_proj_bias_;
49-
50-
bool is_in_proj_weight_loaded_;
51-
bool is_in_proj_bias_loaded_;
52-
bool is_out_proj_weight_loaded_;
53-
bool is_out_proj_bias_loaded_;
46+
DEFINE_WEIGHT(in_proj_weight);
47+
DEFINE_WEIGHT(in_proj_bias);
48+
DEFINE_WEIGHT(out_proj_weight);
49+
DEFINE_WEIGHT(out_proj_bias);
5450
};
5551

5652
TORCH_MODULE(MultiheadAttention);
5753

5854
} // namespace layer
59-
} // namespace xllm
55+
} // namespace xllm

0 commit comments

Comments
 (0)