Skip to content

Commit ea1f83d

Browse files
authored
hotfix accuracy and step num (#27)
* fix * update settings * Update dtype to torch.float32 in DiT model * Remove unnecessary assert statement in DiT class * Refactor logging and checkpoint saving in train_img.py
1 parent 527b13a commit ea1f83d

File tree

4 files changed

+44
-17
lines changed

4 files changed

+44
-17
lines changed

opendit/models/dit.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
import torch
1616
import torch.distributed as dist
1717
import torch.nn as nn
18-
import torch.nn.functional as F
1918
import torch.utils.checkpoint
20-
from timm.models.vision_transformer import Mlp, PatchEmbed, use_fused_attn
19+
from timm.models.vision_transformer import Mlp, PatchEmbed
2120
from torch.jit import Final
2221

2322
from opendit.models.clip import TextEmbedder
@@ -158,7 +157,6 @@ def __init__(
158157
self.num_heads = num_heads
159158
self.head_dim = dim // num_heads
160159
self.scale = self.head_dim**-0.5
161-
self.fused_attn = use_fused_attn()
162160

163161
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
164162
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
@@ -236,13 +234,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
236234
dropout_p=self.attn_drop.p if self.training else 0.0,
237235
softmax_scale=self.scale,
238236
)
239-
elif self.fused_attn:
240-
x = F.scaled_dot_product_attention(
241-
q,
242-
k,
243-
v,
244-
dropout_p=self.attn_drop.p if self.training else 0.0,
245-
)
246237
else:
247238
dtype = q.dtype
248239
q = q * self.scale
@@ -260,7 +251,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
260251
if self.sequence_parallel_size == 1
261252
else (B, N * self.sequence_parallel_size, num_heads * self.head_dim)
262253
)
263-
x = x.transpose(1, 2).reshape(x_output_shape)
254+
if self.enable_flashattn:
255+
x = x.reshape(x_output_shape)
256+
else:
257+
x = x.transpose(1, 2).reshape(x_output_shape)
258+
264259
if self.sequence_parallel_size > 1:
265260
# Todo: Use all_to_all_single for x
266261
# x = x.reshape(1, -1, num_heads * self.head_dim)
@@ -355,6 +350,7 @@ def __init__(
355350
enable_layernorm_kernel=False,
356351
enable_modulate_kernel=False,
357352
sequence_parallel_size=1,
353+
dtype=torch.float32,
358354
):
359355
super().__init__()
360356
self.learn_sigma = learn_sigma
@@ -363,6 +359,12 @@ def __init__(
363359
self.patch_size = patch_size
364360
self.num_heads = num_heads
365361
self.sequence_parallel_size = sequence_parallel_size
362+
self.dtype = dtype
363+
if enable_flashattn:
364+
assert dtype in [
365+
torch.float16,
366+
torch.bfloat16,
367+
], f"Flash attention only supports float16 and bfloat16, but got {self.dtype}"
366368

367369
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
368370
self.t_embedder = TimestepEmbedder(hidden_size)
@@ -470,6 +472,10 @@ def forward(self, x, t, y):
470472

471473
# Todo: Mock video input by repeating the same frame for all timesteps
472474
# x = torch.randn(2, 256, 1152).to(torch.bfloat16).cuda()
475+
476+
# origin inputs should be float32, cast to specified dtype
477+
x = x.to(self.dtype)
478+
473479
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
474480
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
475481
y = self.y_embedder(y, self.training) # (N, D)
@@ -490,6 +496,9 @@ def forward(self, x, t, y):
490496

491497
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
492498
x = self.unpatchify(x) # (N, out_channels, H, W)
499+
500+
# cast to float32 for better accuracy
501+
x = x.to(torch.float32)
493502
return x
494503

495504
def forward_with_cfg(self, x, t, y, cfg_scale):

sample.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,18 @@ def main(args):
3434

3535
# Load model:
3636
latent_size = args.image_size // 8
37-
model = DiT_models[args.model](input_size=latent_size, num_classes=args.num_classes).to(device)
37+
dtype = torch.float32
38+
model = (
39+
DiT_models[args.model](
40+
input_size=latent_size,
41+
num_classes=args.num_classes,
42+
enable_flashattn=False,
43+
enable_layernorm_kernel=False,
44+
dtype=dtype,
45+
)
46+
.to(device)
47+
.to(dtype)
48+
)
3849
# Auto-download a pre-trained model or load a custom DiT checkpoint from train.py:
3950
ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt"
4051
state_dict = find_model(ckpt_path)

train_img.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import colossalai
1616
import torch
17+
import torch.distributed as dist
1718
from colossalai.booster import Booster
1819
from colossalai.booster.plugin import LowLevelZeroPlugin
1920
from colossalai.cluster import DistCoordinator
@@ -60,6 +61,7 @@ def main(args):
6061
model_string_name = args.model.replace("/", "-")
6162
# Create an experiment folder
6263
experiment_dir = f"{args.outputs}/{experiment_index:03d}-{model_string_name}"
64+
dist.barrier()
6365
if coordinator.is_master():
6466
os.makedirs(experiment_dir, exist_ok=True)
6567
with open(f"{experiment_dir}/config.txt", "w") as f:
@@ -113,6 +115,7 @@ def main(args):
113115
enable_layernorm_kernel=args.enable_layernorm_kernel,
114116
enable_modulate_kernel=args.enable_modulate_kernel,
115117
sequence_parallel_size=args.sequence_parallel_size,
118+
dtype=dtype,
116119
)
117120
.to(device)
118121
.to(dtype)
@@ -208,7 +211,6 @@ def main(args):
208211
with torch.no_grad():
209212
# Map input images to latent space + normalize latents:
210213
x = vae.encode(x).latent_dist.sample().mul_(0.18215)
211-
x = x.to(dtype)
212214

213215
# Diffusion
214216
t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
@@ -224,11 +226,15 @@ def main(args):
224226

225227
# Log loss values:
226228
all_reduce_mean(loss)
227-
if coordinator.is_master() and (step + 1) % args.log_every == 0:
228-
pbar.set_postfix({"loss": loss.item()})
229-
writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step)
229+
global_step = epoch * num_steps_per_epoch + step
230+
pbar.set_postfix({"loss": loss.item(), "step": step, "global_step": global_step})
230231

231-
if args.ckpt_every > 0 and (step + 1) % args.ckpt_every == 0:
232+
# Log to tensorboard
233+
if coordinator.is_master() and (global_step + 1) % args.log_every == 0:
234+
writer.add_scalar("loss", loss.item(), global_step)
235+
236+
# Save checkpoint
237+
if args.ckpt_every > 0 and (global_step + 1) % args.ckpt_every == 0:
232238
logger.info(f"Saving checkpoint...")
233239
save(
234240
booster,

train_video.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def main(args):
106106
enable_layernorm_kernel=args.enable_layernorm_kernel,
107107
enable_modulate_kernel=args.enable_modulate_kernel,
108108
sequence_parallel_size=args.sequence_parallel_size,
109+
dtype=dtype,
109110
)
110111
.to(device)
111112
.to(dtype)

0 commit comments

Comments
 (0)