diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index a63c370..31f863f 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -3289,7 +3289,7 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l cache_dtype = self.dtype cache_kwargs = { "config": self.config.decoder, - "max_batch_size": max_batch_size, + "batch_size": max_batch_size, "max_cache_len": max_cache_len, "device": self.device, "dtype": cache_dtype,