|
40 | 40 | logging, |
41 | 41 | replace_example_docstring, |
42 | 42 | ) |
| 43 | +from ...utils.import_utils import is_transformers_version |
43 | 44 | from ...utils.torch_utils import randn_tensor |
44 | 45 | from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline |
45 | 46 | from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel |
@@ -312,8 +313,19 @@ def generate_language_model( |
312 | 313 | `inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): |
313 | 314 | The sequence of generated hidden-states. |
314 | 315 | """ |
| 316 | + cache_position_kwargs = {} |
| 317 | + if is_transformers_version("<", "4.52.0.dev0"): |
| 318 | + cache_position_kwargs["input_ids"] = inputs_embeds |
| 319 | + cache_position_kwargs["model_kwargs"] = model_kwargs |
| 320 | + else: |
| 321 | + cache_position_kwargs["seq_length"] = inputs_embeds.shape[0] |
| 322 | + cache_position_kwargs["device"] = ( |
| 323 | + self.language_model.device if getattr(self, "language_model", None) is not None else self.device |
| 324 | + ) |
| 325 | + cache_position_kwargs["model_kwargs"] = model_kwargs |
315 | 326 | max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens |
316 | | - model_kwargs = self.language_model._get_initial_cache_position(inputs_embeds, model_kwargs) |
| 327 | + model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs) |
| 328 | + |
317 | 329 | for _ in range(max_new_tokens): |
318 | 330 | # prepare model inputs |
319 | 331 | model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs) |
|
0 commit comments