|
| 1 | +from kokoro import KModel |
| 2 | +import torch |
| 3 | +from torch import nn |
| 4 | +from torch.export import default_decompositions, Dim, export_for_training |
| 5 | + |
| 6 | +torch.manual_seed(42) |
| 7 | + |
| 8 | +class WrappedModel(nn.Module): |
| 9 | + def __init__(self, model: nn.Module) -> None: |
| 10 | + super().__init__() |
| 11 | + self.model = model |
| 12 | + |
| 13 | + def forward( |
| 14 | + self, |
| 15 | + input_ids: torch.LongTensor, |
| 16 | + ref_s: torch.FloatTensor, |
| 17 | + speed: float = 1 |
| 18 | + ) -> tuple[torch.FloatTensor, torch.LongTensor]: |
| 19 | + return self.model.forward_with_tokens( |
| 20 | + input_ids, |
| 21 | + ref_s, |
| 22 | + speed, |
| 23 | + ) |
| 24 | + |
| 25 | +repo_id = "hexgrad/Kokoro-82M" |
| 26 | +model = KModel(repo_id=repo_id).eval() |
| 27 | +wrapped_model = WrappedModel(model) |
| 28 | + |
| 29 | +input_ids = torch.randint(1, 100, (48,)) |
| 30 | +input_ids = torch.LongTensor([[0, *input_ids, 0]]) # S = [1, 50] |
| 31 | +style = torch.randn(1, 256) |
| 32 | +speed = torch.randint(1, 10, (1,)).int() |
| 33 | +example_inputs = (input_ids, style, speed) |
| 34 | + |
| 35 | +""" |
| 36 | +Original model output is: |
| 37 | +(tensor([-0.1578, 0.0960, 0.0831, ..., 0.1224, -0.0831, 0.1492]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, |
| 38 | + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, |
| 39 | + 1, 2])) |
| 40 | +""" |
| 41 | +print(wrapped_model(*example_inputs)) |
| 42 | + |
| 43 | + |
| 44 | +dynamic_shapes = { |
| 45 | + # "input_ids": {0: 1, 1: Dim("input_ids", min=2, max=100)}, |
| 46 | + "input_ids": {}, |
| 47 | + "ref_s": {}, |
| 48 | + "speed":{}, |
| 49 | +} |
| 50 | + |
| 51 | +exported_program = export_for_training(wrapped_model, args=example_inputs, dynamic_shapes=dynamic_shapes, strict=True) |
| 52 | +exported_program = exported_program.run_decompositions(default_decompositions()) |
| 53 | +exported_program.run_decompositions() |
0 commit comments