-
Notifications
You must be signed in to change notification settings - Fork 603
[PyT] Plumbing correct bias dims from TE to cudnn #2537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8da3252
d3f15bf
95bdac9
0e7fcae
1ce6ffb
143ede5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -307,6 +307,7 @@ def run_dpa_with_cp( | |||||
| if config.attn_bias_type not in ["no_bias", "alibi"]: | ||||||
| attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) | ||||||
| bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() | ||||||
| bias.requires_grad = True | ||||||
| else: | ||||||
| bias = None | ||||||
|
|
||||||
|
|
@@ -338,7 +339,7 @@ def run_dpa_with_cp( | |||||
| out.backward(dout_fp8) | ||||||
| else: | ||||||
| out.backward(dout) | ||||||
| dq, dk, dv = q.grad, k.grad, v.grad | ||||||
| dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad | ||||||
| d_softmax_offset = None | ||||||
| if config.softmax_type != "vanilla": | ||||||
| d_softmax_offset = core_attn.softmax_offset.grad | ||||||
|
|
@@ -394,6 +395,7 @@ def run_dpa_with_cp( | |||||
| ) | ||||||
| bias_ = bias_.index_select(2, seq_idx) | ||||||
| bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) | ||||||
| bias_.requires_grad = True | ||||||
| # set up environment | ||||||
| core_attn.set_context_parallel_group( | ||||||
| cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, | ||||||
|
|
@@ -433,23 +435,23 @@ def run_dpa_with_cp( | |||||
| out_.backward(dout_fp8_) | ||||||
| else: | ||||||
| out_.backward(dout_) | ||||||
| dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad | ||||||
| dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| d_softmax_offset_ = None | ||||||
| if config.softmax_type != "vanilla": | ||||||
| d_softmax_offset_ = core_attn.softmax_offset.grad.clone() | ||||||
|
|
||||||
| # get outputs | ||||||
| tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_] | ||||||
| tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] | ||||||
| if fp8_mha: | ||||||
| tensors_to_deq = [out, out_] if not fp8_bwd else tensors | ||||||
| for i, tensor in enumerate(tensors_to_deq): | ||||||
| tensors_to_deq[i] = tensor.dequantize() | ||||||
| if not fp8_bwd: | ||||||
| tensors[0], tensors[4] = tensors_to_deq | ||||||
| tensors[0], tensors[5] = tensors_to_deq | ||||||
| for tensor in tensors: | ||||||
| assert torch.all(~torch.isnan(tensor)) | ||||||
| assert torch.all(~torch.isinf(tensor)) | ||||||
| out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors | ||||||
| out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors | ||||||
|
|
||||||
| ############ compare results between CP and no-CP ############ | ||||||
| if qkv_format == "bshd" or qkv_format == "sbhd": | ||||||
|
|
@@ -467,6 +469,22 @@ def run_dpa_with_cp( | |||||
| x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) | ||||||
| for x in [dq_, dk_, dv_, out_] | ||||||
| ] | ||||||
| if dbias is not None and dbias_ is not None: | ||||||
| dbias = dbias.view( | ||||||
| dbias.shape[0], | ||||||
| dbias.shape[1], | ||||||
| 2 * world_size, | ||||||
| dbias.shape[2] // (2 * world_size), | ||||||
| dbias.shape[3], | ||||||
| ) | ||||||
| # bias has fixed axis (2) as dbias shape: (1, 1, max_seqlen_q, max_seqlen_kv) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think our CP implementation (after your C changes) should support all bias shapes, not just 111s. I also think your reshaping here should work for all shapes. Could you run the tests to confirm? |
||||||
| dbias = dbias.index_select(2, seq_idx) | ||||||
| # Flatten | ||||||
| dbias = dbias.view(dbias.shape[0], dbias.shape[1], -1, dbias.shape[-1]) | ||||||
| dbias_ = dbias_.view( | ||||||
| dbias_.shape[0], dbias_.shape[1], 2, dbias_.shape[2] // 2, dbias_.shape[3] | ||||||
| ) | ||||||
|
|
||||||
| elif qkv_format == "thd": | ||||||
| dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] | ||||||
| dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] | ||||||
|
|
@@ -509,57 +527,113 @@ def run_dpa_with_cp( | |||||
| ) | ||||||
|
|
||||||
| atol, rtol, rmse_tol = get_tols(config, dtype) | ||||||
| tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_] | ||||||
| tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit] | ||||||
| names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"] | ||||||
| tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_] | ||||||
| tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit] | ||||||
| names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"] | ||||||
| names_cp = [x + "_cp" for x in names] | ||||||
| names_no_cp = [x + "_no_cp" for x in names] | ||||||
| is_fp8 = dtype == "fp8" | ||||||
| for i, t in enumerate(tensors_no_cp): | ||||||
| if t is not None: | ||||||
| if "softmax_offset" not in names[i] and "max_logit" not in names[i]: | ||||||
| if qkv_format == "bshd": | ||||||
| compare_and_assert( | ||||||
| t[:, 0], | ||||||
| tensors_cp[i][:, 0], | ||||||
| names_no_cp[i], | ||||||
| names_cp[i], | ||||||
| atol, | ||||||
| rtol, | ||||||
| rmse_tol, | ||||||
| is_fp8, | ||||||
| ) | ||||||
| compare_and_assert( | ||||||
| t[:, 1], | ||||||
| tensors_cp[i][:, 1], | ||||||
| names_no_cp[i], | ||||||
| names_cp[i], | ||||||
| atol, | ||||||
| rtol, | ||||||
| rmse_tol, | ||||||
| is_fp8, | ||||||
| ) | ||||||
| # Compare the two sequence chunks separately | ||||||
| # Compare dbias | ||||||
| if names[i] == "dbias": | ||||||
| # After reshaping: (1, 1, 2, seq_q//2, seq_kv) | ||||||
| # Compare along dimension 2 (the split sequence dimension) | ||||||
| compare_and_assert( | ||||||
| t[:, :, 0], # First sequence chunk | ||||||
| tensors_cp[i][:, :, 0], | ||||||
| names_no_cp[i], | ||||||
| names_cp[i], | ||||||
| atol, | ||||||
| rtol, | ||||||
| rmse_tol, | ||||||
| is_fp8, | ||||||
| ) | ||||||
| compare_and_assert( | ||||||
| t[:, :, 1], # Second sequence chunk | ||||||
| tensors_cp[i][:, :, 1], | ||||||
| names_no_cp[i], | ||||||
| names_cp[i], | ||||||
| atol, | ||||||
| rtol, | ||||||
| rmse_tol, | ||||||
| is_fp8, | ||||||
| ) | ||||||
| # Compare Q/K/V/out | ||||||
| else: | ||||||
| # Compare along dimension 1 (the split sequence dimension) | ||||||
| compare_and_assert( | ||||||
| t[:, 0], | ||||||
| tensors_cp[i][:, 0], | ||||||
| names_no_cp[i], | ||||||
| names_cp[i], | ||||||
| atol, | ||||||
| rtol, | ||||||
| rmse_tol, | ||||||
| is_fp8, | ||||||
| ) | ||||||
| compare_and_assert( | ||||||
| t[:, 1], | ||||||
| tensors_cp[i][:, 1], | ||||||
| names_no_cp[i], | ||||||
| names_cp[i], | ||||||
| atol, | ||||||
| rtol, | ||||||
| rmse_tol, | ||||||
| is_fp8, | ||||||
| ) | ||||||
| elif qkv_format == "sbhd": | ||||||
| compare_and_assert( | ||||||
| t[0], | ||||||
| tensors_cp[i][0], | ||||||
| names_no_cp[i], | ||||||
| names_cp[i], | ||||||
| atol, | ||||||
| rtol, | ||||||
| rmse_tol, | ||||||
| is_fp8, | ||||||
| ) | ||||||
| compare_and_assert( | ||||||
| t[1], | ||||||
| tensors_cp[i][1], | ||||||
| names_no_cp[i], | ||||||
| names_cp[i], | ||||||
| atol, | ||||||
| rtol, | ||||||
| rmse_tol, | ||||||
| is_fp8, | ||||||
| ) | ||||||
| # Compare the two sequence chunks separately | ||||||
| # Compare dbias (same as BSHD) | ||||||
| if names[i] == "dbias": | ||||||
| # After reshaping: (1, 1, 2, seq_q//2, seq_kv) | ||||||
| # Compare along dimension 2 (the split sequence dimension) | ||||||
| compare_and_assert( | ||||||
| t[:, :, 0], # First sequence chunk | ||||||
| tensors_cp[i][:, :, 0], | ||||||
| names_no_cp[i], | ||||||
| names_cp[i], | ||||||
| atol, | ||||||
| rtol, | ||||||
| rmse_tol, | ||||||
| is_fp8, | ||||||
| ) | ||||||
| compare_and_assert( | ||||||
| t[:, :, 1], # Second sequence chunk | ||||||
| tensors_cp[i][:, :, 1], | ||||||
| names_no_cp[i], | ||||||
| names_cp[i], | ||||||
| atol, | ||||||
| rtol, | ||||||
| rmse_tol, | ||||||
| is_fp8, | ||||||
| ) | ||||||
| # Compare Q/K/V/out | ||||||
| else: | ||||||
| # Compare along dimension 0 (the split sequence dimension) | ||||||
| compare_and_assert( | ||||||
| t[0], | ||||||
| tensors_cp[i][0], | ||||||
| names_no_cp[i], | ||||||
| names_cp[i], | ||||||
| atol, | ||||||
| rtol, | ||||||
| rmse_tol, | ||||||
| is_fp8, | ||||||
| ) | ||||||
| compare_and_assert( | ||||||
| t[1], | ||||||
| tensors_cp[i][1], | ||||||
| names_no_cp[i], | ||||||
| names_cp[i], | ||||||
| atol, | ||||||
| rtol, | ||||||
| rmse_tol, | ||||||
| is_fp8, | ||||||
| ) | ||||||
| elif qkv_format == "thd": | ||||||
| compare_and_assert( | ||||||
| t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
biasisNonewhenattn_bias_typeis "no_bias" or "alibi" (line 312), sobias.gradwill raiseAttributeError