|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
| 16 | +import gc |
16 | 17 | import unittest |
17 | 18 |
|
18 | 19 | import torch |
| 20 | +from datasets import load_dataset |
| 21 | +from parameterized import parameterized |
19 | 22 |
|
20 | 23 | from diffusers import AutoencoderOobleck |
21 | 24 | from diffusers.utils.testing_utils import ( |
| 25 | + backend_empty_cache, |
22 | 26 | enable_full_determinism, |
23 | 27 | floats_tensor, |
24 | 28 | require_torch_accelerator_with_training, |
| 29 | + slow, |
25 | 30 | torch_all_close, |
26 | 31 | torch_device, |
27 | 32 | ) |
@@ -151,3 +156,116 @@ def test_forward_with_norm_groups(self): |
151 | 156 | @unittest.skip("No attention module used in this model") |
152 | 157 | def test_set_attn_processor_for_determinism(self): |
153 | 158 | return |
| 159 | + |
| 160 | + |
| 161 | +@slow |
| 162 | +class AutoencoderOobleckIntegrationTests(unittest.TestCase): |
| 163 | + def tearDown(self): |
| 164 | + # clean up the VRAM after each test |
| 165 | + super().tearDown() |
| 166 | + gc.collect() |
| 167 | + backend_empty_cache(torch_device) |
| 168 | + |
| 169 | + def _load_datasamples(self, num_samples): |
| 170 | + ds = load_dataset( |
| 171 | + "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True |
| 172 | + ) |
| 173 | + # automatic decoding with librispeech |
| 174 | + speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] |
| 175 | + |
| 176 | + return torch.nn.utils.rnn.pad_sequence( |
| 177 | + [torch.from_numpy(x["array"]) for x in speech_samples], batch_first=True |
| 178 | + ) |
| 179 | + |
| 180 | + def get_audio(self, audio_sample_size=2097152, fp16=False): |
| 181 | + dtype = torch.float16 if fp16 else torch.float32 |
| 182 | + audio = self._load_datasamples(2).to(torch_device).to(dtype) |
| 183 | + |
| 184 | + # pad / crop to audio_sample_size |
| 185 | + audio = torch.nn.functional.pad(audio[:, :audio_sample_size], pad=(0, audio_sample_size - audio.shape[-1])) |
| 186 | + |
| 187 | + # todo channel |
| 188 | + audio = audio.unsqueeze(1).repeat(1, 2, 1).to(torch_device) |
| 189 | + |
| 190 | + return audio |
| 191 | + |
| 192 | + def get_oobleck_vae_model(self, model_id="stabilityai/stable-audio-open-1.0", fp16=False): |
| 193 | + torch_dtype = torch.float16 if fp16 else torch.float32 |
| 194 | + |
| 195 | + model = AutoencoderOobleck.from_pretrained( |
| 196 | + model_id, |
| 197 | + subfolder="vae", |
| 198 | + torch_dtype=torch_dtype, |
| 199 | + ) |
| 200 | + model.to(torch_device) |
| 201 | + |
| 202 | + return model |
| 203 | + |
| 204 | + def get_generator(self, seed=0): |
| 205 | + generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda" |
| 206 | + if torch_device != "mps": |
| 207 | + return torch.Generator(device=generator_device).manual_seed(seed) |
| 208 | + return torch.manual_seed(seed) |
| 209 | + |
| 210 | + @parameterized.expand( |
| 211 | + [ |
| 212 | + # fmt: off |
| 213 | + [33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192], |
| 214 | + [44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196], |
| 215 | + # fmt: on |
| 216 | + ] |
| 217 | + ) |
| 218 | + def test_stable_diffusion(self, seed, expected_slice, expected_mean_absolute_diff): |
| 219 | + model = self.get_oobleck_vae_model() |
| 220 | + audio = self.get_audio() |
| 221 | + generator = self.get_generator(seed) |
| 222 | + |
| 223 | + with torch.no_grad(): |
| 224 | + sample = model(audio, generator=generator, sample_posterior=True).sample |
| 225 | + |
| 226 | + assert sample.shape == audio.shape |
| 227 | + assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6 |
| 228 | + |
| 229 | + output_slice = sample[-1, 1, 5:10].cpu() |
| 230 | + expected_output_slice = torch.tensor(expected_slice) |
| 231 | + |
| 232 | + assert torch_all_close(output_slice, expected_output_slice, atol=1e-5) |
| 233 | + |
| 234 | + def test_stable_diffusion_mode(self): |
| 235 | + model = self.get_oobleck_vae_model() |
| 236 | + audio = self.get_audio() |
| 237 | + |
| 238 | + with torch.no_grad(): |
| 239 | + sample = model(audio, sample_posterior=False).sample |
| 240 | + |
| 241 | + assert sample.shape == audio.shape |
| 242 | + |
| 243 | + @parameterized.expand( |
| 244 | + [ |
| 245 | + # fmt: off |
| 246 | + [33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192], |
| 247 | + [44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196], |
| 248 | + # fmt: on |
| 249 | + ] |
| 250 | + ) |
| 251 | + def test_stable_diffusion_encode_decode(self, seed, expected_slice, expected_mean_absolute_diff): |
| 252 | + model = self.get_oobleck_vae_model() |
| 253 | + audio = self.get_audio() |
| 254 | + generator = self.get_generator(seed) |
| 255 | + |
| 256 | + with torch.no_grad(): |
| 257 | + x = audio |
| 258 | + posterior = model.encode(x).latent_dist |
| 259 | + z = posterior.sample(generator=generator) |
| 260 | + sample = model.decode(z).sample |
| 261 | + |
| 262 | + # (batch_size, latent_dim, sequence_length) |
| 263 | + assert posterior.mean.shape == (audio.shape[0], model.config.decoder_input_channels, 1024) |
| 264 | + |
| 265 | + assert sample.shape == audio.shape |
| 266 | + assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6 |
| 267 | + |
| 268 | + output_slice = sample[-1, 1, 5:10].cpu() |
| 269 | + expected_output_slice = torch.tensor(expected_slice) |
| 270 | + |
| 271 | + assert torch_all_close(output_slice, expected_output_slice, atol=1e-5) |
0 commit comments