Skip to content

Commit 78a9cb8

Browse files
authored
[Fix] fix seed in dmd denoising loop (#736)
1 parent a0bff12 commit 78a9cb8

File tree

3 files changed

+28
-36
lines changed

3 files changed

+28
-36
lines changed

csrc/attn/tests/test_vsa.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
BLOCK_N = 64
1414

1515
def pytorch_test(Q, K, V, block_sparse_mask, dO):
16-
q_ = Q.clone().requires_grad_()
17-
k_ = K.clone().requires_grad_()
18-
v_ = V.clone().requires_grad_()
16+
q_ = Q.clone().float().requires_grad_()
17+
k_ = K.clone().float().requires_grad_()
18+
v_ = V.clone().float().requires_grad_()
1919

2020
QK = torch.matmul(q_, k_.transpose(-2, -1))
2121
QK /= (q_.size(-1) ** 0.5)
@@ -35,9 +35,9 @@ def pytorch_test(Q, K, V, block_sparse_mask, dO):
3535

3636

3737
def block_sparse_kernel_test(Q, K, V, block_sparse_mask, variable_block_sizes, non_pad_index, dO):
38-
Q = Q.clone().requires_grad_()
39-
K = K.clone().requires_grad_()
40-
V = V.clone().requires_grad_()
38+
Q = Q.detach().requires_grad_()
39+
K = K.detach().requires_grad_()
40+
V = V.detach().requires_grad_()
4141

4242
q_padded = vsa_pad(Q, non_pad_index, variable_block_sizes.shape[0], BLOCK_M)
4343
k_padded = vsa_pad(K, non_pad_index, variable_block_sizes.shape[0], BLOCK_M)
@@ -60,11 +60,9 @@ def get_non_pad_index(
6060

6161
return index_pad[index_mask]
6262

63-
def generate_tensor(shape, mean, std, dtype, device):
63+
def generate_tensor(shape, dtype, device):
6464
tensor = torch.randn(shape, dtype=dtype, device=device)
65-
magnitude = torch.norm(tensor, dim=-1, keepdim=True)
66-
scaled_tensor = tensor * (torch.randn(magnitude.shape, dtype=dtype, device=device) * std + mean) / magnitude
67-
return scaled_tensor.contiguous()
65+
return tensor
6866

6967
def generate_variable_block_sizes(num_blocks, min_size=32, max_size=64, device="cuda"):
7068
return torch.randint(min_size, max_size + 1, (num_blocks,), device=device, dtype=torch.int32)
@@ -75,7 +73,7 @@ def vsa_pad(x, non_pad_index, num_blocks, block_size):
7573
padded_x[:, :, non_pad_index, :] = x
7674
return padded_x
7775

78-
def check_correctness(h, d, num_blocks, k, mean, std, num_iterations=20, error_mode='all'):
76+
def check_correctness(h, d, num_blocks, k, num_iterations=20, error_mode='all'):
7977
results = {
8078
'gO': {'sum_diff': 0.0, 'sum_abs': 0.0, 'max_diff': 0.0},
8179
'gQ': {'sum_diff': 0.0, 'sum_abs': 0.0, 'max_diff': 0.0},
@@ -91,10 +89,10 @@ def check_correctness(h, d, num_blocks, k, mean, std, num_iterations=20, error_m
9189
block_mask = generate_block_sparse_mask_for_function(h, num_blocks, k, device)
9290
full_mask = create_full_mask_from_block_mask(block_mask, variable_block_sizes, device)
9391
for _ in range(num_iterations):
94-
Q = generate_tensor((1, h, S, d), mean, std, torch.bfloat16, device)
95-
K = generate_tensor((1, h, S, d), mean, std, torch.bfloat16, device)
96-
V = generate_tensor((1, h, S, d), mean, std, torch.bfloat16, device)
97-
dO = generate_tensor((1, h, S, d), mean, std, torch.bfloat16, device)
92+
Q = generate_tensor((1, h, S, d), torch.bfloat16, device)
93+
K = generate_tensor((1, h, S, d), torch.bfloat16, device)
94+
V = generate_tensor((1, h, S, d), torch.bfloat16, device)
95+
dO = generate_tensor((1, h, S, d), torch.bfloat16, device)
9896

9997
# dO_padded = torch.zeros_like(dO_padded)
10098
# dO_padded[:, :, non_pad_index, :] = dO
@@ -107,7 +105,8 @@ def check_correctness(h, d, num_blocks, k, mean, std, num_iterations=20, error_m
107105
abs_diff = torch.abs(diff)
108106
results[name]['sum_diff'] += torch.sum(abs_diff).item()
109107
results[name]['sum_abs'] += torch.sum(torch.abs(pt)).item()
110-
results[name]['max_diff'] = max(results[name]['max_diff'], torch.max(abs_diff).item())
108+
rel_max_diff = torch.max(abs_diff) / torch.mean(torch.abs(pt))
109+
results[name]['max_diff'] = max(results[name]['max_diff'], rel_max_diff.item())
111110
if torch.cuda.is_available():
112111
torch.cuda.empty_cache()
113112

@@ -119,27 +118,27 @@ def check_correctness(h, d, num_blocks, k, mean, std, num_iterations=20, error_m
119118

120119
return results
121120

122-
def generate_error_graphs(h, d, mean, std, error_mode='all'):
121+
def generate_error_graphs(h, d, error_mode='all'):
123122
test_configs = [
124123
{"num_blocks": 16, "k": 2, "description": "Small sequence"},
125124
{"num_blocks": 32, "k": 4, "description": "Medium sequence"},
126125
{"num_blocks": 53, "k": 6, "description": "Large sequence"},
127126
]
128127

129-
print(f"\nError Analysis for h={h}, d={d}, mean={mean}, std={std}, mode={error_mode}")
128+
print(f"\nError Analysis for h={h}, d={d}, mode={error_mode}")
130129
print("=" * 150)
131130
print(f"{'Config':<20} {'Blocks':<8} {'K':<4} "
132-
f"{'gQ Avg':<12} {'gQ Max':<12} "
133-
f"{'gK Avg':<12} {'gK Max':<12} "
134-
f"{'gV Avg':<12} {'gV Max':<12} "
135-
f"{'gO Avg':<12} {'gO Max':<12}")
131+
f"{'gQ Avg':<12} {'Rel gQ Max':<12} "
132+
f"{'gK Avg':<12} {'Rel gK Max':<12} "
133+
f"{'gV Avg':<12} {'Rel gV Max':<12} "
134+
f"{'gO Avg':<12} {'Rel gO Max':<12}")
136135
print("-" * 150)
137136

138137
for config in test_configs:
139138
num_blocks = config["num_blocks"]
140139
k = config["k"]
141140
description = config["description"]
142-
results = check_correctness(h, d, num_blocks, k, mean, std, error_mode=error_mode)
141+
results = check_correctness(h, d, num_blocks, k, error_mode=error_mode)
143142
print(f"{description:<20} {num_blocks:<8} {k:<4} "
144143
f"{results['gQ']['avg_diff']:<12.6e} {results['gQ']['max_diff']:<12.6e} "
145144
f"{results['gK']['avg_diff']:<12.6e} {results['gK']['max_diff']:<12.6e} "
@@ -150,10 +149,8 @@ def generate_error_graphs(h, d, mean, std, error_mode='all'):
150149

151150
if __name__ == "__main__":
152151
h, d = 16, 128
153-
mean = 0.0
154-
std = 1
155152
print("Block Sparse Attention with Variable Block Sizes Analysis")
156153
print("=" * 60)
157154
for mode in ['backward']:
158-
generate_error_graphs(h, d, mean, std, error_mode=mode)
155+
generate_error_graphs(h, d, error_mode=mode)
159156
print("\nAnalysis completed for all modes.")

csrc/attn/vsa/block_sparse_h100.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,6 @@ void bwd_attend_ker(const __grid_constant__ bwd_globals<D> g) {
568568
__syncthreads(); // wait for sd_smem shared memory write
569569
warpgroup::mm_AtB(qg_reg, ds_smem_t[0], k_smem[0]); //delat dQ = dSK
570570
warpgroup::mma_commit_group();
571-
tma::store_async_wait();
572571
warpgroup::mma_async_wait();
573572
// store qg to shared memory
574573
warpgroup::store(qg_smem, qg_reg);
@@ -625,7 +624,6 @@ void bwd_attend_ker(const __grid_constant__ bwd_globals<D> g) {
625624
__syncthreads(); // wait for sd_smem shared memory write
626625
warpgroup::mm_AtB(qg_reg, ds_smem_t[0], k_smem[0]); //delat dQ = dSK
627626
warpgroup::mma_commit_group();
628-
tma::store_async_wait();
629627
warpgroup::mma_async_wait();
630628
// store qg to shared memory
631629
warpgroup::store(qg_smem, qg_reg);

fastvideo/pipelines/stages/denoising.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -673,12 +673,8 @@ def forward(
673673
# Get latents and embeddings
674674
assert batch.latents is not None, "latents must be provided"
675675
latents = batch.latents
676-
# TODO(yongqi) hard code prepare latents
677-
latents = torch.randn(
678-
latents.permute(0, 2, 1, 3, 4).shape,
679-
dtype=torch.bfloat16,
680-
device="cuda",
681-
generator=torch.Generator(device="cuda").manual_seed(42))
676+
latents = latents.permute(0, 2, 1, 3, 4)
677+
682678
video_raw_latent_shape = latents.shape
683679
prompt_embeds = batch.prompt_embeds
684680
assert torch.isnan(prompt_embeds[0]).sum() == 0
@@ -795,8 +791,9 @@ def forward(
795791
next_timestep = timesteps[i + 1] * torch.ones(
796792
[1], dtype=torch.long, device=pred_video.device)
797793
noise = torch.randn(video_raw_latent_shape,
798-
device=self.device,
799-
dtype=pred_video.dtype)
794+
dtype=pred_video.dtype,
795+
generator=batch.generator[0]).to(
796+
self.device)
800797
if sp_group:
801798
noise = rearrange(noise,
802799
"b (n t) c h w -> b n t c h w",

0 commit comments

Comments
 (0)