diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 6f3faed8ff72..c3bc98b46e7c 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -74,10 +74,18 @@ def calculate_shift( base_shift: float = 0.5, max_shift: float = 1.15, ): + if ( + base_seq_len is _DEFAULT_BASE_SEQ_LEN and + max_seq_len is _DEFAULT_MAX_SEQ_LEN and + base_shift is _DEFAULT_BASE_SHIFT and + max_shift is _DEFAULT_MAX_SHIFT + ): + # Fast path for defaults + return image_seq_len * _m_DEFAULT + _b_DEFAULT + + # Fallback for non-defaults (kept as a one-liner for efficiency) m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - b = base_shift - m * base_seq_len - mu = image_seq_len * m + b - return mu + return image_seq_len * m + (base_shift - m * base_seq_len) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps @@ -799,3 +807,15 @@ def __call__( return (video,) return LTXPipelineOutput(frames=video) + +_DEFAULT_BASE_SEQ_LEN = 256 + +_DEFAULT_MAX_SEQ_LEN = 4096 + +_DEFAULT_BASE_SHIFT = 0.5 + +_DEFAULT_MAX_SHIFT = 1.15 + +_m_DEFAULT = (_DEFAULT_MAX_SHIFT - _DEFAULT_BASE_SHIFT) / (_DEFAULT_MAX_SEQ_LEN - _DEFAULT_BASE_SEQ_LEN) + +_b_DEFAULT = _DEFAULT_BASE_SHIFT - _m_DEFAULT * _DEFAULT_BASE_SEQ_LEN