Skip to content

Commit a6fe056

Browse files
committed
feat: 增加对qwen3的支持
1 parent 525978f commit a6fe056

File tree

4 files changed

+261
-46
lines changed

4 files changed

+261
-46
lines changed

llama.cu/src/exec/group.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ fn builder() -> GraphBuilder {
243243
.register_op("swiglu", op::activation::SwiGLU)
244244
.register_op("concat", op::concat::Concat)
245245
.register_op("split", op::split::Split)
246+
.register_op("tile", op::tile::Tile)
247+
.register_op("merge", op::merge::Merge)
246248
.register_op("all-reduce", op::all_reduce::AllReduce);
247249
ans
248250
}

llama.cu/src/model/llama.rs

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ impl GGufModel<'_> {
1414
let dt_bias = match arch {
1515
"llama" => None,
1616
"qwen2" => Some(self.tensors["blk.0.attn_qkv.bias"].dt()),
17+
"qwen3" => None,
1718
arch => panic!("unsupported arch {arch}"),
1819
};
1920

@@ -23,7 +24,13 @@ impl GGufModel<'_> {
2324
let d = meta![self => llm_embedding_length];
2425
let nh = meta![self => llm_attention_head_count];
2526
let nkvh = meta![self => llm_attention_head_count_kv; nh];
26-
let dh = meta![self => llm_rope_dimension_count; d / nh];
27+
let dh = match arch {
28+
"qwen3" => self.tensors["blk.0.attn_qkv.weight"].shape()[0]
29+
.checked_div(nh + nkvh + nkvh)
30+
.unwrap(),
31+
_ => meta![self => llm_rope_dimension_count; d / nh],
32+
};
33+
println!("dh: {dh}");
2734
let di = meta![self => llm_feed_forward_length];
2835
let epsilon = meta![self => llm_attention_layer_norm_rms_epsilon; 1e-5];
2936
let dt_linear = self.tensors["blk.0.attn_qkv.weight"].dt();
@@ -68,8 +75,36 @@ impl GGufModel<'_> {
6875
get(&format!("blk.{iblk}.attn_qkv.weight")),
6976
dt_bias.map(|dt| (dt, get(&format!("blk.{iblk}.attn_qkv.bias")))),
7077
),
71-
q_norm: None,
72-
k_norm: None,
78+
q_norm: if self
79+
.tensors
80+
.contains_key(format!("blk.{iblk}.attn_q_norm.weight").as_str())
81+
{
82+
Some(Normalization {
83+
d: dh,
84+
epsilon: epsilon as _,
85+
items: NormType::RmsNorm {
86+
dt: out_norm.dt(),
87+
scale: get(&format!("blk.{iblk}.attn_q_norm.weight")),
88+
},
89+
})
90+
} else {
91+
None
92+
},
93+
k_norm: if self
94+
.tensors
95+
.contains_key(format!("blk.{iblk}.attn_k_norm.weight").as_str())
96+
{
97+
Some(Normalization {
98+
d: dh,
99+
epsilon: epsilon as _,
100+
items: NormType::RmsNorm {
101+
dt: out_norm.dt(),
102+
scale: get(&format!("blk.{iblk}.attn_k_norm.weight")),
103+
},
104+
})
105+
} else {
106+
None
107+
},
73108
rope: Some(RoPE {
74109
multimodal: false,
75110
nctx,
@@ -125,13 +160,19 @@ impl GGufModel<'_> {
125160

126161
/// 插入用于 RoPE 的 sin cos 表张量
127162
pub fn insert_rope_sin_cos(&mut self) {
163+
let arch = meta![self => general_architecture];
128164
let nctx = meta![self => llm_context_length];
129165
let d = meta![self => llm_embedding_length];
130166
let nh = meta![self => llm_attention_head_count];
131-
let dh = meta![self => llm_rope_dimension_count; d / nh];
167+
let nkvh = meta![self => llm_attention_head_count_kv; nh];
168+
let dh = match arch {
169+
"qwen3" => self.tensors["blk.0.attn_qkv.weight"].shape()[0]
170+
.checked_div(nh + nkvh + nkvh)
171+
.unwrap(),
172+
_ => meta![self => llm_rope_dimension_count; d / nh],
173+
};
132174
let theta = meta![self => llm_rope_freq_base; 1e4];
133175

134-
let arch = meta![self => general_architecture];
135176
let [sin, cos] = match self.get_str(&format!("{arch}.rope.scaling.type")) {
136177
Ok("longrope") => {
137178
let ctx_scale = 1.;
@@ -159,13 +200,19 @@ impl GGufModel<'_> {
159200

160201
/// 构造语言模型的 kv cache 张量
161202
pub fn lm_kv_cache<const N: usize>(&self) -> Tensor<usize, N> {
203+
let arch = meta![self => general_architecture];
162204
let dt = self.tensors["token_embd.weight"].dt();
163205
let nblk = meta![self => llm_block_count];
164206
let nctx = meta![self => llm_context_length];
165207
let d = meta![self => llm_embedding_length];
166208
let nh = meta![self => llm_attention_head_count];
167209
let nkvh = meta![self => llm_attention_head_count_kv; nh];
168-
let dh = meta![self => llm_rope_dimension_count; d / nh];
210+
let dh = match arch {
211+
"qwen3" => self.tensors["blk.0.attn_qkv.weight"].shape()[0]
212+
.checked_div(nh + nkvh + nkvh)
213+
.unwrap(),
214+
_ => meta![self => llm_rope_dimension_count; d / nh],
215+
};
169216
Tensor::from_dim_slice(dt, [nctx, nblk, 2, nkvh, dh])
170217
}
171218
}

llama.cu/src/op/rms_norm.cuh

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,36 @@ static __device__ void padding(
2929
*y = Ta(rms * x * w);
3030
}
3131

32+
// 三维张量的 padding 函数
33+
template <unsigned int BLOCK_SIZE, class Tw, class Ta>
34+
static __device__ void padding_3d(
35+
Ta *__restrict__ y_,
36+
int const stride_y_batch,
37+
int const stride_y_seq,
38+
Ta const *__restrict__ x_,
39+
int const stride_x_batch,
40+
int const stride_x_seq,
41+
Tw const *__restrict__ w_,
42+
float const epsilon) {
43+
44+
// blockIdx.x = batch index, blockIdx.y = seq index
45+
auto y = y_ + blockIdx.x * stride_y_batch + blockIdx.y * stride_y_seq + threadIdx.x;
46+
float const x = x_[blockIdx.x * stride_x_batch + blockIdx.y * stride_x_seq + threadIdx.x];
47+
float const w = w_[threadIdx.x];
48+
49+
using BlockOp = cub::BlockReduce<float, BLOCK_SIZE>;
50+
__shared__ typename BlockOp::TempStorage temp_storage;
51+
auto acc = BlockOp(temp_storage).Reduce(x * x, cub::Sum());
52+
53+
__shared__ float rms;
54+
if (threadIdx.x == 0) {
55+
rms = rsqrtf(acc / float(blockDim.x) + epsilon);
56+
}
57+
__syncthreads();
58+
59+
*y = Ta(rms * x * w);
60+
}
61+
3262
template <unsigned int BLOCK_SIZE, unsigned int NUM_ITEMS_THREAD, class Tw, class Ta>
3363
static __device__ void folding(
3464
Ta *__restrict__ y,
@@ -79,3 +109,59 @@ static __device__ void folding(
79109
BlockOp(temp_storage).Store(y, data, items_size);
80110
}
81111
}
112+
113+
// 三维张量的 folding 函数
114+
template <unsigned int BLOCK_SIZE, unsigned int NUM_ITEMS_THREAD, class Tw, class Ta>
115+
static __device__ void folding_3d(
116+
Ta *__restrict__ y,
117+
int const stride_y_batch,
118+
int const stride_y_seq,
119+
Ta const *__restrict__ x,
120+
int const stride_x_batch,
121+
int const stride_x_seq,
122+
Tw const *__restrict__ w,
123+
float const epsilon,
124+
unsigned int const items_size) {
125+
126+
// blockIdx.x = batch index, blockIdx.y = seq index
127+
y += blockIdx.x * stride_y_batch + blockIdx.y * stride_y_seq;
128+
x += blockIdx.x * stride_x_batch + blockIdx.y * stride_x_seq;
129+
130+
float data[NUM_ITEMS_THREAD], weight[NUM_ITEMS_THREAD];
131+
{
132+
using BlockOp = cub::BlockLoad<float, BLOCK_SIZE, NUM_ITEMS_THREAD>;
133+
__shared__ typename BlockOp::TempStorage temp_storage;
134+
BlockOp(temp_storage).Load(x, data, items_size, 0.f);
135+
BlockOp(temp_storage).Load(w, weight, items_size, 0.f);
136+
}
137+
138+
float squared = 0;
139+
#pragma unroll
140+
for (unsigned int i = 0; i < NUM_ITEMS_THREAD; ++i) {
141+
squared += data[i] * data[i];
142+
}
143+
144+
float acc;
145+
{
146+
using BlockOp = cub::BlockReduce<float, BLOCK_SIZE>;
147+
__shared__ typename BlockOp::TempStorage temp_storage;
148+
acc = BlockOp(temp_storage).Reduce(squared, cub::Sum());
149+
}
150+
151+
__shared__ float rms;
152+
if (threadIdx.x == 0) {
153+
rms = rsqrtf(acc / float(items_size) + epsilon);
154+
}
155+
__syncthreads();
156+
157+
#pragma unroll
158+
for (unsigned int i = 0; i < NUM_ITEMS_THREAD; ++i) {
159+
data[i] = rms * data[i] * weight[i];
160+
}
161+
162+
{
163+
using BlockOp = cub::BlockStore<float, BLOCK_SIZE, NUM_ITEMS_THREAD>;
164+
__shared__ typename BlockOp::TempStorage temp_storage;
165+
BlockOp(temp_storage).Store(y, data, items_size);
166+
}
167+
}

0 commit comments

Comments
 (0)