Skip to content

Commit 36fc9be

Browse files
feat: add option to train upsampler with multiple factors
1 parent 73a0d5c commit 36fc9be

File tree

3 files changed

+28
-10
lines changed

3 files changed

+28
-10
lines changed

audio_diffusion_pytorch/model.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import random
12
from math import prod
2-
from typing import Optional, Sequence
3+
from typing import Optional, Sequence, Union
34

45
import torch
56
from torch import Tensor, nn
@@ -15,6 +16,7 @@
1516
Schedule,
1617
)
1718
from .modules import Encoder1d, ResnetBlock1d, UNet1d
19+
from .utils import default, to_list
1820

1921
""" Diffusion Classes (generic for 1d data) """
2022

@@ -95,24 +97,32 @@ def sample(
9597

9698

9799
class DiffusionUpsampler1d(Model1d):
98-
def __init__(self, factor: int, in_channels: int, *args, **kwargs):
99-
self.factor = factor
100+
def __init__(
101+
self, factor: Union[int, Sequence[int]], in_channels: int, *args, **kwargs
102+
):
103+
self.factor = to_list(factor)
100104
default_kwargs = dict(
101105
in_channels=in_channels,
102106
context_channels=[in_channels],
103107
)
104108
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore
105109

106-
def forward(self, x: Tensor, **kwargs) -> Tensor:
110+
def forward(self, x: Tensor, factor: Optional[int] = None, **kwargs) -> Tensor:
111+
# Either user provides factor or we pick one at random
112+
factor = default(factor, random.choice(self.factor))
107113
# Downsample by picking every `factor` item
108-
downsampled = x[:, :, :: self.factor]
114+
downsampled = x[:, :, ::factor]
109115
# Upsample by interleaving to get context
110-
context = torch.repeat_interleave(downsampled, repeats=self.factor, dim=2)
116+
context = torch.repeat_interleave(downsampled, repeats=factor, dim=2)
111117
return self.diffusion(x, context=[context], **kwargs)
112118

113-
def sample(self, undersampled: Tensor, *args, **kwargs): # type: ignore
119+
def sample( # type: ignore
120+
self, undersampled: Tensor, factor: Optional[int] = None, *args, **kwargs
121+
):
122+
# Either user provides factor or we pick the first
123+
factor = default(factor, self.factor[0])
114124
# Upsample context by interleaving
115-
context = torch.repeat_interleave(undersampled, repeats=self.factor, dim=2)
125+
context = torch.repeat_interleave(undersampled, repeats=factor, dim=2)
116126
noise = torch.randn_like(context)
117127
default_kwargs = dict(context=[context])
118128
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore

audio_diffusion_pytorch/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from inspect import isfunction
2-
from typing import Callable, Optional, TypeVar, Union
2+
from typing import Callable, List, Optional, Sequence, TypeVar, Union
33

44
from typing_extensions import TypeGuard
55

@@ -22,3 +22,11 @@ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
2222
if exists(val):
2323
return val
2424
return d() if isfunction(d) else d
25+
26+
27+
def to_list(val: Union[T, Sequence[T]]) -> List[T]:
28+
if isinstance(val, tuple):
29+
return list(val)
30+
if isinstance(val, list):
31+
return val
32+
return [val] # type: ignore

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.24",
6+
version="0.0.25",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)