Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformers import (
ClapFeatureExtractor,
ClapModel,
GPT2Model,
GPT2LMHeadModel,
RobertaTokenizer,
RobertaTokenizerFast,
SpeechT5HifiGan,
Expand Down Expand Up @@ -196,7 +196,7 @@ def __init__(
text_encoder: ClapModel,
text_encoder_2: Union[T5EncoderModel, VitsModel],
projection_model: AudioLDM2ProjectionModel,
language_model: GPT2Model,
language_model: GPT2LMHeadModel,
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer],
feature_extractor: ClapFeatureExtractor,
Expand Down Expand Up @@ -259,7 +259,10 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
)

device_type = torch_device.type
device = torch.device(f"{device_type}:{gpu_id or torch_device.index}")
device_str = device_type
if gpu_id or torch_device.index:
device_str = f"{device_str}:{gpu_id or torch_device.index}"
device = torch.device(device_str)
Comment on lines +262 to +265
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this is a partial fix to enable_model_cpu_offload. This may have been known as this won't be tested from fast tests.


if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
Expand Down Expand Up @@ -316,9 +319,9 @@ def generate_language_model(
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)

# forward pass to get next hidden states
output = self.language_model(**model_inputs, return_dict=True)
output = self.language_model(**model_inputs, output_hidden_states=True, return_dict=True)

next_hidden_states = output.last_hidden_state
next_hidden_states = output.hidden_states[-1]

# Update the model input
inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
Expand Down
16 changes: 14 additions & 2 deletions tests/pipelines/audioldm2/test_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
ClapModel,
ClapTextConfig,
GPT2Config,
GPT2Model,
GPT2LMHeadModel,
RobertaTokenizer,
SpeechT5HifiGan,
SpeechT5HifiGanConfig,
Expand Down Expand Up @@ -162,7 +162,7 @@ def get_dummy_components(self):
n_ctx=99,
n_positions=99,
)
language_model = GPT2Model(language_model_config)
language_model = GPT2LMHeadModel(language_model_config)
language_model.config.max_new_tokens = 8

torch.manual_seed(0)
Expand Down Expand Up @@ -516,6 +516,18 @@ def test_sequential_cpu_offload_forward_pass(self):
def test_encode_prompt_works_in_isolation(self):
pass

@unittest.skip("Not supported yet due to CLAPModel.")
def test_sequential_offload_forward_pass_twice(self):
pass

@unittest.skip("Not supported yet, the second forward has mixed devices and `vocoder` is not offloaded.")
def test_cpu_offload_forward_pass_twice(self):
pass

@unittest.skip("Not supported yet. `vocoder` is not offloaded.")
def test_model_cpu_offload_forward_pass(self):
pass
Comment on lines +519 to +529
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above these may have been known, added a skip for now.



@nightly
class AudioLDM2PipelineSlowTests(unittest.TestCase):
Expand Down
Loading