diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 05379270c13b..c0fa031b9faf 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -180,7 +180,6 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): ], dim=1, ) - self.rope_cache = {} # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART self.scale_rope = scale_rope @@ -195,10 +194,20 @@ def rope_params(self, index, dim, theta=10000): freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs - def forward(self, video_fhw, txt_seq_lens, device): + def forward( + self, + video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], + txt_seq_lens: List[int], + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: - txt_length: [bs] a list of 1 integers representing the length of the text + Args: + video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): + A list of 3 integers [frame, height, width] representing the shape of the video. + txt_seq_lens (`List[int]`): + A list of integers of length batch_size representing the length of each text prompt. + device: (`torch.device`): + The device on which to perform the RoPE computation. """ if self.pos_freqs.device != device: self.pos_freqs = self.pos_freqs.to(device) @@ -213,14 +222,8 @@ def forward(self, video_fhw, txt_seq_lens, device): max_vid_index = 0 for idx, fhw in enumerate(video_fhw): frame, height, width = fhw - rope_key = f"{idx}_{height}_{width}" - - if not torch.compiler.is_compiling(): - if rope_key not in self.rope_cache: - self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) - video_freq = self.rope_cache[rope_key] - else: - video_freq = self._compute_video_freqs(frame, height, width, idx) + # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs + video_freq = self._compute_video_freqs(frame, height, width, idx) video_freq = video_freq.to(device) vid_freqs.append(video_freq) @@ -235,8 +238,8 @@ def forward(self, video_fhw, txt_seq_lens, device): return vid_freqs, txt_freqs - @functools.lru_cache(maxsize=None) - def _compute_video_freqs(self, frame, height, width, idx=0): + @functools.lru_cache(maxsize=128) + def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor: seq_lens = frame * height * width freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)