|
13 | 13 | Sampler, |
14 | 14 | Schedule, |
15 | 15 | ) |
16 | | -from .modules import Bottleneck, MultiEncoder1d, UNet1d, UNetConditional1d |
| 16 | +from .modules import ( |
| 17 | + Bottleneck, |
| 18 | + MultiEncoder1d, |
| 19 | + SinusoidalEmbedding, |
| 20 | + UNet1d, |
| 21 | + UNetConditional1d, |
| 22 | +) |
17 | 23 | from .utils import default, downsample, exists, to_list, upsample |
18 | 24 |
|
19 | 25 | """ |
@@ -65,46 +71,64 @@ def sample( |
65 | 71 |
|
66 | 72 | class DiffusionUpsampler1d(Model1d): |
67 | 73 | def __init__( |
68 | | - self, factor: Union[int, Sequence[int]], in_channels: int, *args, **kwargs |
| 74 | + self, |
| 75 | + in_channels: int, |
| 76 | + factor: Union[int, Sequence[int]], |
| 77 | + factor_features: Optional[int] = None, |
| 78 | + *args, |
| 79 | + **kwargs |
69 | 80 | ): |
70 | 81 | self.factors = to_list(factor) |
| 82 | + self.use_conditioning = exists(factor_features) |
| 83 | + |
71 | 84 | default_kwargs = dict( |
72 | 85 | in_channels=in_channels, |
73 | 86 | context_channels=[in_channels], |
| 87 | + context_features=factor_features if self.use_conditioning else None, |
74 | 88 | ) |
75 | 89 | super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore |
76 | 90 |
|
77 | | - def random_reupsample(self, x: Tensor) -> Tensor: |
| 91 | + if self.use_conditioning: |
| 92 | + assert exists(factor_features) |
| 93 | + self.to_features = SinusoidalEmbedding(dim=factor_features) |
| 94 | + |
| 95 | + def random_reupsample(self, x: Tensor) -> Tuple[Tensor, Tensor]: |
78 | 96 | batch_size, factors = x.shape[0], self.factors |
79 | 97 | # Pick random factor for each batch element |
80 | | - factor_batch_idx = torch.randint(0, len(factors), (batch_size,)) |
| 98 | + random_factors = torch.randint(0, len(factors), (batch_size,)) |
81 | 99 | x = x.clone() |
82 | 100 |
|
83 | 101 | for i, factor in enumerate(factors): |
84 | 102 | # Pick random items with current factor, skip if 0 |
85 | | - n = torch.count_nonzero(factor_batch_idx == i) |
| 103 | + n = torch.count_nonzero(random_factors == i) |
86 | 104 | if n > 0: |
87 | | - waveforms = x[factor_batch_idx == i] |
| 105 | + waveforms = x[random_factors == i] |
88 | 106 | # Downsample and reupsample items |
89 | 107 | downsampled = downsample(waveforms, factor=factor) |
90 | 108 | reupsampled = upsample(downsampled, factor=factor) |
91 | 109 | # Save reupsampled version in place |
92 | | - x[factor_batch_idx == i] = reupsampled |
93 | | - return x |
| 110 | + x[random_factors == i] = reupsampled |
| 111 | + return x, random_factors |
94 | 112 |
|
95 | 113 | def forward(self, x: Tensor, **kwargs) -> Tensor: |
96 | | - channels = self.random_reupsample(x) |
97 | | - return self.diffusion(x, channels_list=[channels], **kwargs) |
| 114 | + channels, factors = self.random_reupsample(x) |
| 115 | + features = self.to_features(factors) if self.use_conditioning else None |
| 116 | + return self.diffusion(x, channels_list=[channels], features=features, **kwargs) |
98 | 117 |
|
99 | 118 | def sample( # type: ignore |
100 | 119 | self, undersampled: Tensor, factor: Optional[int] = None, *args, **kwargs |
101 | 120 | ): |
102 | 121 | # Either user provides factor or we pick the first |
| 122 | + batch_size, device = undersampled.shape[0], undersampled.device |
103 | 123 | factor = default(factor, self.factors[0]) |
104 | | - # Upsample channels |
| 124 | + # Upsample channels by interpolation |
105 | 125 | channels = upsample(undersampled, factor=factor) |
| 126 | + # Compute features if conditioning on factor |
| 127 | + factors = torch.tensor([factor] * batch_size, device=device) |
| 128 | + features = self.to_features(factors) if self.use_conditioning else None |
| 129 | + # Diffuse upsampled |
106 | 130 | noise = torch.randn_like(channels) |
107 | | - default_kwargs = dict(channels_list=[channels]) |
| 131 | + default_kwargs = dict(channels_list=[channels], features=features) |
108 | 132 | return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore |
109 | 133 |
|
110 | 134 |
|
|
0 commit comments