Skip to content

Commit eb5185e

Browse files
tts : add SNAC decoder architecture support for Orpheus TTS
- Add LLM_ARCH_SNAC_DEC architecture enum and name mapping - Define 27 SNAC-specific tensor types for decoder and quantizer - Add tensor name mappings in llama-arch.cpp - Add SNAC_DEC to gguf constants with tensor enums and mappings - Implement SnacDecModel class for model conversion - Add comprehensive SNAC implementation documentation This provides the foundational architecture support for SNAC audio codec. Remaining work includes model loading, forward pass, and TTS tool integration. Addresses issue #208 Co-Authored-By: Jake Cosme <[email protected]>
1 parent 275947c commit eb5185e

File tree

5 files changed

+321
-0
lines changed

5 files changed

+321
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3651,6 +3651,32 @@ def set_gguf_parameters(self):
36513651
self.gguf_writer.add_causal_attention(False)
36523652

36533653

3654+
class SnacDecModel(TextModel):
3655+
model_arch = gguf.MODEL_ARCH.SNAC_DEC
3656+
3657+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3658+
del bid
3659+
3660+
if name.endswith("_g") or name.endswith("_v"):
3661+
logger.debug(f"Skipping weight_norm parameter {name!r}")
3662+
return []
3663+
3664+
logger.info(f"{self.map_tensor_name(name)} -> {data_torch.shape}")
3665+
3666+
return [(self.map_tensor_name(name), data_torch)]
3667+
3668+
def set_vocab(self):
3669+
self._set_vocab_none()
3670+
3671+
def set_gguf_parameters(self):
3672+
super().set_gguf_parameters()
3673+
self.gguf_writer.add_vocab_size(self.hparams.get("codebook_size", 4096))
3674+
self.gguf_writer.add_block_count(len(self.hparams.get("decoder_rates", [7, 7, 3, 3])))
3675+
self.gguf_writer.add_embedding_length(self.hparams.get("latent_dim", 1536))
3676+
self.gguf_writer.add_feed_forward_length(self.hparams.get("decoder_dim", 1536))
3677+
self.gguf_writer.add_causal_attention(False)
3678+
3679+
36543680
@ModelBase.register("Qwen2MoeForCausalLM")
36553681
class Qwen2MoeModel(TextModel):
36563682
model_arch = gguf.MODEL_ARCH.QWEN2MOE

docs/SNAC_IMPLEMENTATION.md

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# SNAC Decoder Implementation for Orpheus TTS
2+
3+
## Overview
4+
5+
This document describes the implementation of SNAC (Multi-Scale Neural Audio Codec) decoder support in llama.cpp for Orpheus TTS models.
6+
7+
## Current Status
8+
9+
### ✅ Completed
10+
11+
1. **Architecture Infrastructure**
12+
- Added `LLM_ARCH_SNAC_DEC` architecture enum
13+
- Registered "snac-dec" architecture name
14+
- Defined 27 SNAC-specific tensor types
15+
- Added tensor name mappings for decoder and quantizer components
16+
17+
2. **GGUF Constants**
18+
- Added `MODEL_ARCH.SNAC_DEC` to gguf constants
19+
- Defined tensor enums for all SNAC components
20+
- Added tensor name format strings
21+
22+
3. **Model Conversion**
23+
- Implemented `SnacDecModel` class in `convert_hf_to_gguf.py`
24+
- Handles weight_norm parameters (skips _g and _v suffixes)
25+
- Configures SNAC-specific hyperparameters
26+
27+
### 🚧 In Progress / TODO
28+
29+
1. **Model Loading (llama-model.cpp)**
30+
- Need to implement SNAC decoder model loading
31+
- Load decoder convolution layers
32+
- Load vector quantizer components (in_proj, out_proj, codebook)
33+
- Load attention layers if present
34+
- Handle Snake activation parameters
35+
36+
2. **Forward Pass Implementation (llama.cpp)**
37+
- Implement SNAC decoder forward pass
38+
- Vector quantization decoding (from_codes)
39+
- Decoder blocks with:
40+
- Transposed convolutions (upsampling)
41+
- Residual units with dilated convolutions
42+
- Snake activation function
43+
- Local multi-head attention (if present)
44+
- Output convolution and tanh activation
45+
46+
3. **TTS Tool Integration (tools/tts/tts.cpp)**
47+
- Add SNAC decoder option to TTS tool
48+
- Support for multi-scale code input
49+
- Audio generation from hierarchical codes
50+
- Integration with Orpheus TTS models
51+
52+
4. **Testing**
53+
- Download and convert SNAC models from HuggingFace
54+
- Test with Orpheus TTS models
55+
- Validate audio quality
56+
- Performance benchmarking
57+
58+
## SNAC Architecture
59+
60+
### Components
61+
62+
1. **Encoder** (not needed for TTS, only for training)
63+
- Input convolution
64+
- Encoder blocks with strided convolutions
65+
- Local attention (optional)
66+
- Output convolution
67+
68+
2. **Vector Quantizer** (needed for decoding)
69+
- 4 quantization levels with different strides [8, 4, 2, 1]
70+
- Each level has:
71+
- `in_proj`: Projects latent to codebook dimension
72+
- `codebook`: Embedding table (4096 x 8)
73+
- `out_proj`: Projects back to latent dimension
74+
- Residual quantization across levels
75+
76+
3. **Decoder** (main component needed)
77+
- Input convolution (or direct from quantizer output)
78+
- Local attention (optional)
79+
- Decoder blocks (4 blocks for standard config):
80+
- Transposed convolution for upsampling
81+
- 3 residual units with dilations [1, 3, 9]
82+
- Snake activation
83+
- Output convolution + tanh
84+
85+
### Snake Activation
86+
87+
Formula: `x + (1/alpha) * sin^2(alpha * x)`
88+
89+
Can be implemented using existing ggml operations:
90+
```c
91+
// x_scaled = x * alpha
92+
// sin_x = sin(x_scaled)
93+
// sin2_x = sin_x * sin_x
94+
// result = x + sin2_x / alpha
95+
```
96+
97+
### Tensor Naming Convention
98+
99+
Decoder tensors:
100+
- `decoder.conv_in` - Input convolution
101+
- `decoder.attn_norm`, `decoder.attn_q/k/v/out` - Attention (if present)
102+
- `decoder.block.{i}.conv_up` - Upsampling transposed conv
103+
- `decoder.block.{i}.conv1/2/3` - Residual unit convolutions
104+
- `decoder.block.{i}.snake_alpha` - Snake activation parameters
105+
- `decoder.conv_out` - Output convolution
106+
107+
Quantizer tensors:
108+
- `quantizer.{i}.in_proj` - Input projection for level i
109+
- `quantizer.{i}.out_proj` - Output projection for level i
110+
- `quantizer.{i}.codebook` - Codebook embeddings for level i
111+
112+
## Model Conversion
113+
114+
### Converting SNAC Models
115+
116+
```bash
117+
# Download SNAC model
118+
git clone https://huggingface.co/hubertsiuzdak/snac_24khz
119+
120+
# Convert to GGUF
121+
python convert_hf_to_gguf.py snac_24khz \
122+
--outfile snac-24khz-f16.gguf \
123+
--outtype f16
124+
```
125+
126+
### Expected Hyperparameters
127+
128+
From SNAC config.json:
129+
```json
130+
{
131+
"sampling_rate": 24000,
132+
"encoder_dim": 64,
133+
"encoder_rates": [3, 3, 7, 7],
134+
"latent_dim": 1344,
135+
"decoder_dim": 1536,
136+
"decoder_rates": [7, 7, 3, 3],
137+
"attn_window_size": 32,
138+
"codebook_size": 4096,
139+
"codebook_dim": 8,
140+
"vq_strides": [8, 4, 2, 1]
141+
}
142+
```
143+
144+
## Integration with Orpheus TTS
145+
146+
Orpheus TTS uses a two-model architecture:
147+
1. **Text-to-Codes Model**: LLM that generates hierarchical audio codes
148+
2. **Codes-to-Speech Model**: SNAC decoder that converts codes to audio
149+
150+
Usage flow:
151+
```
152+
Text → Orpheus LLM → Multi-scale codes → SNAC Decoder → Audio waveform
153+
```
154+
155+
## References
156+
157+
- SNAC Paper: https://arxiv.org/abs/2410.14411
158+
- SNAC GitHub: https://github.com/hubertsiuzdak/snac
159+
- Orpheus Models: https://huggingface.co/collections/canopylabs/orpheus-tts-67d9ea3f6c05a941c06ad9d2
160+
- OuteTTS Reference: PR #10784 in llama.cpp
161+
162+
## Implementation Notes
163+
164+
### Key Differences from WavTokenizer
165+
166+
1. **Multi-scale Quantization**: SNAC uses 4 levels with different temporal resolutions
167+
2. **Snake Activation**: Custom activation function (WavTokenizer uses standard activations)
168+
3. **Simpler Architecture**: No PosNet or ConvNext blocks
169+
4. **Hierarchical Codes**: Variable-length codes at different scales
170+
171+
### Performance Considerations
172+
173+
- SNAC is designed for low bitrate (0.98-2.6 kbps)
174+
- Decoder is relatively lightweight
175+
- Main computation in transposed convolutions and residual blocks
176+
- Attention is optional and can be disabled for faster inference
177+
178+
## Next Steps
179+
180+
1. Implement model loading in `llama-model.cpp`
181+
2. Implement forward pass in `llama.cpp`
182+
3. Add SNAC support to TTS tool
183+
4. Test with Orpheus models
184+
5. Add documentation and examples
185+
6. Performance optimization

gguf-py/gguf/constants.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ class MODEL_ARCH(IntEnum):
400400
GRANITE_HYBRID = auto()
401401
CHAMELEON = auto()
402402
WAVTOKENIZER_DEC = auto()
403+
SNAC_DEC = auto()
403404
PLM = auto()
404405
BAILINGMOE = auto()
405406
BAILINGMOE2 = auto()
@@ -600,6 +601,33 @@ class MODEL_TENSOR(IntEnum):
600601
SHORTCONV_CONV = auto()
601602
SHORTCONV_INPROJ = auto()
602603
SHORTCONV_OUTPROJ = auto()
604+
SNAC_ENC_CONV_IN = auto()
605+
SNAC_ENC_BLK_CONV1 = auto()
606+
SNAC_ENC_BLK_CONV2 = auto()
607+
SNAC_ENC_BLK_CONV3 = auto()
608+
SNAC_ENC_BLK_CONV_DS = auto()
609+
SNAC_ENC_BLK_SNAKE_ALPHA = auto()
610+
SNAC_ENC_CONV_OUT = auto()
611+
SNAC_ENC_ATTN_NORM = auto()
612+
SNAC_ENC_ATTN_Q = auto()
613+
SNAC_ENC_ATTN_K = auto()
614+
SNAC_ENC_ATTN_V = auto()
615+
SNAC_ENC_ATTN_OUT = auto()
616+
SNAC_VQ_IN_PROJ = auto()
617+
SNAC_VQ_OUT_PROJ = auto()
618+
SNAC_VQ_CODEBOOK = auto()
619+
SNAC_DEC_CONV_IN = auto()
620+
SNAC_DEC_ATTN_NORM = auto()
621+
SNAC_DEC_ATTN_Q = auto()
622+
SNAC_DEC_ATTN_K = auto()
623+
SNAC_DEC_ATTN_V = auto()
624+
SNAC_DEC_ATTN_OUT = auto()
625+
SNAC_DEC_BLK_CONV_UP = auto()
626+
SNAC_DEC_BLK_CONV1 = auto()
627+
SNAC_DEC_BLK_CONV2 = auto()
628+
SNAC_DEC_BLK_CONV3 = auto()
629+
SNAC_DEC_BLK_SNAKE_ALPHA = auto()
630+
SNAC_DEC_CONV_OUT = auto()
603631
# vision
604632
V_MMPROJ = auto()
605633
V_MMPROJ_FC = auto()
@@ -745,6 +773,7 @@ class MODEL_TENSOR(IntEnum):
745773
MODEL_ARCH.GRANITE_HYBRID: "granitehybrid",
746774
MODEL_ARCH.CHAMELEON: "chameleon",
747775
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
776+
MODEL_ARCH.SNAC_DEC: "snac-dec",
748777
MODEL_ARCH.PLM: "plm",
749778
MODEL_ARCH.BAILINGMOE: "bailingmoe",
750779
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
@@ -946,6 +975,21 @@ class MODEL_TENSOR(IntEnum):
946975
MODEL_TENSOR.SHORTCONV_CONV: "blk.{bid}.shortconv.conv",
947976
MODEL_TENSOR.SHORTCONV_INPROJ: "blk.{bid}.shortconv.in_proj",
948977
MODEL_TENSOR.SHORTCONV_OUTPROJ: "blk.{bid}.shortconv.out_proj",
978+
MODEL_TENSOR.SNAC_DEC_CONV_IN: "decoder.conv_in",
979+
MODEL_TENSOR.SNAC_DEC_ATTN_NORM: "decoder.attn_norm",
980+
MODEL_TENSOR.SNAC_DEC_ATTN_Q: "decoder.attn_q",
981+
MODEL_TENSOR.SNAC_DEC_ATTN_K: "decoder.attn_k",
982+
MODEL_TENSOR.SNAC_DEC_ATTN_V: "decoder.attn_v",
983+
MODEL_TENSOR.SNAC_DEC_ATTN_OUT: "decoder.attn_out",
984+
MODEL_TENSOR.SNAC_DEC_BLK_CONV_UP: "decoder.block.{bid}.conv_up",
985+
MODEL_TENSOR.SNAC_DEC_BLK_CONV1: "decoder.block.{bid}.conv1",
986+
MODEL_TENSOR.SNAC_DEC_BLK_CONV2: "decoder.block.{bid}.conv2",
987+
MODEL_TENSOR.SNAC_DEC_BLK_CONV3: "decoder.block.{bid}.conv3",
988+
MODEL_TENSOR.SNAC_DEC_BLK_SNAKE_ALPHA: "decoder.block.{bid}.snake_alpha",
989+
MODEL_TENSOR.SNAC_DEC_CONV_OUT: "decoder.conv_out",
990+
MODEL_TENSOR.SNAC_VQ_IN_PROJ: "quantizer.{bid}.in_proj",
991+
MODEL_TENSOR.SNAC_VQ_OUT_PROJ: "quantizer.{bid}.out_proj",
992+
MODEL_TENSOR.SNAC_VQ_CODEBOOK: "quantizer.{bid}.codebook",
949993
# vision
950994
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
951995
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
@@ -2518,6 +2562,23 @@ class MODEL_TENSOR(IntEnum):
25182562
MODEL_TENSOR.POSNET_ATTN_V,
25192563
MODEL_TENSOR.POSNET_ATTN_OUT,
25202564
],
2565+
MODEL_ARCH.SNAC_DEC: [
2566+
MODEL_TENSOR.SNAC_DEC_CONV_IN,
2567+
MODEL_TENSOR.SNAC_DEC_ATTN_NORM,
2568+
MODEL_TENSOR.SNAC_DEC_ATTN_Q,
2569+
MODEL_TENSOR.SNAC_DEC_ATTN_K,
2570+
MODEL_TENSOR.SNAC_DEC_ATTN_V,
2571+
MODEL_TENSOR.SNAC_DEC_ATTN_OUT,
2572+
MODEL_TENSOR.SNAC_DEC_BLK_CONV_UP,
2573+
MODEL_TENSOR.SNAC_DEC_BLK_CONV1,
2574+
MODEL_TENSOR.SNAC_DEC_BLK_CONV2,
2575+
MODEL_TENSOR.SNAC_DEC_BLK_CONV3,
2576+
MODEL_TENSOR.SNAC_DEC_BLK_SNAKE_ALPHA,
2577+
MODEL_TENSOR.SNAC_DEC_CONV_OUT,
2578+
MODEL_TENSOR.SNAC_VQ_IN_PROJ,
2579+
MODEL_TENSOR.SNAC_VQ_OUT_PROJ,
2580+
MODEL_TENSOR.SNAC_VQ_CODEBOOK,
2581+
],
25212582
MODEL_ARCH.BAILINGMOE: [
25222583
MODEL_TENSOR.TOKEN_EMBD,
25232584
MODEL_TENSOR.OUTPUT_NORM,

src/llama-arch.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
8383
{ LLM_ARCH_GRANITE_HYBRID, "granitehybrid" },
8484
{ LLM_ARCH_CHAMELEON, "chameleon" },
8585
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
86+
{ LLM_ARCH_SNAC_DEC, "snac-dec" },
8687
{ LLM_ARCH_PLM, "plm" },
8788
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
8889
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
@@ -1926,6 +1927,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
19261927
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
19271928
},
19281929
},
1930+
{
1931+
LLM_ARCH_SNAC_DEC,
1932+
{
1933+
{ LLM_TENSOR_SNAC_DEC_CONV_IN, "decoder.conv_in" },
1934+
{ LLM_TENSOR_SNAC_DEC_ATTN_NORM, "decoder.attn_norm" },
1935+
{ LLM_TENSOR_SNAC_DEC_ATTN_Q, "decoder.attn_q" },
1936+
{ LLM_TENSOR_SNAC_DEC_ATTN_K, "decoder.attn_k" },
1937+
{ LLM_TENSOR_SNAC_DEC_ATTN_V, "decoder.attn_v" },
1938+
{ LLM_TENSOR_SNAC_DEC_ATTN_OUT, "decoder.attn_out" },
1939+
{ LLM_TENSOR_SNAC_DEC_BLK_CONV_UP, "decoder.block.%d.conv_up" },
1940+
{ LLM_TENSOR_SNAC_DEC_BLK_CONV1, "decoder.block.%d.conv1" },
1941+
{ LLM_TENSOR_SNAC_DEC_BLK_CONV2, "decoder.block.%d.conv2" },
1942+
{ LLM_TENSOR_SNAC_DEC_BLK_CONV3, "decoder.block.%d.conv3" },
1943+
{ LLM_TENSOR_SNAC_DEC_BLK_SNAKE_ALPHA, "decoder.block.%d.snake_alpha" },
1944+
{ LLM_TENSOR_SNAC_DEC_CONV_OUT, "decoder.conv_out" },
1945+
{ LLM_TENSOR_SNAC_VQ_IN_PROJ, "quantizer.%d.in_proj" },
1946+
{ LLM_TENSOR_SNAC_VQ_OUT_PROJ, "quantizer.%d.out_proj" },
1947+
{ LLM_TENSOR_SNAC_VQ_CODEBOOK, "quantizer.%d.codebook" },
1948+
},
1949+
},
19291950
{
19301951
LLM_ARCH_BAILINGMOE,
19311952
{

src/llama-arch.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ enum llm_arch {
8787
LLM_ARCH_GRANITE_HYBRID,
8888
LLM_ARCH_CHAMELEON,
8989
LLM_ARCH_WAVTOKENIZER_DEC,
90+
LLM_ARCH_SNAC_DEC,
9091
LLM_ARCH_PLM,
9192
LLM_ARCH_BAILINGMOE,
9293
LLM_ARCH_BAILINGMOE2,
@@ -461,6 +462,33 @@ enum llm_tensor {
461462
LLM_TENSOR_NEXTN_HNORM,
462463
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
463464
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
465+
LLM_TENSOR_SNAC_ENC_CONV_IN,
466+
LLM_TENSOR_SNAC_ENC_BLK_CONV1,
467+
LLM_TENSOR_SNAC_ENC_BLK_CONV2,
468+
LLM_TENSOR_SNAC_ENC_BLK_CONV3,
469+
LLM_TENSOR_SNAC_ENC_BLK_CONV_DS,
470+
LLM_TENSOR_SNAC_ENC_BLK_SNAKE_ALPHA,
471+
LLM_TENSOR_SNAC_ENC_CONV_OUT,
472+
LLM_TENSOR_SNAC_ENC_ATTN_NORM,
473+
LLM_TENSOR_SNAC_ENC_ATTN_Q,
474+
LLM_TENSOR_SNAC_ENC_ATTN_K,
475+
LLM_TENSOR_SNAC_ENC_ATTN_V,
476+
LLM_TENSOR_SNAC_ENC_ATTN_OUT,
477+
LLM_TENSOR_SNAC_VQ_IN_PROJ,
478+
LLM_TENSOR_SNAC_VQ_OUT_PROJ,
479+
LLM_TENSOR_SNAC_VQ_CODEBOOK,
480+
LLM_TENSOR_SNAC_DEC_CONV_IN,
481+
LLM_TENSOR_SNAC_DEC_ATTN_NORM,
482+
LLM_TENSOR_SNAC_DEC_ATTN_Q,
483+
LLM_TENSOR_SNAC_DEC_ATTN_K,
484+
LLM_TENSOR_SNAC_DEC_ATTN_V,
485+
LLM_TENSOR_SNAC_DEC_ATTN_OUT,
486+
LLM_TENSOR_SNAC_DEC_BLK_CONV_UP,
487+
LLM_TENSOR_SNAC_DEC_BLK_CONV1,
488+
LLM_TENSOR_SNAC_DEC_BLK_CONV2,
489+
LLM_TENSOR_SNAC_DEC_BLK_CONV3,
490+
LLM_TENSOR_SNAC_DEC_BLK_SNAKE_ALPHA,
491+
LLM_TENSOR_SNAC_DEC_CONV_OUT,
464492
};
465493

466494
enum llm_tensor_layer {

0 commit comments

Comments
 (0)