Skip to content

Commit 4b5b4b0

Browse files
committed
integration tests for stable audio decoder.
1 parent a9de100 commit 4b5b4b0

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed

tests/models/autoencoders/test_models_autoencoder_oobleck.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,20 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import gc
1617
import unittest
1718

1819
import torch
20+
from datasets import load_dataset
21+
from parameterized import parameterized
1922

2023
from diffusers import AutoencoderOobleck
2124
from diffusers.utils.testing_utils import (
25+
backend_empty_cache,
2226
enable_full_determinism,
2327
floats_tensor,
2428
require_torch_accelerator_with_training,
29+
slow,
2530
torch_all_close,
2631
torch_device,
2732
)
@@ -151,3 +156,116 @@ def test_forward_with_norm_groups(self):
151156
@unittest.skip("No attention module used in this model")
152157
def test_set_attn_processor_for_determinism(self):
153158
return
159+
160+
161+
@slow
162+
class AutoencoderOobleckIntegrationTests(unittest.TestCase):
163+
def tearDown(self):
164+
# clean up the VRAM after each test
165+
super().tearDown()
166+
gc.collect()
167+
backend_empty_cache(torch_device)
168+
169+
def _load_datasamples(self, num_samples):
170+
ds = load_dataset(
171+
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True
172+
)
173+
# automatic decoding with librispeech
174+
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
175+
176+
return torch.nn.utils.rnn.pad_sequence(
177+
[torch.from_numpy(x["array"]) for x in speech_samples], batch_first=True
178+
)
179+
180+
def get_audio(self, audio_sample_size=2097152, fp16=False):
181+
dtype = torch.float16 if fp16 else torch.float32
182+
audio = self._load_datasamples(2).to(torch_device).to(dtype)
183+
184+
# pad / crop to audio_sample_size
185+
audio = torch.nn.functional.pad(audio[:, :audio_sample_size], pad=(0, audio_sample_size - audio.shape[-1]))
186+
187+
# todo channel
188+
audio = audio.unsqueeze(1).repeat(1, 2, 1).to(torch_device)
189+
190+
return audio
191+
192+
def get_oobleck_vae_model(self, model_id="stabilityai/stable-audio-open-1.0", fp16=False):
193+
torch_dtype = torch.float16 if fp16 else torch.float32
194+
195+
model = AutoencoderOobleck.from_pretrained(
196+
model_id,
197+
subfolder="vae",
198+
torch_dtype=torch_dtype,
199+
)
200+
model.to(torch_device)
201+
202+
return model
203+
204+
def get_generator(self, seed=0):
205+
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
206+
if torch_device != "mps":
207+
return torch.Generator(device=generator_device).manual_seed(seed)
208+
return torch.manual_seed(seed)
209+
210+
@parameterized.expand(
211+
[
212+
# fmt: off
213+
[33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192],
214+
[44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196],
215+
# fmt: on
216+
]
217+
)
218+
def test_stable_diffusion(self, seed, expected_slice, expected_mean_absolute_diff):
219+
model = self.get_oobleck_vae_model()
220+
audio = self.get_audio()
221+
generator = self.get_generator(seed)
222+
223+
with torch.no_grad():
224+
sample = model(audio, generator=generator, sample_posterior=True).sample
225+
226+
assert sample.shape == audio.shape
227+
assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6
228+
229+
output_slice = sample[-1, 1, 5:10].cpu()
230+
expected_output_slice = torch.tensor(expected_slice)
231+
232+
assert torch_all_close(output_slice, expected_output_slice, atol=1e-5)
233+
234+
def test_stable_diffusion_mode(self):
235+
model = self.get_oobleck_vae_model()
236+
audio = self.get_audio()
237+
238+
with torch.no_grad():
239+
sample = model(audio, sample_posterior=False).sample
240+
241+
assert sample.shape == audio.shape
242+
243+
@parameterized.expand(
244+
[
245+
# fmt: off
246+
[33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192],
247+
[44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196],
248+
# fmt: on
249+
]
250+
)
251+
def test_stable_diffusion_encode_decode(self, seed, expected_slice, expected_mean_absolute_diff):
252+
model = self.get_oobleck_vae_model()
253+
audio = self.get_audio()
254+
generator = self.get_generator(seed)
255+
256+
with torch.no_grad():
257+
x = audio
258+
posterior = model.encode(x).latent_dist
259+
z = posterior.sample(generator=generator)
260+
sample = model.decode(z).sample
261+
262+
# (batch_size, latent_dim, sequence_length)
263+
assert posterior.mean.shape == (audio.shape[0], model.config.decoder_input_channels, 1024)
264+
265+
assert sample.shape == audio.shape
266+
assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6
267+
268+
output_slice = sample[-1, 1, 5:10].cpu()
269+
expected_output_slice = torch.tensor(expected_slice)
270+
271+
assert torch_all_close(output_slice, expected_output_slice, atol=1e-5)

tests/models/autoencoders/test_models_consistency_decoder_vae.py

Whitespace-only changes.

0 commit comments

Comments
 (0)