Skip to content

Commit 5f5adaf

Browse files
authored
feat: qwen3 moe support. (#253)
1 parent b9ec995 commit 5f5adaf

21 files changed

+729
-304
lines changed

converter/convert-hf.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
class ArchType:
99
LLAMA = 0xABCD00
1010
QWEN3 = 0xABCD01
11+
QWEN3_MOE = 0xABCD02
1112

1213
def permute(tensor, nHeads: int, nKvHeads: int):
1314
if nHeads != nKvHeads:
@@ -71,22 +72,23 @@ def __preparePlan(self):
7172
f'model.layers.{l}.self_attn.o_proj.weight'])
7273

7374
if (self.config['n_experts'] > 0):
75+
p.append([FloatType.F32, f'model.layers.{l}.mlp.gate.weight'])
7476
for e in range(self.config['n_experts']):
7577
p.append([wt,
76-
f'model.layers.{l}.block_sparse_moe.experts.{e}.w3.weight']) # up
78+
f'model.layers.{l}.mlp.experts.{e}.gate_proj.weight'])
7779
p.append([wt,
78-
f'model.layers.{l}.block_sparse_moe.experts.{e}.w1.weight']) # gate
80+
f'model.layers.{l}.mlp.experts.{e}.down_proj.weight'])
7981
p.append([wt,
80-
f'model.layers.{l}.block_sparse_moe.experts.{e}.w2.weight']) # down
82+
f'model.layers.{l}.mlp.experts.{e}.up_proj.weight'])
8183
else:
8284
p.append([wt,
83-
f'model.layers.{l}.mlp.gate_proj.weight']) # gate
85+
f'model.layers.{l}.mlp.gate_proj.weight'])
8486
p.append([wt,
85-
f'model.layers.{l}.mlp.down_proj.weight']) # down
87+
f'model.layers.{l}.mlp.down_proj.weight'])
8688
p.append([wt,
87-
f'model.layers.{l}.mlp.up_proj.weight']) # up
89+
f'model.layers.{l}.mlp.up_proj.weight'])
8890

89-
if (self.archType == ArchType.QWEN3):
91+
if (self.archType == ArchType.QWEN3 or self.archType == ArchType.QWEN3_MOE):
9092
p.append([FloatType.F32,
9193
f'model.layers.{l}.self_attn.q_norm.weight'])
9294
p.append([FloatType.F32,
@@ -146,6 +148,7 @@ def parseArchType(type: str):
146148
'llama': ArchType.LLAMA,
147149
'mistral': ArchType.LLAMA,
148150
'qwen3': ArchType.QWEN3,
151+
'qwen3_moe': ArchType.QWEN3_MOE,
149152
}.get(type)
150153
if (archType is None):
151154
raise Exception(f'Unsupported arch type: {type}')
@@ -202,8 +205,8 @@ def loadConfig(folderPath: str, weightsFloatType: int):
202205
'files': files,
203206
}
204207

205-
nExperts = config.get('num_local_experts')
206-
nActiveExperts = config.get('num_active_local_experts') or config.get('num_experts_per_tok')
208+
nExperts = config.get('num_experts')
209+
nActiveExperts = config.get('num_experts_per_tok')
207210
result['n_experts'] = int(nExperts) if nExperts is not None else 0
208211
result['n_active_experts'] = int(nActiveExperts) if nActiveExperts is not None else 0
209212

@@ -226,6 +229,10 @@ def loadConfig(folderPath: str, weightsFloatType: int):
226229
rmsNormEps = config.get('rms_norm_eps')
227230
if (rmsNormEps is not None):
228231
result['norm_epsilon'] = parseRmsNormEpsilon(rmsNormEps)
232+
233+
moeHiddenDim = config.get('moe_intermediate_size')
234+
if (moeHiddenDim is not None):
235+
result['moe_hidden_dim'] = int(moeHiddenDim)
229236
return result
230237

231238
def printUsage():

converter/writer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ def writeHeader(file, params):
128128
'rope_scaling_orig_max_seq_len': 17,
129129
'rope_type': 18,
130130
'head_dim': 19,
131-
'norm_epsilon': 20
131+
'norm_epsilon': 20,
132+
'moe_hidden_dim': 21,
132133
}
133134
header = struct.pack('i', 0xA00ABCD)
134135

launch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ def parts(length):
6565
'https://huggingface.co/b4rtaz/Qwen3-14B-Q40-Distributed-Llama/resolve/main/dllama_tokenizer_qwen3_14b.t?download=true',
6666
'q40', 'q80', 'chat', '--max-seq-len 4096'
6767
],
68+
'qwen3_30b_a3b_q40': [
69+
list(map(lambda suffix : f'https://huggingface.co/b4rtaz/Qwen3-30B-A3B-Q40-Distributed-Llama/resolve/main/dllama_model_qwen3_30b_a3b_{suffix}?download=true', parts(5))),
70+
'https://huggingface.co/b4rtaz/Qwen3-30B-A3B-Q40-Distributed-Llama/resolve/main/dllama_tokenizer_qwen3_30b_a3b.t?download=true',
71+
'q40', 'q80', 'chat', '--max-seq-len 4096'
72+
],
6873
}
6974

7075
def confirm(message: str):

src/llm.cpp

Lines changed: 197 additions & 72 deletions
Large diffs are not rendered by default.

src/llm.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ enum LlmHeaderKey {
2626
ROPE_SCALING_ORIG_MAX_SEQ_LEN = 17,
2727
ROPE_TYPE = 18,
2828
HEAD_DIM = 19,
29-
NORM_EPSILON = 20
29+
NORM_EPSILON = 20,
30+
MOE_HIDDEN_DIM = 21,
3031
};
3132

3233
enum LlmHiddenAct {
@@ -36,7 +37,8 @@ enum LlmHiddenAct {
3637

3738
enum LlmArchType {
3839
LLAMA = 0xABCD00,
39-
QWEN3 = 0xABCD01
40+
QWEN3 = 0xABCD01,
41+
QWEN3_MOE = 0xABCD02,
4042
};
4143

4244
typedef struct {
@@ -54,6 +56,7 @@ typedef struct {
5456
NnUint origSeqLen; // Original model context length
5557
NnUint seqLen; // Limited context length by the `--max-seq-len` argument
5658
NnUint hiddenDim;
59+
NnUint moeHiddenDim;
5760
LlmHiddenAct hiddenAct;
5861
NnUint qDim;
5962
NnUint kvDim;
@@ -86,9 +89,10 @@ typedef struct {
8689
NnUint tokenPipeIndex;
8790
NnUint xPipeIndex;
8891
NnUint logitsPipeIndex;
89-
NnSize2D tokenEmbeddingSize;
90-
NnSize2D rmsNormSize;
91-
NnSize2D qkRmsNormSize;
92+
NnSize3D tokenEmbeddingSize;
93+
NnSize3D rmsNormSize;
94+
NnSize3D qkRmsNormSize;
95+
NnSize3D moeGateSize;
9296
} LlmNet;
9397

9498
LlmHeader loadLlmHeader(const char* path, const unsigned int maxSeqLen, NnFloatType syncType);

src/nn/nn-config-builder.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class NnNetConfigBuilder {
2424
this->nBatches = nBatches;
2525
}
2626

27-
NnUint addPipe(const char *name, NnSize2D size) {
27+
NnUint addPipe(const char *name, NnSize3D size) {
2828
NnUint pipeIndex = pipes.size();
2929
pipes.push_back({ cloneString(name), size });
3030
return pipeIndex;
@@ -62,7 +62,7 @@ class NnNodeConfigBuilder {
6262
this->nodeIndex = nodeIndex;
6363
}
6464

65-
NnUint addBuffer(const char *name, NnSize2D size) {
65+
NnUint addBuffer(const char *name, NnSize3D size) {
6666
NnUint bufferIndex = buffers.size();
6767
buffers.push_back({ cloneString(name), size });
6868
return bufferIndex;
@@ -98,7 +98,7 @@ class NnSegmentConfigBuilder {
9898

9999
public:
100100
template <typename T>
101-
void addOp(NnOpCode code, const char *name, NnUint index, NnPointerConfig input, NnPointerConfig output, NnSize2D weightSize, T config) {
101+
void addOp(NnOpCode code, const char *name, NnUint index, NnPointerConfig input, NnPointerConfig output, NnSize3D weightSize, T config) {
102102
NnUint configSize = sizeof(T);
103103
NnByte *configCopy = new NnByte[configSize];
104104
std::memcpy(configCopy, &config, configSize);

src/nn/nn-core.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ NnOpQuantType getOpQuantType(NnFloatType input, NnFloatType weight, NnFloatType
7272

7373
const char *opCodeToString(NnOpCode code) {
7474
if (code == OP_MERGE_ADD) return "MERGE_ADD";
75+
if (code == OP_MERGE_SUM) return "MERGE_SUM";
7576
if (code == OP_EMBEDDING) return "EMBEDDING";
7677
if (code == OP_INV_RMS) return "INV_RMS";
7778
if (code == OP_RMS_NORM) return "RMS_NORM";
@@ -81,7 +82,11 @@ const char *opCodeToString(NnOpCode code) {
8182
if (code == OP_GELU) return "GELU";
8283
if (code == OP_SILU) return "SILU";
8384
if (code == OP_MUL) return "MUL";
85+
if (code == OP_SCALE) return "SCALE";
8486
if (code == OP_CAST) return "CAST";
87+
if (code == OP_REPEAT_Z) return "REPEAT_Z";
88+
if (code == OP_SHIFT) return "SHIFT";
89+
if (code == OP_MOE_GATE) return "MOE_GATE";
8590
throw std::invalid_argument("Unknown op code");
8691
}
8792

@@ -97,17 +102,22 @@ const char *opQuantTypeToString(NnOpQuantType type) {
97102
throw std::invalid_argument("Unknown op quant type");
98103
}
99104

100-
NnSize2D size0() {
101-
return { F_UNK, 0, 0, 0, 0 };
105+
NnSize3D size0() {
106+
return { F_UNK, 0, 0, 0, 0, 0 };
102107
}
103108

104-
NnSize2D size1D(NnFloatType floatType, NnUint x) {
105-
return size2D(floatType, 1, x);
109+
NnSize3D size1D(NnFloatType floatType, NnUint x) {
110+
return size3D(floatType, 1, 1, x);
106111
}
107112

108-
NnSize2D size2D(NnFloatType floatType, NnUint y, NnUint x) {
109-
NnSize length = y * x;
110-
return { floatType, y, x, length, getBytes(floatType, length) };
113+
NnSize3D size2D(NnFloatType floatType, NnUint y, NnUint x) {
114+
return size3D(floatType, 1, y, x);
115+
}
116+
117+
NnSize3D size3D(NnFloatType floatType, NnUint z, NnUint y, NnUint x) {
118+
NnSize len = z * y * x;
119+
NnSize lenXY = y * x;
120+
return { floatType, z, y, x, len, getBytes(floatType, len), getBytes(floatType, lenXY) };
111121
}
112122

113123
NnPointerConfig pointerBatchConfig(NnPointerSource source, NnUint index) {

src/nn/nn-core.hpp

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,29 @@
1111

1212
typedef struct {
1313
NnFloatType floatType;
14+
NnUint z;
1415
NnUint y;
1516
NnUint x;
1617
NnSize length;
1718
NnSize nBytes;
18-
} NnSize2D;
19+
NnSize nBytesXY;
20+
} NnSize3D;
1921

2022
// slices
2123

2224
typedef struct {
2325
NnUint kvDim0;
24-
NnSize2D keySize;
25-
NnSize2D valueSize;
26+
NnSize3D keySize;
27+
NnSize3D valueSize;
2628
} NnKvCacheSlice;
2729

2830
typedef struct {
2931
NnFloatType type;
3032
NnUint nNodes;
3133
NnUint d0;
3234
NnUint n;
33-
NnSize2D size;
34-
NnSize2D sliceSize;
35+
NnSize3D size;
36+
NnSize3D sliceSize;
3537
} NnRowMatmulSlice;
3638

3739
typedef struct {
@@ -40,8 +42,8 @@ typedef struct {
4042
NnUint n;
4143
NnUint n0;
4244
NnUint d;
43-
NnSize2D size;
44-
NnSize2D sliceSize;
45+
NnSize3D size;
46+
NnSize3D sliceSize;
4547
} NnColMatmulSlice;
4648

4749
typedef struct {
@@ -57,19 +59,20 @@ typedef struct {
5759
NnUint headDim;
5860
NnUint nKvHeads;
5961
float ropeTheta;
60-
NnSize2D cacheSize;
62+
NnSize3D cacheSize;
6163
} NnRopeSlice;
6264

6365
typedef struct {
6466
NnUint nHeads;
6567
NnUint nHeads0;
66-
NnSize2D attSize;
68+
NnSize3D attSize;
6769
} NnMultiHeadAttSlice;
6870

6971
// base enums
7072

7173
enum NnOpCode {
7274
OP_MERGE_ADD,
75+
OP_MERGE_SUM,
7376
OP_EMBEDDING,
7477
OP_INV_RMS,
7578
OP_RMS_NORM,
@@ -79,8 +82,12 @@ enum NnOpCode {
7982
OP_GELU,
8083
OP_SILU,
8184
OP_MUL,
85+
OP_SCALE,
8286
OP_CAST,
87+
OP_REPEAT_Z,
8388
OP_SHIFT,
89+
OP_SOFTMAX,
90+
OP_MOE_GATE,
8491
};
8592

8693
enum NnOpQuantType {
@@ -125,12 +132,12 @@ enum NnRopeType {
125132

126133
typedef struct {
127134
char *name;
128-
NnSize2D size;
135+
NnSize3D size;
129136
} NnPipeConfig;
130137

131138
typedef struct {
132139
char *name;
133-
NnSize2D size;
140+
NnSize3D size;
134141
} NnBufferConfig;
135142

136143
typedef struct {
@@ -145,7 +152,7 @@ typedef struct {
145152
NnUint index;
146153
NnPointerConfig input;
147154
NnPointerConfig output;
148-
NnSize2D weightSize;
155+
NnSize3D weightSize;
149156
NnByte *config;
150157
NnUint configSize;
151158
} NnOpConfig;
@@ -200,7 +207,9 @@ typedef struct {
200207
} NnRmsNormOpConfig;
201208

202209
typedef struct {
203-
// empty
210+
NnUint nExperts;
211+
NnUint nActiveExperts;
212+
NnUint activeExpertIndexesBufferIndex;
204213
} NnMatmulOpConfig;
205214

206215
typedef struct {
@@ -234,6 +243,10 @@ typedef struct {
234243
// empty
235244
} NnMergeAddOpCodeConfig;
236245

246+
typedef struct {
247+
// empty
248+
} NnMergeSumOpCodeConfig;
249+
237250
typedef struct {
238251
// empty
239252
} NnSiluOpCodeConfig;
@@ -242,14 +255,32 @@ typedef struct {
242255
NnUint multiplierBufferIndex;
243256
} NnMulOpCodeConfig;
244257

258+
typedef struct {
259+
NnUint scaleBufferIndex;
260+
} NnScaleOpCodeConfig;
261+
245262
typedef struct {
246263
// empty
247264
} NnCastOpCodeConfig;
248265

266+
typedef struct {
267+
// empty
268+
} NnRepeatZOpCodeConfig;
269+
249270
typedef struct {
250271
NnUint indexPipeIndex;
251272
} NnShiftOpCodeConfig;
252273

274+
typedef struct {
275+
// empty
276+
} NnSoftmaxOpCodeConfig;
277+
278+
typedef struct {
279+
NnUint k;
280+
NnUint normTopk;
281+
NnUint indexesBufferIndex;
282+
} NnMoeGateOpCodeConfig;
283+
253284
// utility functions
254285

255286
const char *opCodeToString(NnOpCode code);
@@ -258,9 +289,10 @@ const char *opQuantTypeToString(NnOpQuantType type);
258289
NnSize getBytes(NnFloatType floatType, NnSize n);
259290
NnSize getBlockSize(NnFloatType floatType);
260291
NnOpQuantType getOpQuantType(NnFloatType input, NnFloatType weight, NnFloatType output);
261-
NnSize2D size0();
262-
NnSize2D size1D(NnFloatType floatType, NnUint x);
263-
NnSize2D size2D(NnFloatType floatType, NnUint y, NnUint x);
292+
NnSize3D size0();
293+
NnSize3D size1D(NnFloatType floatType, NnUint x);
294+
NnSize3D size2D(NnFloatType floatType, NnUint y, NnUint x);
295+
NnSize3D size3D(NnFloatType floatType, NnUint z, NnUint y, NnUint x);
264296
NnPointerConfig pointerBatchConfig(NnPointerSource source, NnUint index);
265297
NnPointerConfig pointerBatchedSliceConfig(NnPointerSource source, NnUint index);
266298
NnPointerConfig pointerRawConfig(NnPointerSource source, NnUint index);

0 commit comments

Comments
 (0)