|
| 1 | +import numpy as np |
1 | 2 | import torch |
2 | 3 | import torch.amp as amp |
3 | 4 | from torch.backends.cuda import sdp_kernel |
@@ -59,6 +60,17 @@ def rope_apply(x, grid_sizes, freqs): |
59 | 60 | return torch.stack(output).float() |
60 | 61 |
|
61 | 62 |
|
| 63 | +def broadcast_should_calc(should_calc: bool) -> bool: |
| 64 | + import torch.distributed as dist |
| 65 | + |
| 66 | + device = torch.cuda.current_device() |
| 67 | + int_should_calc = 1 if should_calc else 0 |
| 68 | + tensor = torch.tensor([int_should_calc], device=device, dtype=torch.int8) |
| 69 | + dist.broadcast(tensor, src=0) |
| 70 | + should_calc = tensor.item() == 1 |
| 71 | + return should_calc |
| 72 | + |
| 73 | + |
62 | 74 | def usp_dit_forward(self, x, t, context, clip_fea=None, y=None, fps=None): |
63 | 75 | """ |
64 | 76 | x: A list of videos each with shape [C, T, H, W]. |
@@ -135,20 +147,84 @@ def usp_dit_forward(self, x, t, context, clip_fea=None, y=None, fps=None): |
135 | 147 | e0 = torch.chunk(e0, get_sequence_parallel_world_size(), dim=2)[get_sequence_parallel_rank()] |
136 | 148 | kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask) |
137 | 149 |
|
138 | | - # Context Parallel |
139 | | - x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] |
| 150 | + if self.enable_teacache: |
| 151 | + modulated_inp = e0 if self.use_ref_steps else e |
| 152 | + # teacache |
| 153 | + if self.cnt % 2 == 0: # even -> conditon |
| 154 | + self.is_even = True |
| 155 | + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: |
| 156 | + should_calc_even = True |
| 157 | + self.accumulated_rel_l1_distance_even = 0 |
| 158 | + else: |
| 159 | + rescale_func = np.poly1d(self.coefficients) |
| 160 | + self.accumulated_rel_l1_distance_even += rescale_func( |
| 161 | + ((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()) |
| 162 | + .cpu() |
| 163 | + .item() |
| 164 | + ) |
| 165 | + if self.accumulated_rel_l1_distance_even < self.teacache_thresh: |
| 166 | + should_calc_even = False |
| 167 | + else: |
| 168 | + should_calc_even = True |
| 169 | + self.accumulated_rel_l1_distance_even = 0 |
| 170 | + self.previous_e0_even = modulated_inp.clone() |
| 171 | + else: # odd -> unconditon |
| 172 | + self.is_even = False |
| 173 | + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: |
| 174 | + should_calc_odd = True |
| 175 | + self.accumulated_rel_l1_distance_odd = 0 |
| 176 | + else: |
| 177 | + rescale_func = np.poly1d(self.coefficients) |
| 178 | + self.accumulated_rel_l1_distance_odd += rescale_func( |
| 179 | + ((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()) |
| 180 | + .cpu() |
| 181 | + .item() |
| 182 | + ) |
| 183 | + if self.accumulated_rel_l1_distance_odd < self.teacache_thresh: |
| 184 | + should_calc_odd = False |
| 185 | + else: |
| 186 | + should_calc_odd = True |
| 187 | + self.accumulated_rel_l1_distance_odd = 0 |
| 188 | + self.previous_e0_odd = modulated_inp.clone() |
140 | 189 |
|
141 | | - for block in self.blocks: |
142 | | - x = block(x, **kwargs) |
| 190 | + x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] |
| 191 | + if self.enable_teacache: |
| 192 | + if self.is_even: |
| 193 | + should_calc_even = broadcast_should_calc(should_calc_even) |
| 194 | + if not should_calc_even: |
| 195 | + x += self.previous_residual_even |
| 196 | + else: |
| 197 | + ori_x = x.clone() |
| 198 | + for block in self.blocks: |
| 199 | + x = block(x, **kwargs) |
| 200 | + ori_x.mul_(-1) |
| 201 | + ori_x.add_(x) |
| 202 | + self.previous_residual_even = ori_x |
| 203 | + else: |
| 204 | + should_calc_odd = broadcast_should_calc(should_calc_odd) |
| 205 | + if not should_calc_odd: |
| 206 | + x += self.previous_residual_odd |
| 207 | + else: |
| 208 | + ori_x = x.clone() |
| 209 | + for block in self.blocks: |
| 210 | + x = block(x, **kwargs) |
| 211 | + ori_x.mul_(-1) |
| 212 | + ori_x.add_(x) |
| 213 | + self.previous_residual_odd = ori_x |
| 214 | + self.cnt += 1 |
| 215 | + if self.cnt >= self.num_steps: |
| 216 | + self.cnt = 0 |
| 217 | + else: |
| 218 | + # Context Parallel |
| 219 | + for block in self.blocks: |
| 220 | + x = block(x, **kwargs) |
143 | 221 |
|
144 | 222 | # head |
145 | 223 | if e.ndim == 3: |
146 | 224 | e = torch.chunk(e, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] |
147 | 225 | x = self.head(x, e) |
148 | | - |
149 | 226 | # Context Parallel |
150 | 227 | x = get_sp_group().all_gather(x, dim=1) |
151 | | - |
152 | 228 | # unpatchify |
153 | 229 | x = self.unpatchify(x, grid_sizes) |
154 | 230 | return x.float() |
|
0 commit comments