| 
40 | 40 |     logging,  | 
41 | 41 |     replace_example_docstring,  | 
42 | 42 | )  | 
 | 43 | +from ...utils.import_utils import is_transformers_version  | 
43 | 44 | from ...utils.torch_utils import randn_tensor  | 
44 | 45 | from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline  | 
45 | 46 | from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel  | 
@@ -312,8 +313,19 @@ def generate_language_model(  | 
312 | 313 |             `inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):  | 
313 | 314 |                 The sequence of generated hidden-states.  | 
314 | 315 |         """  | 
 | 316 | +        cache_position_kwargs = {}  | 
 | 317 | +        if is_transformers_version("<", "4.52.0.dev0"):  | 
 | 318 | +            cache_position_kwargs["input_ids"] = inputs_embeds  | 
 | 319 | +            cache_position_kwargs["model_kwargs"] = model_kwargs  | 
 | 320 | +        else:  | 
 | 321 | +            cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]  | 
 | 322 | +            cache_position_kwargs["device"] = (  | 
 | 323 | +                self.language_model.device if getattr(self, "language_model", None) is not None else self.device  | 
 | 324 | +            )  | 
 | 325 | +            cache_position_kwargs["model_kwargs"] = model_kwargs  | 
315 | 326 |         max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens  | 
316 |  | -        model_kwargs = self.language_model._get_initial_cache_position(inputs_embeds, model_kwargs)  | 
 | 327 | +        model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)  | 
 | 328 | + | 
317 | 329 |         for _ in range(max_new_tokens):  | 
318 | 330 |             # prepare model inputs  | 
319 | 331 |             model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)  | 
 | 
0 commit comments