Skip to content

Commit a07ff71

Browse files
committed
Export Kokoro until LSTM
1 parent 413dee4 commit a07ff71

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

examples/models/kokoro/export.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from kokoro import KModel
2+
import torch
3+
from torch import nn
4+
from torch.export import default_decompositions, Dim, export_for_training
5+
6+
torch.manual_seed(42)
7+
8+
class WrappedModel(nn.Module):
9+
def __init__(self, model: nn.Module) -> None:
10+
super().__init__()
11+
self.model = model
12+
13+
def forward(
14+
self,
15+
input_ids: torch.LongTensor,
16+
ref_s: torch.FloatTensor,
17+
speed: float = 1
18+
) -> tuple[torch.FloatTensor, torch.LongTensor]:
19+
return self.model.forward_with_tokens(
20+
input_ids,
21+
ref_s,
22+
speed,
23+
)
24+
25+
repo_id = "hexgrad/Kokoro-82M"
26+
model = KModel(repo_id=repo_id).eval()
27+
wrapped_model = WrappedModel(model)
28+
29+
input_ids = torch.randint(1, 100, (48,))
30+
input_ids = torch.LongTensor([[0, *input_ids, 0]]) # S = [1, 50]
31+
style = torch.randn(1, 256)
32+
speed = torch.randint(1, 10, (1,)).int()
33+
example_inputs = (input_ids, style, speed)
34+
35+
"""
36+
Original model output is:
37+
(tensor([-0.1578, 0.0960, 0.0831, ..., 0.1224, -0.0831, 0.1492]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
38+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
39+
1, 2]))
40+
"""
41+
print(wrapped_model(*example_inputs))
42+
43+
44+
dynamic_shapes = {
45+
# "input_ids": {0: 1, 1: Dim("input_ids", min=2, max=100)},
46+
"input_ids": {},
47+
"ref_s": {},
48+
"speed":{},
49+
}
50+
51+
exported_program = export_for_training(wrapped_model, args=example_inputs, dynamic_shapes=dynamic_shapes, strict=True)
52+
exported_program = exported_program.run_decompositions(default_decompositions())
53+
exported_program.run_decompositions()

0 commit comments

Comments
 (0)