Skip to content

Commit 6dca237

Browse files
committed
tts : add sesame csm
1 parent 3714c3e commit 6dca237

File tree

7 files changed

+513
-23
lines changed

7 files changed

+513
-23
lines changed

examples/tts/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,9 @@ add_executable(${TARGET} tts.cpp)
33
install(TARGETS ${TARGET} RUNTIME)
44
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
55
target_compile_features(${TARGET} PRIVATE cxx_std_17)
6+
7+
set(TARGET llama-tts-csm)
8+
add_executable(${TARGET} tts-csm.cpp)
9+
install(TARGETS ${TARGET} RUNTIME)
10+
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
11+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
import os
2+
import sys
3+
import argparse
4+
import logging
5+
import torch
6+
from safetensors.torch import load_file
7+
from typing import Union, Any, Dict
8+
from pathlib import Path
9+
from torch import Tensor
10+
from huggingface_hub import hf_hub_download
11+
12+
cur_path = sys.path
13+
if 'NO_LOCAL_GGUF' not in os.environ:
14+
sys.path.insert(1, str(Path(__file__).parent.parent.parent / 'gguf-py'))
15+
import gguf
16+
17+
sys.path = cur_path
18+
19+
logger = logging.getLogger("csm")
20+
21+
22+
# This converts directly one safetensors file to 2 GGUFs
23+
# It is easier to do this way, rather than convert to 2 smaller HF models and then convert to GGUF
24+
# This is because the Sesame model does not have built-in tokenizer
25+
26+
def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
27+
field = reader.get_field(key)
28+
return field.contents() if field else None
29+
30+
# copied from https://github.com/SesameAILabs/csm/blob/main/models.py
31+
class Llama_3_2_1B:
32+
vocab_size=128_256
33+
num_layers=16
34+
num_heads=32
35+
num_kv_heads=8
36+
embed_dim=2048
37+
max_seq_len=2048
38+
intermediate_dim=8192
39+
attn_dropout=0.0
40+
norm_eps=1e-5
41+
rope_base=500_000
42+
scale_factor=32
43+
44+
def write_gguf_metadata(self, fout: gguf.GGUFWriter, fvocab: gguf.GGUFReader):
45+
arch = get_field_data(fvocab, gguf.Keys.General.ARCHITECTURE)
46+
assert arch == "llama"
47+
fout.add_type("model")
48+
fout.add_block_count(self.num_layers)
49+
fout.add_context_length(self.max_seq_len)
50+
fout.add_feed_forward_length(self.intermediate_dim)
51+
fout.add_embedding_length(self.embed_dim)
52+
# attn
53+
fout.add_head_count(self.num_heads)
54+
fout.add_head_count_kv(self.num_kv_heads)
55+
fout.add_rope_freq_base(self.rope_base)
56+
# fout.add_rope_scaling_factor(self.scale_factor) # breaks if this is added
57+
fout.add_rope_dimension_count(self.embed_dim // self.num_heads)
58+
fout.add_layer_norm_rms_eps(self.norm_eps)
59+
fout.add_key_length(self.embed_dim // self.num_heads)
60+
fout.add_value_length(self.embed_dim // self.num_heads)
61+
# vocab
62+
fout.add_vocab_size(self.vocab_size)
63+
fout.add_tokenizer_model(get_field_data(fvocab, gguf.Keys.Tokenizer.MODEL))
64+
fout.add_tokenizer_pre(get_field_data(fvocab, gguf.Keys.Tokenizer.PRE))
65+
fout.add_token_list(get_field_data(fvocab, gguf.Keys.Tokenizer.LIST)[:self.vocab_size])
66+
fout.add_token_types(get_field_data(fvocab, gguf.Keys.Tokenizer.TOKEN_TYPE)[:self.vocab_size])
67+
fout.add_token_merges(get_field_data(fvocab, gguf.Keys.Tokenizer.MERGES))
68+
fout.add_bos_token_id(get_field_data(fvocab, gguf.Keys.Tokenizer.BOS_ID))
69+
fout.add_eos_token_id(get_field_data(fvocab, gguf.Keys.Tokenizer.EOS_ID))
70+
71+
class Llama_3_2_100M(Llama_3_2_1B):
72+
vocab_size=65_632 #128_256
73+
num_layers=4
74+
num_heads=8
75+
num_kv_heads=2
76+
embed_dim=1024
77+
max_seq_len=2048
78+
intermediate_dim=8192
79+
attn_dropout=0.0
80+
norm_eps=1e-5
81+
rope_base=500_000
82+
scale_factor=32
83+
84+
class CSMModelConverter:
85+
state_dict: Dict[str, Tensor]
86+
gguf_writer_backbone: gguf.GGUFWriter
87+
gguf_writer_decoder: gguf.GGUFWriter
88+
gguf_reader_vocab: gguf.GGUFReader
89+
fname_out: Path
90+
ftype: gguf.LlamaFileType
91+
92+
projection_tensor: Tensor # projecting from n_embd_backbone (2048) to n_embd_decoder (1024)
93+
94+
def __init__(self,
95+
safetensors_path: Union[Path, str],
96+
path_to_vocab_gguf: Path,
97+
fname_out: Path,
98+
ftype: gguf.LlamaFileType,
99+
is_big_endian: bool,):
100+
101+
if "<component>" not in fname_out.name:
102+
raise ValueError("Output file name must contain '<component>' placeholder, for example: 'sesame-csm-<component>.gguf'")
103+
104+
self.state_dict = load_file(safetensors_path, device="cpu")
105+
self.fname_out = fname_out
106+
self.ftype = ftype
107+
self.gguf_reader_vocab = gguf.GGUFReader(path_to_vocab_gguf)
108+
endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
109+
110+
# backbone
111+
self.gguf_writer_backbone = gguf.GGUFWriter(
112+
path=None,
113+
arch="llama",
114+
endianess=endianess)
115+
116+
# decoder
117+
self.gguf_writer_decoder = gguf.GGUFWriter(
118+
path=None,
119+
arch="llama",
120+
endianess=endianess)
121+
122+
Llama_3_2_1B().write_gguf_metadata(self.gguf_writer_backbone, self.gguf_reader_vocab)
123+
Llama_3_2_100M().write_gguf_metadata(self.gguf_writer_decoder, self.gguf_reader_vocab)
124+
125+
# get projection tensor)
126+
for name, data_torch in self.state_dict.items():
127+
if name == "projection.weight":
128+
self.projection_tensor = data_torch
129+
break
130+
131+
# load tensors
132+
for component in ("backbone", "decoder"):
133+
print()
134+
print(f"Converting {component}...")
135+
print()
136+
for name, data_torch in self.state_dict.items():
137+
# convert any unsupported data types to float32
138+
old_dtype = data_torch.dtype
139+
if data_torch.dtype not in (torch.float16, torch.float32):
140+
data_torch = data_torch.to(torch.float32)
141+
self.add_tensor(name, data_torch, old_dtype, component)
142+
143+
def add_tensor(self, name: str, data_torch: Tensor, old_dtype: torch.dtype, component: str):
144+
is_1d = len(data_torch.shape) == 1
145+
#is_embd = "_embeddings" in name
146+
can_quantize = not is_1d #and not is_embd
147+
data_qtype = gguf.GGMLQuantizationType.F32
148+
149+
is_backbone = False
150+
is_decoder = False
151+
152+
def rename_transformer(name: str) -> str:
153+
# transformer
154+
name = name.replace(".scale", ".weight")
155+
name = name.replace("attn.k_proj", "attn_k")
156+
name = name.replace("attn.q_proj", "attn_q")
157+
name = name.replace("attn.v_proj", "attn_v")
158+
name = name.replace("attn.output_proj", "attn_output")
159+
name = name.replace("sa_norm", "attn_norm")
160+
name = name.replace("mlp.w1", "ffn_gate")
161+
name = name.replace("mlp.w2", "ffn_down")
162+
name = name.replace("mlp.w3", "ffn_up")
163+
name = name.replace("mlp_norm", "ffn_norm")
164+
return name
165+
166+
if "audio_embeddings." in name:
167+
is_decoder = True
168+
if component == "decoder":
169+
name = name.replace("audio_embeddings.", "token_embd.")
170+
data_torch = torch.mm(data_torch, self.projection_tensor.T)
171+
print("Applied projection to audio_embeddings", data_torch.shape)
172+
173+
elif "text_embeddings." in name:
174+
is_backbone = True
175+
name = name.replace("text_embeddings.", "token_embd.")
176+
177+
elif "backbone." in name or "codebook0_head." in name:
178+
is_backbone = True
179+
name = name.replace("backbone.layers.", "blk.")
180+
name = name.replace("backbone.norm.scale", "output_norm.weight")
181+
name = rename_transformer(name)
182+
183+
elif "decoder." in name:
184+
is_decoder = True
185+
name = name.replace("decoder.layers.", "blk.")
186+
name = name.replace("decoder.norm.scale", "output_norm.weight")
187+
name = rename_transformer(name)
188+
189+
elif name == "audio_head":
190+
is_decoder = True
191+
name = "audio_head.weight"
192+
193+
elif name == "projection.weight":
194+
is_decoder = True
195+
name = "inp_proj.weight"
196+
self.projection_tensor = data_torch
197+
198+
if can_quantize:
199+
if self.ftype == gguf.LlamaFileType.ALL_F32:
200+
data_qtype = gguf.GGMLQuantizationType.F32
201+
elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
202+
data_qtype = gguf.GGMLQuantizationType.F16
203+
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
204+
data_qtype = gguf.GGMLQuantizationType.BF16
205+
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
206+
data_qtype = gguf.GGMLQuantizationType.Q8_0
207+
else:
208+
raise ValueError(f"Unsupported file type: {self.ftype}")
209+
210+
data = data_torch.numpy()
211+
212+
try:
213+
data = gguf.quants.quantize(data, data_qtype)
214+
except Exception as e:
215+
logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16")
216+
data_qtype = gguf.GGMLQuantizationType.F16
217+
data = gguf.quants.quantize(data, data_qtype)
218+
219+
if (is_backbone and component == "backbone") or (is_decoder and component == "decoder"):
220+
# reverse shape to make it similar to the internal ggml dimension order
221+
shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}"
222+
logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
223+
224+
if component == "backbone":
225+
self.gguf_writer_backbone.add_tensor(name, data, raw_dtype=data_qtype)
226+
elif component == "decoder":
227+
self.gguf_writer_decoder.add_tensor(name, data, raw_dtype=data_qtype)
228+
229+
def write(self):
230+
self._write_single(self.gguf_writer_backbone, "backbone")
231+
self._write_single(self.gguf_writer_decoder, "decoder")
232+
233+
def _write_single(self, gguf_writer: gguf.GGUFWriter, component: str):
234+
output_path = str(self.fname_out).replace("<component>", component)
235+
gguf_writer.write_header_to_file(path=Path(output_path))
236+
gguf_writer.write_kv_data_to_file()
237+
gguf_writer.write_tensors_to_file(progress=True)
238+
gguf_writer.close()
239+
240+
@staticmethod
241+
def undo_permute(weights: Tensor, n_head: int, n_head_kv: int):
242+
if n_head_kv is not None and n_head != n_head_kv:
243+
n_head = n_head_kv
244+
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
245+
.swapaxes(1, 2)
246+
.reshape(weights.shape))
247+
248+
def parse_args() -> argparse.Namespace:
249+
parser = argparse.ArgumentParser(
250+
description="Convert Sesame model to GGUFs (multiple files)",)
251+
parser.add_argument(
252+
"--outfile", type=Path, default="sesame-csm-<component>.gguf",
253+
help="path to write to, the '<component>' placeholder is required and will be replaced with 'backbone' and 'decoder'",
254+
)
255+
parser.add_argument(
256+
"--vocab", type=Path, default="models/ggml-vocab-llama-bpe.gguf",
257+
help="path to vocab GGUF",
258+
)
259+
parser.add_argument(
260+
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
261+
help="output format",
262+
)
263+
parser.add_argument(
264+
"--bigendian", action="store_true",
265+
help="model is executed on big endian machine",
266+
)
267+
parser.add_argument(
268+
"model", type=Path,
269+
help="path to safetensors or model ID containing model file (if model ID is specified, download from Hugging Face hub)",
270+
nargs="?",
271+
default="sesame/csm-1b:model.safetensors",
272+
)
273+
parser.add_argument(
274+
"--verbose", action="store_true",
275+
help="increase output verbosity",
276+
)
277+
278+
args = parser.parse_args()
279+
if args.model is None:
280+
parser.error("the following arguments are required: model")
281+
return args
282+
283+
284+
def main() -> None:
285+
args = parse_args()
286+
287+
if args.verbose:
288+
logging.basicConfig(level=logging.DEBUG)
289+
else:
290+
logging.basicConfig(level=logging.INFO)
291+
292+
dir_model = args.model
293+
path_vocab = args.vocab
294+
295+
dir_parts = str(dir_model).split(":")
296+
if len(dir_parts) == 2:
297+
try:
298+
dir_model = Path(hf_hub_download(dir_parts[0], dir_parts[1]))
299+
except Exception as e:
300+
print("Error downloading model from Hugging Face hub:", e)
301+
print()
302+
print("Please make sure you have access to the model")
303+
print("Hint: you may need to set HF_TOKEN by running: huggingface-cli login")
304+
305+
if not path_vocab.exists():
306+
raise FileNotFoundError(f"Vocab file not found: {path_vocab} ; Hint: download it from https://github.com/ggml-org/llama.cpp/blob/master/models/ggml-vocab-llama-bpe.gguf")
307+
308+
ftype_map: dict[str, gguf.LlamaFileType] = {
309+
"f32": gguf.LlamaFileType.ALL_F32,
310+
"f16": gguf.LlamaFileType.MOSTLY_F16,
311+
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
312+
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
313+
}
314+
315+
logger.info(f"Loading model: {dir_model}")
316+
317+
with torch.inference_mode():
318+
converter = CSMModelConverter(
319+
safetensors_path=dir_model,
320+
fname_out=args.outfile,
321+
path_to_vocab_gguf=path_vocab,
322+
ftype=ftype_map[args.outtype],
323+
is_big_endian=args.bigendian,
324+
)
325+
converter.write()
326+
327+
328+
if __name__ == '__main__':
329+
main()
330+

0 commit comments

Comments
 (0)