Skip to content

Commit 4d7a998

Browse files
feat: refactor layer module to support multiple platforms.
1 parent 808942b commit 4d7a998

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+1465
-1133
lines changed

xllm/core/layers/CMakeLists.txt

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,47 @@ cc_library(
1919
torch
2020
)
2121

22+
cc_library(
23+
NAME
24+
base_layer
25+
HDRS
26+
attention_mask.h
27+
base_layer.h
28+
SRCS
29+
attention_mask.cpp
30+
base_layer.cpp
31+
DEPS
32+
:state_dict
33+
:block
34+
:kv_cache
35+
:prefix_cache
36+
glog::glog
37+
gflags::gflags
38+
torch
39+
)
40+
2241
cc_library(
2342
NAME
2443
layers
44+
HDRS
45+
column_parallel_linear.h
46+
deepseek_v2_decoder_layer.h
47+
llama_decoder_layer.h
48+
multi_head_attention.h
49+
qwen2_decoder_layer.h
50+
qwen2dot5_vision_decode_layer.h
51+
qwen3_decoder_layer.h
52+
qwen3_moe_decoder_layer.h
53+
rms_norm.h
54+
siglip_encoder_layer.h
55+
SRCS
56+
multi_head_attention.cpp
2557
DEPS
2658
:state_dict
2759
:kv_cache
2860
:prefix_cache
2961
:block
62+
:base_layer
3063
:rotary_embedding
3164
glog::glog
3265
gflags::gflags

xllm/core/layers/npu/attn_mask.cpp renamed to xllm/core/layers/attention_mask.cpp

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
#include "attn_mask.h"
1+
#include "attention_mask.h"
22

3-
namespace xllm::hf {
3+
namespace xllm::layer {
44

5-
AttentionMaskImpl::AttentionMaskImpl(at::Device device,
6-
torch::Dtype dtype,
7-
float mask_value) {
5+
AttentionMask::AttentionMask(at::Device device,
6+
torch::Dtype dtype,
7+
float mask_value) {
88
int max_seq_len = 128;
99
seq_len_cached_ = max_seq_len;
1010
auto bias_cache =
@@ -21,25 +21,24 @@ AttentionMaskImpl::AttentionMaskImpl(at::Device device,
2121
.to(device);
2222
}
2323

24-
torch::Tensor AttentionMaskImpl::get_decode_attn_mask(
25-
torch::Tensor input_lengths,
26-
int64_t max_s,
27-
torch::Dtype dtype,
28-
torch::Device device) {
24+
torch::Tensor AttentionMask::get_decode_attn_mask(torch::Tensor input_lengths,
25+
int64_t max_s,
26+
torch::Dtype dtype,
27+
torch::Device device) {
2928
update_attn_cache(dtype, device, max_s);
3029
return atten_mask_cache_.index_select(0, input_lengths).view({-1, 1, max_s});
3130
}
3231

33-
torch::Tensor AttentionMaskImpl::get_attn_mask(int64_t max_s,
34-
torch::Dtype dtype,
35-
torch::Device device) {
32+
torch::Tensor AttentionMask::get_attn_mask(int64_t max_s,
33+
torch::Dtype dtype,
34+
torch::Device device) {
3635
update_attn_cache(dtype, device, max_s);
3736
return atten_mask_cache_.slice(0, 0, max_s).slice(1, 0, max_s);
3837
}
3938

40-
torch::Tensor AttentionMaskImpl::gen_free_mask(int32_t q_len,
41-
torch::Dtype dtype,
42-
torch::Device device) {
39+
torch::Tensor AttentionMask::gen_free_mask(int32_t q_len,
40+
torch::Dtype dtype,
41+
torch::Device device) {
4342
float pre_mask_factor = -10000.0f;
4443
if (dtype == torch::kBFloat16) {
4544
pre_mask_factor = 1.0f;
@@ -52,9 +51,9 @@ torch::Tensor AttentionMaskImpl::gen_free_mask(int32_t q_len,
5251
return mask_free;
5352
}
5453

55-
void AttentionMaskImpl::update_attn_cache(torch::Dtype dtype,
56-
torch::Device device,
57-
int64_t seqlen) {
54+
void AttentionMask::update_attn_cache(torch::Dtype dtype,
55+
torch::Device device,
56+
int64_t seqlen) {
5857
if (seqlen > seq_len_cached_ || atten_mask_cache_.dtype() != dtype) {
5958
seq_len_cached_ = seqlen;
6059

@@ -69,4 +68,4 @@ void AttentionMaskImpl::update_attn_cache(torch::Dtype dtype,
6968
}
7069
}
7170

72-
} // namespace xllm::hf
71+
} // namespace xllm::layer

xllm/core/layers/npu/attn_mask.h renamed to xllm/core/layers/attention_mask.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
#pragma once
22
#include <torch/torch.h>
33

4-
#include "atb/atb_infer.h"
5-
64
namespace xllm {
7-
namespace hf {
5+
namespace layer {
86

9-
class AttentionMaskImpl : public torch::nn::Module {
7+
class AttentionMask : public torch::nn::Module {
108
public:
11-
AttentionMaskImpl() = default;
9+
AttentionMask() = default;
1210

13-
explicit AttentionMaskImpl(at::Device device,
14-
torch::Dtype dtype,
15-
float mask_value = -9984);
11+
explicit AttentionMask(at::Device device,
12+
torch::Dtype dtype,
13+
float mask_value = -9984);
1614

1715
torch::Tensor get_decode_attn_mask(torch::Tensor input_lengths,
1816
int64_t max_s,
@@ -37,5 +35,5 @@ class AttentionMaskImpl : public torch::nn::Module {
3735
at::Tensor atten_mask_cache_;
3836
};
3937

40-
} // namespace hf
38+
} // namespace layer
4139
} // namespace xllm

xllm/core/layers/base_layer.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#include "base_layer.h"
2+
3+
namespace xllm {
4+
namespace layer {
5+
6+
BaseLayer::BaseLayer(const Context& context)
7+
: device_(context.get_tensor_options().device()),
8+
name_(""),
9+
parallel_args_(context.get_parallel_args()) {
10+
auto quant_args = context.get_quant_args();
11+
if (!quant_args.quantize_type().empty()) {
12+
quantize_type_ = quant_args.quantize_type();
13+
}
14+
15+
if (!quant_args.torch_dtype().empty()) {
16+
torch_dtype_ = quant_args.torch_dtype();
17+
}
18+
19+
dp_size_ = parallel_args_.dp_size();
20+
dp_local_tp_size_ = parallel_args_.world_size() / dp_size_;
21+
dp_rank_ = parallel_args_.rank() / dp_local_tp_size_;
22+
CHECK_EQ(parallel_args_.world_size(), dp_size_ * dp_local_tp_size_);
23+
dp_local_tp_rank_ = parallel_args_.rank() % dp_local_tp_size_;
24+
25+
run_task_func_ = std::bind(
26+
&BaseLayer::run_task, this, std::placeholders::_1, std::placeholders::_2);
27+
}
28+
29+
torch::Dtype BaseLayer::string2dtype(const std::string& dtype_str) {
30+
if (dtype_str.compare("float16") == 0) {
31+
return torch::kFloat16;
32+
} else if (dtype_str.compare("bfloat16") == 0) {
33+
return torch::kBFloat16;
34+
} else if (dtype_str.compare("float32") == 0) {
35+
return torch::kFloat32;
36+
} else if (dtype_str.compare("float64") == 0) {
37+
return torch::kFloat64;
38+
} else if (dtype_str.compare("int8") == 0) {
39+
return torch::kInt8;
40+
} else if (dtype_str.compare("int16") == 0) {
41+
return torch::kInt16;
42+
} else if (dtype_str.compare("int32") == 0) {
43+
return torch::kInt32;
44+
} else if (dtype_str.compare("int64") == 0) {
45+
return torch::kInt64;
46+
} else if (dtype_str.compare("uint8") == 0) {
47+
return torch::kUInt8;
48+
} else if (dtype_str.compare("bool") == 0) {
49+
return torch::kBool;
50+
}
51+
52+
throw std::runtime_error("Unsupported dtype string");
53+
}
54+
55+
void BaseLayer::correct_tensor_dtype(torch::Tensor& tensor,
56+
const std::string& tensorName) {
57+
if (absl::EndsWith(tensorName, "deq_scale") &&
58+
(torch_dtype_.compare("bfloat16") == 0)) {
59+
return;
60+
}
61+
62+
if (tensor.dtype() != torch::kInt8 && tensor.dtype() != torch::kInt32 &&
63+
tensor.dtype() != torch::kInt64) {
64+
torch::Dtype dtype = string2dtype(torch_dtype_);
65+
tensor = tensor.to(dtype);
66+
}
67+
}
68+
69+
void BaseLayer::set_weight(const StateDict& state_dict,
70+
const std::string& tensor_name,
71+
int weight_position) {
72+
for (const auto& [name, tensor] : state_dict) {
73+
if (absl::EndsWith(name, tensor_name)) {
74+
at::Tensor mutable_tensor = tensor;
75+
correct_tensor_dtype(mutable_tensor, tensor_name);
76+
at_weight_tensors_[weight_position] = mutable_tensor.to(device_);
77+
}
78+
}
79+
}
80+
81+
void BaseLayer::set_weight(const StateDict& state_dict,
82+
const std::string& tensor_name,
83+
int weight_position,
84+
int dim) {
85+
for (const auto& [name, tensor] : state_dict) {
86+
if (absl::EndsWith(name, tensor_name)) {
87+
if (parallel_args_.world_size() <= 1) {
88+
at::Tensor mutable_tensor = tensor;
89+
correct_tensor_dtype(mutable_tensor, tensor_name);
90+
at_weight_tensors_[weight_position] = mutable_tensor.to(device_);
91+
} else {
92+
at_weight_tensors_[weight_position] =
93+
state_dict
94+
.get_sharded_tensor(tensor_name,
95+
/*dim=*/dim,
96+
/*rank=*/parallel_args_.rank(),
97+
/*world_size=*/parallel_args_.world_size())
98+
.to(device_);
99+
correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name);
100+
}
101+
}
102+
}
103+
}
104+
105+
void BaseLayer::set_weight(const StateDict& state_dict,
106+
const std::string& tensor_name,
107+
int weight_position,
108+
int dim,
109+
int rank,
110+
int world_size) {
111+
for (const auto& [name, tensor] : state_dict) {
112+
if (absl::EndsWith(name, tensor_name)) {
113+
if (world_size <= 1) {
114+
at::Tensor mutable_tensor = tensor;
115+
correct_tensor_dtype(mutable_tensor, tensor_name);
116+
at_weight_tensors_[weight_position] = mutable_tensor.to(device_);
117+
} else {
118+
at_weight_tensors_[weight_position] =
119+
state_dict
120+
.get_sharded_tensor(tensor_name,
121+
/*dim=*/dim,
122+
/*rank=*/rank,
123+
/*world_size=*/world_size)
124+
.to(device_);
125+
correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name);
126+
}
127+
}
128+
}
129+
}
130+
131+
} // namespace layer
132+
} // namespace xllm

0 commit comments

Comments
 (0)