From 5af2e4b8c0d1ff305da820bb80fa6e35ae4db829 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 9 Apr 2025 08:18:59 +0100 Subject: [PATCH] AudioLDM2 Fixes --- .../pipelines/audioldm2/pipeline_audioldm2.py | 13 ++++++++----- tests/pipelines/audioldm2/test_audioldm2.py | 16 ++++++++++++++-- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py index b8b5d07af529..1616d94ff1ff 100644 --- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py @@ -20,7 +20,7 @@ from transformers import ( ClapFeatureExtractor, ClapModel, - GPT2Model, + GPT2LMHeadModel, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan, @@ -196,7 +196,7 @@ def __init__( text_encoder: ClapModel, text_encoder_2: Union[T5EncoderModel, VitsModel], projection_model: AudioLDM2ProjectionModel, - language_model: GPT2Model, + language_model: GPT2LMHeadModel, tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer], feature_extractor: ClapFeatureExtractor, @@ -259,7 +259,10 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t ) device_type = torch_device.type - device = torch.device(f"{device_type}:{gpu_id or torch_device.index}") + device_str = device_type + if gpu_id or torch_device.index: + device_str = f"{device_str}:{gpu_id or torch_device.index}" + device = torch.device(device_str) if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) @@ -316,9 +319,9 @@ def generate_language_model( model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs) # forward pass to get next hidden states - output = self.language_model(**model_inputs, return_dict=True) + output = self.language_model(**model_inputs, output_hidden_states=True, return_dict=True) - next_hidden_states = output.last_hidden_state + next_hidden_states = output.hidden_states[-1] # Update the model input inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1) diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py index 66052392f07f..80b2b9c8cd34 100644 --- a/tests/pipelines/audioldm2/test_audioldm2.py +++ b/tests/pipelines/audioldm2/test_audioldm2.py @@ -26,7 +26,7 @@ ClapModel, ClapTextConfig, GPT2Config, - GPT2Model, + GPT2LMHeadModel, RobertaTokenizer, SpeechT5HifiGan, SpeechT5HifiGanConfig, @@ -162,7 +162,7 @@ def get_dummy_components(self): n_ctx=99, n_positions=99, ) - language_model = GPT2Model(language_model_config) + language_model = GPT2LMHeadModel(language_model_config) language_model.config.max_new_tokens = 8 torch.manual_seed(0) @@ -516,6 +516,18 @@ def test_sequential_cpu_offload_forward_pass(self): def test_encode_prompt_works_in_isolation(self): pass + @unittest.skip("Not supported yet due to CLAPModel.") + def test_sequential_offload_forward_pass_twice(self): + pass + + @unittest.skip("Not supported yet, the second forward has mixed devices and `vocoder` is not offloaded.") + def test_cpu_offload_forward_pass_twice(self): + pass + + @unittest.skip("Not supported yet. `vocoder` is not offloaded.") + def test_model_cpu_offload_forward_pass(self): + pass + @nightly class AudioLDM2PipelineSlowTests(unittest.TestCase):