|
1 | | -import random |
2 | 1 | from typing import Any, Optional, Sequence, Tuple, Union |
3 | 2 |
|
4 | 3 | import torch |
|
15 | 14 | Schedule, |
16 | 15 | ) |
17 | 16 | from .modules import Bottleneck, MultiEncoder1d, UNet1d, UNetConditional1d |
18 | | -from .utils import default, exists, to_list |
| 17 | +from .utils import default, downsample, exists, to_list, upsample |
19 | 18 |
|
20 | 19 | """ |
21 | 20 | Diffusion Classes (generic for 1d data) |
@@ -68,29 +67,41 @@ class DiffusionUpsampler1d(Model1d): |
68 | 67 | def __init__( |
69 | 68 | self, factor: Union[int, Sequence[int]], in_channels: int, *args, **kwargs |
70 | 69 | ): |
71 | | - self.factor = to_list(factor) |
| 70 | + self.factors = to_list(factor) |
72 | 71 | default_kwargs = dict( |
73 | 72 | in_channels=in_channels, |
74 | 73 | context_channels=[in_channels], |
75 | 74 | ) |
76 | 75 | super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore |
77 | 76 |
|
78 | | - def forward(self, x: Tensor, factor: Optional[int] = None, **kwargs) -> Tensor: |
79 | | - # Either user provides factor or we pick one at random |
80 | | - factor = default(factor, random.choice(self.factor)) |
81 | | - # Downsample by picking every `factor` item |
82 | | - downsampled = x[:, :, ::factor] |
83 | | - # Upsample by interleaving to get context |
84 | | - channels = torch.repeat_interleave(downsampled, repeats=factor, dim=2) |
| 77 | + def random_reupsample(self, x: Tensor) -> Tensor: |
| 78 | + batch_size, factors = x.shape[0], self.factors |
| 79 | + # Pick random factor for each batch element |
| 80 | + factor_batch_idx = torch.randint(0, len(factors), (batch_size,)) |
| 81 | + |
| 82 | + for i, factor in enumerate(factors): |
| 83 | + # Pick random items with current factor, skip if 0 |
| 84 | + n = torch.count_nonzero(factor_batch_idx == i) |
| 85 | + if n > 0: |
| 86 | + waveforms = x[factor_batch_idx == i] |
| 87 | + # Downsample and reupsample items |
| 88 | + downsampled = downsample(waveforms, factor=factor) |
| 89 | + reupsampled = upsample(downsampled, factor=factor) |
| 90 | + # Save reupsampled version in place |
| 91 | + x[factor_batch_idx == i] = reupsampled |
| 92 | + return x |
| 93 | + |
| 94 | + def forward(self, x: Tensor, **kwargs) -> Tensor: |
| 95 | + channels = self.random_reupsample(x) |
85 | 96 | return self.diffusion(x, channels_list=[channels], **kwargs) |
86 | 97 |
|
87 | 98 | def sample( # type: ignore |
88 | 99 | self, undersampled: Tensor, factor: Optional[int] = None, *args, **kwargs |
89 | 100 | ): |
90 | 101 | # Either user provides factor or we pick the first |
91 | | - factor = default(factor, self.factor[0]) |
92 | | - # Upsample channels by interleaving |
93 | | - channels = torch.repeat_interleave(undersampled, repeats=factor, dim=2) |
| 102 | + factor = default(factor, self.factors[0]) |
| 103 | + # Upsample channels |
| 104 | + channels = upsample(undersampled, factor=factor) |
94 | 105 | noise = torch.randn_like(channels) |
95 | 106 | default_kwargs = dict(channels_list=[channels]) |
96 | 107 | return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore |
|
0 commit comments