|
| 1 | +import random |
1 | 2 | from math import prod |
2 | | -from typing import Optional, Sequence |
| 3 | +from typing import Optional, Sequence, Union |
3 | 4 |
|
4 | 5 | import torch |
5 | 6 | from torch import Tensor, nn |
|
15 | 16 | Schedule, |
16 | 17 | ) |
17 | 18 | from .modules import Encoder1d, ResnetBlock1d, UNet1d |
| 19 | +from .utils import default, to_list |
18 | 20 |
|
19 | 21 | """ Diffusion Classes (generic for 1d data) """ |
20 | 22 |
|
@@ -95,24 +97,32 @@ def sample( |
95 | 97 |
|
96 | 98 |
|
97 | 99 | class DiffusionUpsampler1d(Model1d): |
98 | | - def __init__(self, factor: int, in_channels: int, *args, **kwargs): |
99 | | - self.factor = factor |
| 100 | + def __init__( |
| 101 | + self, factor: Union[int, Sequence[int]], in_channels: int, *args, **kwargs |
| 102 | + ): |
| 103 | + self.factor = to_list(factor) |
100 | 104 | default_kwargs = dict( |
101 | 105 | in_channels=in_channels, |
102 | 106 | context_channels=[in_channels], |
103 | 107 | ) |
104 | 108 | super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore |
105 | 109 |
|
106 | | - def forward(self, x: Tensor, **kwargs) -> Tensor: |
| 110 | + def forward(self, x: Tensor, factor: Optional[int] = None, **kwargs) -> Tensor: |
| 111 | + # Either user provides factor or we pick one at random |
| 112 | + factor = default(factor, random.choice(self.factor)) |
107 | 113 | # Downsample by picking every `factor` item |
108 | | - downsampled = x[:, :, :: self.factor] |
| 114 | + downsampled = x[:, :, ::factor] |
109 | 115 | # Upsample by interleaving to get context |
110 | | - context = torch.repeat_interleave(downsampled, repeats=self.factor, dim=2) |
| 116 | + context = torch.repeat_interleave(downsampled, repeats=factor, dim=2) |
111 | 117 | return self.diffusion(x, context=[context], **kwargs) |
112 | 118 |
|
113 | | - def sample(self, undersampled: Tensor, *args, **kwargs): # type: ignore |
| 119 | + def sample( # type: ignore |
| 120 | + self, undersampled: Tensor, factor: Optional[int] = None, *args, **kwargs |
| 121 | + ): |
| 122 | + # Either user provides factor or we pick the first |
| 123 | + factor = default(factor, self.factor[0]) |
114 | 124 | # Upsample context by interleaving |
115 | | - context = torch.repeat_interleave(undersampled, repeats=self.factor, dim=2) |
| 125 | + context = torch.repeat_interleave(undersampled, repeats=factor, dim=2) |
116 | 126 | noise = torch.randn_like(context) |
117 | 127 | default_kwargs = dict(context=[context]) |
118 | 128 | return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore |
|
0 commit comments