diff --git a/acestep/cpu_offload.py b/acestep/cpu_offload.py index 17f3f309..23e87132 100644 --- a/acestep/cpu_offload.py +++ b/acestep/cpu_offload.py @@ -19,7 +19,7 @@ def __exit__(self, *args): self.model.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() - torch.cuda.synchronize() + # torch.cuda.synchronize() T = TypeVar('T') @@ -31,8 +31,8 @@ def wrapper(self, *args, **kwargs): if not self.cpu_offload: return func(self, *args, **kwargs) - # Get the device from the class - device = self.device + # Get the device from the class device map + device = getattr(self, "device_map", {}).get(model_attr, self.device) # Get the model from the class attribute model = getattr(self, model_attr) diff --git a/acestep/music_dcae/music_dcae_pipeline.py b/acestep/music_dcae/music_dcae_pipeline.py index 8112638c..543b5e84 100644 --- a/acestep/music_dcae/music_dcae_pipeline.py +++ b/acestep/music_dcae/music_dcae_pipeline.py @@ -118,6 +118,7 @@ def decode(self, latents, audio_lengths=None, sr=None): pred_wavs = [] for latent in latents: + latent = latent.to(self.device) mels = self.dcae.decoder(latent.unsqueeze(0)) mels = mels * 0.5 + 0.5 mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value diff --git a/acestep/pipeline_ace_step.py b/acestep/pipeline_ace_step.py index 81fe1d07..ab393d74 100644 --- a/acestep/pipeline_ace_step.py +++ b/acestep/pipeline_ace_step.py @@ -105,6 +105,7 @@ def __init__( cpu_offload=False, quantized=False, overlapped_decode=False, + device_map=None, **kwargs, ): if not checkpoint_dir: @@ -137,9 +138,18 @@ def __init__( self.loaded = False self.torch_compile = torch_compile self.cpu_offload = cpu_offload + self.cpu_offload_device = torch.device("cpu") if cpu_offload else None self.quantized = quantized self.overlapped_decode = overlapped_decode - + if device_map is not None: + self.device_map = device_map + else: + self.device_map = { + 'ace_step_transformer': self.device, + 'text_encoder_model': self.device, + 'music_dcae': self.device, + } + def cleanup_memory(self): """Clean up GPU and CPU memory to prevent VRAM overflow during multiple generations.""" # Clear CUDA cache @@ -190,15 +200,9 @@ def load_checkpoint(self, checkpoint_dir=None, export_quantized_weights=False): self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained( ace_step_checkpoint_path, torch_dtype=self.dtype ) - # self.ace_step_transformer.to(self.device).eval().to(self.dtype) - if self.cpu_offload: - self.ace_step_transformer = ( - self.ace_step_transformer.to("cpu").eval().to(self.dtype) - ) - else: - self.ace_step_transformer = ( - self.ace_step_transformer.to(self.device).eval().to(self.dtype) - ) + self.ace_step_transformer = ( + self.ace_step_transformer.to(self.cpu_offload_device or self.device_map['ace_step_transformer']).eval().to(self.dtype) + ) if self.torch_compile: self.ace_step_transformer = torch.compile(self.ace_step_transformer) @@ -206,11 +210,9 @@ def load_checkpoint(self, checkpoint_dir=None, export_quantized_weights=False): dcae_checkpoint_path=dcae_checkpoint_path, vocoder_checkpoint_path=vocoder_checkpoint_path, ) - # self.music_dcae.to(self.device).eval().to(self.dtype) - if self.cpu_offload: # might be redundant - self.music_dcae = self.music_dcae.to("cpu").eval().to(self.dtype) - else: - self.music_dcae = self.music_dcae.to(self.device).eval().to(self.dtype) + self.music_dcae = ( + self.music_dcae.to(self.cpu_offload_device or self.device_map['music_dcae']).eval().to(self.dtype) + ) if self.torch_compile: self.music_dcae = torch.compile(self.music_dcae) @@ -222,11 +224,9 @@ def load_checkpoint(self, checkpoint_dir=None, export_quantized_weights=False): text_encoder_model = UMT5EncoderModel.from_pretrained( text_encoder_checkpoint_path, torch_dtype=self.dtype ).eval() - # text_encoder_model = text_encoder_model.to(self.device).to(self.dtype) - if self.cpu_offload: - text_encoder_model = text_encoder_model.to("cpu").eval().to(self.dtype) - else: - text_encoder_model = text_encoder_model.to(self.device).eval().to(self.dtype) + text_encoder_model = ( + text_encoder_model.to(self.cpu_offload_device or self.device_map['text_encoder_model']).eval().to(self.dtype) + ) text_encoder_model.requires_grad_(False) self.text_encoder_model = text_encoder_model if self.torch_compile: @@ -290,19 +290,19 @@ def load_quantized_checkpoint(self, checkpoint_dir=None): dcae_checkpoint_path=dcae_checkpoint_path, vocoder_checkpoint_path=vocoder_checkpoint_path, ) - if self.cpu_offload: - self.music_dcae.eval().to(self.dtype).to(self.device) - else: - self.music_dcae.eval().to(self.dtype).to('cpu') + self.music_dcae.to(self.cpu_offload_device or self.device_map['music_dcae']).eval().to(self.dtype) self.music_dcae = torch.compile(self.music_dcae) - self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path) + # self.ace_step_transformer = ( + # self.ace_step_transformer.to(self.cpu_offload_device or self.device_map['ace_step_transformer']).eval().to(self.dtype) + # ) self.ace_step_transformer.eval().to(self.dtype).to('cpu') self.ace_step_transformer = torch.compile(self.ace_step_transformer) self.ace_step_transformer.load_state_dict( torch.load( os.path.join(ace_step_checkpoint_path, "diffusion_pytorch_model_int4wo.bin"), + # map_location=self.cpu_offload_device or self.device_map['ace_step_transformer'], map_location=self.device, ),assign=True ) @@ -314,6 +314,7 @@ def load_quantized_checkpoint(self, checkpoint_dir=None): self.text_encoder_model.load_state_dict( torch.load( os.path.join(text_encoder_checkpoint_path, "pytorch_model_int4wo.bin"), + # map_location=self.device_map['text_encoder_model'], map_location=self.device, ),assign=True ) @@ -339,14 +340,14 @@ def get_text_embeddings(self, texts, text_max_length=256): truncation=True, max_length=text_max_length, ) - inputs = {key: value.to(self.device) for key, value in inputs.items()} - if self.text_encoder_model.device != self.device: - self.text_encoder_model.to(self.device) + # if self.text_encoder_model.device != self.device: + # self.text_encoder_model.to(self.device) with torch.no_grad(): + inputs = {key: value.to(self.text_encoder_model.device) for key, value in inputs.items()} outputs = self.text_encoder_model(**inputs) last_hidden_states = outputs.last_hidden_state attention_mask = inputs["attention_mask"] - return last_hidden_states, attention_mask + return last_hidden_states.to(self.device), attention_mask.to(self.device) @cpu_offload("text_encoder_model") def get_text_embeddings_null( @@ -359,9 +360,9 @@ def get_text_embeddings_null( truncation=True, max_length=text_max_length, ) - inputs = {key: value.to(self.device) for key, value in inputs.items()} - if self.text_encoder_model.device != self.device: - self.text_encoder_model.to(self.device) + # inputs = {key: value.to(self.device) for key, value in inputs.items()} + # if self.text_encoder_model.device != self.device: + # self.text_encoder_model.to(self.device) def forward_with_temperature(inputs, tau=0.01, l_min=8, l_max=10): handlers = [] @@ -379,16 +380,17 @@ def hook(module, input, output): handlers.append(handler) with torch.no_grad(): + inputs = {key: value.to(self.text_encoder_model.device) for key, value in inputs.items()} outputs = self.text_encoder_model(**inputs) last_hidden_states = outputs.last_hidden_state for hook in handlers: hook.remove() - return last_hidden_states + return last_hidden_states.to(self.device) last_hidden_states = forward_with_temperature(inputs, tau, l_min, l_max) - return last_hidden_states + return last_hidden_states.to(self.device) def set_seeds(self, batch_size, manual_seeds=None): processed_input_seeds = None @@ -502,17 +504,23 @@ def calc_v( torch.cat([zt_src, zt_src]) if do_classifier_free_guidance else zt_src ) timestep = t.expand(src_latent_model_input.shape[0]) + args = { + "hidden_states": src_latent_model_input, + "attention_mask": attention_mask, + "encoder_text_hidden_states": encoder_text_hidden_states, + "text_attention_mask": text_attention_mask, + "speaker_embeds": speaker_embds, + "lyric_token_idx": lyric_token_ids, + "lyric_mask": lyric_mask, + "timestep": timestep, + } + for key, value in args.items(): + if value is not None: + args[key] = value.to(self.ace_step_transformer.device) # source noise_pred_src = self.ace_step_transformer( - hidden_states=src_latent_model_input, - attention_mask=attention_mask, - encoder_text_hidden_states=encoder_text_hidden_states, - text_attention_mask=text_attention_mask, - speaker_embeds=speaker_embds, - lyric_token_idx=lyric_token_ids, - lyric_mask=lyric_mask, - timestep=timestep, - ).sample + **args, + ).sample.to(self.device) if do_classifier_free_guidance: noise_pred_with_cond_src, noise_pred_uncond_src = noise_pred_src.chunk( @@ -536,17 +544,23 @@ def calc_v( torch.cat([zt_tar, zt_tar]) if do_classifier_free_guidance else zt_tar ) timestep = t.expand(tar_latent_model_input.shape[0]) + args = { + "hidden_states": tar_latent_model_input, + "attention_mask": attention_mask, + "encoder_text_hidden_states": target_encoder_text_hidden_states, + "text_attention_mask": target_text_attention_mask, + "speaker_embeds": target_speaker_embeds, + "lyric_token_idx": target_lyric_token_ids, + "lyric_mask": target_lyric_mask, + "timestep": timestep, + } + for key, value in args.items(): + if value is not None: + args[key] = value.to(self.ace_step_transformer.device) # target noise_pred_tar = self.ace_step_transformer( - hidden_states=tar_latent_model_input, - attention_mask=attention_mask, - encoder_text_hidden_states=target_encoder_text_hidden_states, - text_attention_mask=target_text_attention_mask, - speaker_embeds=target_speaker_embeds, - lyric_token_idx=target_lyric_token_ids, - lyric_mask=target_lyric_mask, - timestep=timestep, - ).sample + **args, + ).sample.to(self.device) if do_classifier_free_guidance: noise_pred_with_cond_tar, noise_pred_uncond_tar = noise_pred_tar.chunk(2) @@ -1100,63 +1114,90 @@ def hook(module, input, output): return encoder_hidden_states + args = { + "encoder_text_hidden_states": encoder_text_hidden_states, + "text_attention_mask": text_attention_mask, + "speaker_embeds": speaker_embds, + "lyric_token_idx": lyric_token_ids, + "lyric_mask": lyric_mask, + } + # move to device + for k, v in args.items(): + if v is not None: + args[k] = v.to(self.ace_step_transformer.device, dtype=self.dtype) # P(speaker, text, lyric) encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode( - encoder_text_hidden_states, - text_attention_mask, - speaker_embds, - lyric_token_ids, - lyric_mask, + **args, ) if use_erg_lyric: + inputs = { + "encoder_text_hidden_states": ( + encoder_text_hidden_states_null + if encoder_text_hidden_states_null is not None + else torch.zeros_like(encoder_text_hidden_states) + ), + "text_attention_mask": text_attention_mask, + "speaker_embeds": torch.zeros_like(speaker_embds), + "lyric_token_idx": lyric_token_ids, + "lyric_mask": lyric_mask, + } + for k, v in inputs.items(): + if v is not None: + inputs[k] = v.to(self.ace_step_transformer.device, dtype=self.dtype) # P(null_speaker, text_weaker, lyric_weaker) encoder_hidden_states_null = forward_encoder_with_temperature( self, - inputs={ - "encoder_text_hidden_states": ( - encoder_text_hidden_states_null - if encoder_text_hidden_states_null is not None - else torch.zeros_like(encoder_text_hidden_states) - ), - "text_attention_mask": text_attention_mask, - "speaker_embeds": torch.zeros_like(speaker_embds), - "lyric_token_idx": lyric_token_ids, - "lyric_mask": lyric_mask, - }, + inputs=inputs, ) else: + inputs = { + "encoder_text_hidden_states": torch.zeros_like(encoder_text_hidden_states), + "text_attention_mask": text_attention_mask, + "speaker_embeds": torch.zeros_like(speaker_embds), + "lyric_token_idx": torch.zeros_like(lyric_token_ids), + "lyric_mask": lyric_mask, + } + for k, v in inputs.items(): + if v is not None: + inputs[k] = v.to(self.ace_step_transformer.device, dtype=self.dtype) # P(null_speaker, null_text, null_lyric) encoder_hidden_states_null, _ = self.ace_step_transformer.encode( - torch.zeros_like(encoder_text_hidden_states), - text_attention_mask, - torch.zeros_like(speaker_embds), - torch.zeros_like(lyric_token_ids), - lyric_mask, + **inputs, ) encoder_hidden_states_no_lyric = None if do_double_condition_guidance: # P(null_speaker, text, lyric_weaker) if use_erg_lyric: + inputs = { + "encoder_text_hidden_states": encoder_text_hidden_states, + "text_attention_mask": text_attention_mask, + "speaker_embeds": torch.zeros_like(speaker_embds), + "lyric_token_idx": lyric_token_ids, + "lyric_mask": lyric_mask, + } + for k, v in inputs.items(): + if v is not None: + inputs[k] = v.to(self.ace_step_transformer.device, dtype=self.dtype) encoder_hidden_states_no_lyric = forward_encoder_with_temperature( self, - inputs={ - "encoder_text_hidden_states": encoder_text_hidden_states, - "text_attention_mask": text_attention_mask, - "speaker_embeds": torch.zeros_like(speaker_embds), - "lyric_token_idx": lyric_token_ids, - "lyric_mask": lyric_mask, - }, + inputs=inputs, ) # P(null_speaker, text, no_lyric) else: + inputs = { + "encoder_text_hidden_states": encoder_text_hidden_states, + "text_attention_mask": text_attention_mask, + "speaker_embeds": torch.zeros_like(speaker_embds), + "lyric_token_idx": torch.zeros_like(lyric_token_ids), + "lyric_mask": lyric_mask, + } + for k, v in inputs.items(): + if v is not None: + inputs[k] = v.to(self.ace_step_transformer.device, dtype=self.dtype) encoder_hidden_states_no_lyric, _ = self.ace_step_transformer.encode( - encoder_text_hidden_states, - text_attention_mask, - torch.zeros_like(speaker_embds), - torch.zeros_like(lyric_token_ids), - lyric_mask, + **inputs, ) def forward_diffusion_with_temperature( @@ -1180,7 +1221,7 @@ def hook(module, input, output): sample = self.ace_step_transformer.decode( hidden_states=hidden_states, timestep=timestep, **inputs - ).sample + ).sample.to(self.device) for hook in handlers: hook.remove() @@ -1221,51 +1262,73 @@ def hook(module, input, output): latent_model_input = latents timestep = t.expand(latent_model_input.shape[0]) output_length = latent_model_input.shape[-1] + inputs = { + "hidden_states": latent_model_input, + "attention_mask": attention_mask, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_mask": encoder_hidden_mask, + "output_length": output_length, + "timestep": timestep, + } + for k, v in inputs.items(): + if v is not None: + inputs[k] = v.to(self.ace_step_transformer.device, dtype=self.dtype) # P(x|speaker, text, lyric) noise_pred_with_cond = self.ace_step_transformer.decode( - hidden_states=latent_model_input, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_mask=encoder_hidden_mask, - output_length=output_length, - timestep=timestep, - ).sample + **inputs, + ).sample.to(self.device) noise_pred_with_only_text_cond = None if ( do_double_condition_guidance and encoder_hidden_states_no_lyric is not None ): + inputs = { + "hidden_states": latent_model_input, + "attention_mask": attention_mask, + "encoder_hidden_states": encoder_hidden_states_no_lyric, + "encoder_hidden_mask": encoder_hidden_mask, + "output_length": output_length, + "timestep": timestep, + } + for k, v in inputs.items(): + if v is not None: + inputs[k] = v.to(self.ace_step_transformer.device, dtype=self.dtype) noise_pred_with_only_text_cond = self.ace_step_transformer.decode( - hidden_states=latent_model_input, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states_no_lyric, - encoder_hidden_mask=encoder_hidden_mask, - output_length=output_length, - timestep=timestep, - ).sample + **inputs, + ).sample.to(self.device) if use_erg_diffusion: + inputs = { + "encoder_hidden_states": encoder_hidden_states_null, + "encoder_hidden_mask": encoder_hidden_mask, + "output_length": output_length, + "attention_mask": attention_mask, + } + for k, v in inputs.items(): + if v is not None: + inputs[k] = v.to(self.ace_step_transformer.device, dtype=self.dtype) noise_pred_uncond = forward_diffusion_with_temperature( self, hidden_states=latent_model_input, timestep=timestep, - inputs={ - "encoder_hidden_states": encoder_hidden_states_null, - "encoder_hidden_mask": encoder_hidden_mask, - "output_length": output_length, - "attention_mask": attention_mask, - }, + inputs=inputs, ) else: + inputs = { + "hidden_states": latent_model_input, + "attention_mask": attention_mask, + "encoder_hidden_states": encoder_hidden_states_null, + "encoder_hidden_mask": encoder_hidden_mask, + "output_length": output_length, + "timestep": timestep, + } + for k, v in inputs.items(): + if v is not None: + inputs[k] = v.to(self.ace_step_transformer.device, dtype=self.dtype) noise_pred_uncond = self.ace_step_transformer.decode( - hidden_states=latent_model_input, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states_null, - encoder_hidden_mask=encoder_hidden_mask, - output_length=output_length, - timestep=timestep, - ).sample + **inputs, + ).sample.to(self.device) if ( do_double_condition_guidance @@ -1304,14 +1367,20 @@ def hook(module, input, output): else: latent_model_input = latents timestep = t.expand(latent_model_input.shape[0]) + inputs = { + "hidden_states": latent_model_input, + "attention_mask": attention_mask, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_mask": encoder_hidden_mask, + "output_length": latent_model_input.shape[-1], + "timestep": timestep, + } + for k, v in inputs.items(): + if v is not None: + inputs[k] = v.to(self.ace_step_transformer.device, dtype=self.dtype) noise_pred = self.ace_step_transformer.decode( - hidden_states=latent_model_input, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_mask=encoder_hidden_mask, - output_length=latent_model_input.shape[-1], - timestep=timestep, - ).sample + **inputs, + ).sample.to(self.device) if is_repaint and i >= n_min: t_i = t / 1000 @@ -1362,9 +1431,12 @@ def latents2audio( pred_latents = latents with torch.no_grad(): if self.overlapped_decode and target_wav_duration_second > 48: + # decode overlap appears to handle .to(self.device) automatically _, pred_wavs = self.music_dcae.decode_overlap(pred_latents, sr=sample_rate) else: + # decode appears to handle .to(self.device) automatically _, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate) + # pred wavs are already brought to cpu, no need for .to(self.device) pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs] for i in tqdm(range(bs)): output_audio_path = self.save_wav_file( @@ -1408,9 +1480,9 @@ def infer_latents(self, input_audio_path): return None input_audio, sr = self.music_dcae.load_audio(input_audio_path) input_audio = input_audio.unsqueeze(0) - input_audio = input_audio.to(device=self.device, dtype=self.dtype) + input_audio = input_audio.to(device=self.music_dcae.device, dtype=self.dtype) latents, _ = self.music_dcae.encode(input_audio, sr=sr) - return latents + return latents.to(self.device) def load_lora(self, lora_name_or_path, lora_weight): if lora_name_or_path != self.lora_path and lora_name_or_path != "none":