Skip to content

Commit 2d743b6

Browse files
committed
wip
1 parent 6dca237 commit 2d743b6

File tree

6 files changed

+376
-70
lines changed

6 files changed

+376
-70
lines changed

examples/tts/convert_csm_to_gguf.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ class CSMModelConverter:
8989
fname_out: Path
9090
ftype: gguf.LlamaFileType
9191

92-
projection_tensor: Tensor # projecting from n_embd_backbone (2048) to n_embd_decoder (1024)
93-
9492
def __init__(self,
9593
safetensors_path: Union[Path, str],
9694
path_to_vocab_gguf: Path,
@@ -110,24 +108,18 @@ def __init__(self,
110108
# backbone
111109
self.gguf_writer_backbone = gguf.GGUFWriter(
112110
path=None,
113-
arch="llama",
111+
arch="llama-csm",
114112
endianess=endianess)
115113

116114
# decoder
117115
self.gguf_writer_decoder = gguf.GGUFWriter(
118116
path=None,
119-
arch="llama",
117+
arch="llama-csm",
120118
endianess=endianess)
121119

122120
Llama_3_2_1B().write_gguf_metadata(self.gguf_writer_backbone, self.gguf_reader_vocab)
123121
Llama_3_2_100M().write_gguf_metadata(self.gguf_writer_decoder, self.gguf_reader_vocab)
124122

125-
# get projection tensor)
126-
for name, data_torch in self.state_dict.items():
127-
if name == "projection.weight":
128-
self.projection_tensor = data_torch
129-
break
130-
131123
# load tensors
132124
for component in ("backbone", "decoder"):
133125
print()
@@ -165,10 +157,7 @@ def rename_transformer(name: str) -> str:
165157

166158
if "audio_embeddings." in name:
167159
is_decoder = True
168-
if component == "decoder":
169-
name = name.replace("audio_embeddings.", "token_embd.")
170-
data_torch = torch.mm(data_torch, self.projection_tensor.T)
171-
print("Applied projection to audio_embeddings", data_torch.shape)
160+
name = name.replace("audio_embeddings.", "audio_embd.")
172161

173162
elif "text_embeddings." in name:
174163
is_backbone = True
@@ -189,11 +178,18 @@ def rename_transformer(name: str) -> str:
189178
elif name == "audio_head":
190179
is_decoder = True
191180
name = "audio_head.weight"
181+
if component == "decoder":
182+
# add padding at the beginning so that build_lora_mm_id can be used
183+
zero_tensor = torch.zeros(1, 1024, 2051)
184+
data_torch = torch.cat([zero_tensor, data_torch], dim=0)
185+
assert data_torch.shape == (32, 1024, 2051)
186+
# then, transpose it
187+
data_torch = data_torch.transpose(1, 2)
192188

193189
elif name == "projection.weight":
194190
is_decoder = True
195-
name = "inp_proj.weight"
196-
self.projection_tensor = data_torch
191+
is_backbone = True
192+
name = "csm_proj.weight"
197193

198194
if can_quantize:
199195
if self.ftype == gguf.LlamaFileType.ALL_F32:
@@ -203,7 +199,9 @@ def rename_transformer(name: str) -> str:
203199
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
204200
data_qtype = gguf.GGMLQuantizationType.BF16
205201
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
206-
data_qtype = gguf.GGMLQuantizationType.Q8_0
202+
# decoder is very sensitive to quantization, do not quantize it lower than F16
203+
data_qtype = gguf.GGMLQuantizationType.Q8_0 if component != "decoder" \
204+
else gguf.GGMLQuantizationType.F16
207205
else:
208206
raise ValueError(f"Unsupported file type: {self.ftype}")
209207

examples/tts/tts-csm.cpp

Lines changed: 95 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,12 @@ static llama_token sample_greedy(const float * logits, int n_vocab) {
3030
static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) {
3131
std::vector<float> * embd = (std::vector<float> *) user_data;
3232

33-
if (t && strcmp(t->name, "result_norm") == 0) {
33+
if (t && (strcmp(t->name, "output_csm_proj") == 0 || strcmp(t->name, "output_audio_embd") == 0)) {
3434
if (ask) return true;
3535

36-
auto n_bytes = ggml_nbytes(t);
37-
embd->resize(n_bytes);
38-
ggml_backend_tensor_get(t, embd->data(), 0, n_bytes);
39-
printf("result_norm\n");
36+
embd->resize(ggml_nelements(t));
37+
ggml_backend_tensor_get(t, embd->data(), 0, ggml_nbytes(t));
38+
// printf("%s tensor size: %lld, %lld\n", t->name, t->ne[0], t->ne[1]);
4039
return true;
4140
}
4241

@@ -54,34 +53,37 @@ int main(int argc, char ** argv) {
5453
params.n_batch = 8192;
5554
params.n_ctx = 8192;
5655

57-
params.sampling.top_k = 4;
58-
params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K, };
59-
6056
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) {
6157
return 1;
6258
}
6359

6460
llama_backend_init();
6561
llama_numa_init(params.numa);
6662

67-
common_params params_decoder(params); // duplicate the params
68-
string_replace_all(params_decoder.model, "-backbone", "-decoder");
69-
7063
std::vector<float> embd;
7164
params.cb_eval = ggml_callback;
7265
params.cb_eval_user_data = &embd;
66+
params.warmup = false;
67+
68+
common_params params_decoder(params); // duplicate the params
69+
string_replace_all(params_decoder.model, "-backbone", "-decoder");
70+
7371
common_init_result llama_backbone = common_init_from_params(params);
7472
llama_model * model_bb = llama_backbone.model.get();
7573
llama_context * ctx_bb = llama_backbone.context.get();
7674

77-
//common_init_result llama_decoder = common_init_from_params(params_decoder);
78-
//llama_model * model_dc = llama_decoder.model.get();
79-
//llama_context * ctx_dc = llama_decoder.context.get();
75+
common_init_result llama_decoder = common_init_from_params(params_decoder);
76+
llama_model * model_dc = llama_decoder.model.get();
77+
llama_context * ctx_dc = llama_decoder.context.get();
8078

8179
if (model_bb == nullptr || ctx_bb == nullptr) {
8280
return ENOENT;
8381
}
8482

83+
if (model_dc == nullptr || ctx_dc == nullptr) {
84+
return ENOENT;
85+
}
86+
8587
const llama_vocab * vocab = llama_model_get_vocab(model_bb);
8688
llama_tokens prompt_tokens = common_tokenize(vocab, params.prompt, false, true);
8789
prompt_tokens.insert(prompt_tokens.begin(), llama_vocab_bos(vocab));
@@ -93,27 +95,92 @@ int main(int argc, char ** argv) {
9395
}
9496
printf("\n");
9597

98+
llama_pos n_past_bb = 0;
9699
llama_batch batch = llama_batch_init(params.n_batch, 0, 1);
100+
common_batch_clear(batch);
97101
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
98-
common_batch_add(batch, prompt_tokens[i], i, { 0 }, false);
102+
common_batch_add(batch, prompt_tokens[i], n_past_bb++, { 0 }, false);
99103
}
100104
batch.logits[batch.n_tokens - 1] = true;
101105

102-
if (llama_decode(ctx_bb, batch) != 0) {
103-
LOG_ERR("%s: llama_decode() failed\n", __func__);
104-
return 1;
105-
}
106+
std::vector<float> inp_past_embd(2048, 0.0f);
107+
llama_batch batch_past_embd = llama_batch_init(1, inp_past_embd.size(), 1);
106108

107-
//auto vocab_dc = llama_model_get_vocab(model_dc);
108-
auto logits = llama_get_logits_ith(ctx_bb, batch.n_tokens - 1);
109-
//printf("next tok: %d\n", sample_greedy(logits, llama_vocab_n_tokens(vocab_dc)));
110-
for (size_t i = 0; i < 10; ++i) {
111-
printf("%4.2f, ", logits[i]);
112-
}
113-
printf("next tok: %d\n", sample_greedy(logits, 65632));
109+
for (int k = 0; k < 4; ++k) {
110+
if (llama_decode(ctx_bb, k == 0 ? batch : batch_past_embd) != 0) {
111+
LOG_ERR("%s: llama_decode() failed\n", __func__);
112+
return 1;
113+
}
114+
115+
auto vocab_dc = llama_model_get_vocab(model_dc);
116+
auto logits = llama_get_logits_ith(ctx_bb, k == 0 ? (batch.n_tokens - 1) : 0);
117+
// for (size_t i = 0; i < 10; ++i) {
118+
// printf("%4.2f, ", logits[i]);
119+
// }
120+
// printf("\n");
121+
122+
llama_token latent_token = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc));
123+
// printf("latent_token: %d\n", latent_token);
124+
printf("%5d, ", latent_token);
125+
126+
// for (size_t i = 0; i < 10; ++i) {
127+
// printf("%4.2f, ", embd[i]);
128+
// }
129+
// printf("\n");
130+
131+
132+
133+
// decode
134+
prompt_tokens.clear();
135+
prompt_tokens.push_back(latent_token);
136+
inp_past_embd = std::vector<float>(inp_past_embd.size(), 0.0f);
137+
{
138+
llama_kv_self_clear(ctx_dc);
139+
llama_batch batch_embd = llama_batch_init(1, embd.size(), 1);
140+
llama_batch batch_token = llama_batch_init(1, 0, 1);
141+
{
142+
batch_embd.n_tokens = 1;
143+
batch_embd.pos[0] = 0;
144+
batch_embd.seq_id[0][0] = 0;
145+
batch_embd.n_seq_id[0] = 1;
146+
batch_embd.logits[0] = false;
147+
memcpy(batch_embd.embd, embd.data(), embd.size() * sizeof(float));
148+
}
149+
llama_decode(ctx_dc, batch_embd);
150+
151+
llama_token audio_token = latent_token;
152+
for (int i = 0; i < 31; ++i) {
153+
common_batch_clear(batch_token);
154+
// encoder vocab is further divided into 32 codebooks, each with 2051 entries
155+
llama_token inp_tok = audio_token + 2051*i;
156+
common_batch_add(batch_token, inp_tok, i+1, { 0 }, true);
157+
llama_decode(ctx_dc, batch_token);
158+
auto logits = llama_get_logits_ith(ctx_dc, 0);
159+
audio_token = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc));
160+
printf("%d,", audio_token);
161+
prompt_tokens.push_back(audio_token);
162+
163+
GGML_ASSERT(inp_past_embd.size() == embd.size());
164+
for (size_t i = 0; i < inp_past_embd.size(); ++i) {
165+
inp_past_embd[i] += embd[i];
166+
}
167+
}
168+
printf("\n");
169+
170+
llama_batch_free(batch_embd);
171+
llama_batch_free(batch_token);
172+
}
114173

115-
for (size_t i = 0; i < 10; ++i) {
116-
printf("%4.2f, ", embd[i]);
174+
// prepare for the next iteration
175+
{
176+
batch_past_embd.n_tokens = 1;
177+
batch_past_embd.pos[0] = n_past_bb;
178+
batch_past_embd.seq_id[0][0] = 0;
179+
batch_past_embd.n_seq_id[0] = 1;
180+
batch_past_embd.logits[0] = true;
181+
memcpy(batch_past_embd.embd, inp_past_embd.data(), inp_past_embd.size() * sizeof(float));
182+
}
183+
n_past_bb++;
117184
}
118185

119186
return 0;

src/llama-arch.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
88
{ LLM_ARCH_LLAMA, "llama" },
9+
{ LLM_ARCH_LLAMA_CSM, "llama-csm" },
910
{ LLM_ARCH_DECI, "deci" },
1011
{ LLM_ARCH_FALCON, "falcon" },
1112
{ LLM_ARCH_GROK, "grok" },
@@ -229,9 +230,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
229230
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
230231
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
231232
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
233+
},
234+
},
235+
{
236+
LLM_ARCH_LLAMA_CSM, // like LLM_ARCH_LLAMA, but with extra tensors for Sesame CSM
237+
{
238+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
239+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
240+
{ LLM_TENSOR_OUTPUT, "output" },
241+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
242+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
243+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
244+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
245+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
246+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
247+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
248+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
249+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
250+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
251+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
252+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
253+
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
254+
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
255+
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
256+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
257+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
258+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
259+
{ LLM_TENSOR_CSM_AUDIO_EMBD, "audio_embd" },
232260
{ LLM_TENSOR_CSM_CBOOK_OUTPUT, "codebook0_head" },
233261
{ LLM_TENSOR_CSM_AUDIO_OUTPUT, "audio_head" },
234-
{ LLM_TENSOR_CSM_INP_PROJ, "inp_proj" },
262+
{ LLM_TENSOR_CSM_PROJ, "csm_proj" },
235263
},
236264
},
237265
{
@@ -1573,9 +1601,10 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
15731601
{LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
15741602
{LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
15751603
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1604+
{LLM_TENSOR_CSM_AUDIO_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
15761605
{LLM_TENSOR_CSM_CBOOK_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
15771606
{LLM_TENSOR_CSM_AUDIO_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
1578-
{LLM_TENSOR_CSM_INP_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}},
1607+
{LLM_TENSOR_CSM_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}},
15791608
};
15801609

15811610
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}

src/llama-arch.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
enum llm_arch {
1212
LLM_ARCH_LLAMA,
13+
LLM_ARCH_LLAMA_CSM,
1314
LLM_ARCH_DECI,
1415
LLM_ARCH_FALCON,
1516
LLM_ARCH_BAICHUAN,
@@ -347,9 +348,10 @@ enum llm_tensor {
347348
LLM_TENSOR_POS_NET_ATTN_K,
348349
LLM_TENSOR_POS_NET_ATTN_V,
349350
LLM_TENSOR_POS_NET_ATTN_OUT,
351+
LLM_TENSOR_CSM_AUDIO_EMBD,
350352
LLM_TENSOR_CSM_CBOOK_OUTPUT,
351353
LLM_TENSOR_CSM_AUDIO_OUTPUT,
352-
LLM_TENSOR_CSM_INP_PROJ,
354+
LLM_TENSOR_CSM_PROJ,
353355
};
354356

355357
enum llm_tensor_layer {

0 commit comments

Comments
 (0)