Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 122 additions & 48 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def run_dpa_with_cp(
if config.attn_bias_type not in ["no_bias", "alibi"]:
attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv)
bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda()
bias.requires_grad = True
else:
bias = None

Expand Down Expand Up @@ -338,7 +339,7 @@ def run_dpa_with_cp(
out.backward(dout_fp8)
else:
out.backward(dout)
dq, dk, dv = q.grad, k.grad, v.grad
dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bias is None when attn_bias_type is "no_bias" or "alibi" (line 312), so bias.grad will raise AttributeError

Suggested change
dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad
dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None

d_softmax_offset = None
if config.softmax_type != "vanilla":
d_softmax_offset = core_attn.softmax_offset.grad
Expand Down Expand Up @@ -394,6 +395,7 @@ def run_dpa_with_cp(
)
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
bias_.requires_grad = True
# set up environment
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
Expand Down Expand Up @@ -433,23 +435,23 @@ def run_dpa_with_cp(
out_.backward(dout_fp8_)
else:
out_.backward(dout_)
dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad
dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bias_ is None when bias is None (line 355), so bias_.grad will raise AttributeError

Suggested change
dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad
dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad if bias_ is not None else None

d_softmax_offset_ = None
if config.softmax_type != "vanilla":
d_softmax_offset_ = core_attn.softmax_offset.grad.clone()

# get outputs
tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_]
tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_]
if fp8_mha:
tensors_to_deq = [out, out_] if not fp8_bwd else tensors
for i, tensor in enumerate(tensors_to_deq):
tensors_to_deq[i] = tensor.dequantize()
if not fp8_bwd:
tensors[0], tensors[4] = tensors_to_deq
tensors[0], tensors[5] = tensors_to_deq
for tensor in tensors:
assert torch.all(~torch.isnan(tensor))
assert torch.all(~torch.isinf(tensor))
out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors
out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors

############ compare results between CP and no-CP ############
if qkv_format == "bshd" or qkv_format == "sbhd":
Expand All @@ -467,6 +469,22 @@ def run_dpa_with_cp(
x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :])
for x in [dq_, dk_, dv_, out_]
]
if dbias is not None and dbias_ is not None:
dbias = dbias.view(
dbias.shape[0],
dbias.shape[1],
2 * world_size,
dbias.shape[2] // (2 * world_size),
dbias.shape[3],
)
# bias has fixed axis (2) as dbias shape: (1, 1, max_seqlen_q, max_seqlen_kv)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think our CP implementation (after your C changes) should support all bias shapes, not just 111s. I also think your reshaping here should work for all shapes. Could you run the tests to confirm?

dbias = dbias.index_select(2, seq_idx)
# Flatten
dbias = dbias.view(dbias.shape[0], dbias.shape[1], -1, dbias.shape[-1])
dbias_ = dbias_.view(
dbias_.shape[0], dbias_.shape[1], 2, dbias_.shape[2] // 2, dbias_.shape[3]
)

elif qkv_format == "thd":
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]]
Expand Down Expand Up @@ -509,57 +527,113 @@ def run_dpa_with_cp(
)

atol, rtol, rmse_tol = get_tols(config, dtype)
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit]
names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"]
tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_]
tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit]
names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"]
names_cp = [x + "_cp" for x in names]
names_no_cp = [x + "_no_cp" for x in names]
is_fp8 = dtype == "fp8"
for i, t in enumerate(tensors_no_cp):
if t is not None:
if "softmax_offset" not in names[i] and "max_logit" not in names[i]:
if qkv_format == "bshd":
compare_and_assert(
t[:, 0],
tensors_cp[i][:, 0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[:, 1],
tensors_cp[i][:, 1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
# Compare the two sequence chunks separately
# Compare dbias
if names[i] == "dbias":
# After reshaping: (1, 1, 2, seq_q//2, seq_kv)
# Compare along dimension 2 (the split sequence dimension)
compare_and_assert(
t[:, :, 0], # First sequence chunk
tensors_cp[i][:, :, 0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[:, :, 1], # Second sequence chunk
tensors_cp[i][:, :, 1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
# Compare Q/K/V/out
else:
# Compare along dimension 1 (the split sequence dimension)
compare_and_assert(
t[:, 0],
tensors_cp[i][:, 0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[:, 1],
tensors_cp[i][:, 1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "sbhd":
compare_and_assert(
t[0],
tensors_cp[i][0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[1],
tensors_cp[i][1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
# Compare the two sequence chunks separately
# Compare dbias (same as BSHD)
if names[i] == "dbias":
# After reshaping: (1, 1, 2, seq_q//2, seq_kv)
# Compare along dimension 2 (the split sequence dimension)
compare_and_assert(
t[:, :, 0], # First sequence chunk
tensors_cp[i][:, :, 0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[:, :, 1], # Second sequence chunk
tensors_cp[i][:, :, 1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
# Compare Q/K/V/out
else:
# Compare along dimension 0 (the split sequence dimension)
compare_and_assert(
t[0],
tensors_cp[i][0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[1],
tensors_cp[i][1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "thd":
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
Expand Down
5 changes: 2 additions & 3 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def test_dpa_mask(dtype, model_configs, model):

model_configs_bias = {
# test: ModelConfig(b, sq, hq, dqk)
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="111s"),
"bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
"bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
Expand Down Expand Up @@ -1118,11 +1118,10 @@ def _run_dot_product_attention(
bias = None
if config.attn_bias_type == "post_scale_bias":
shape = "_".join(config.bias_shape)
shape = shape.replace("_1_s", "_1_skv")
shape = shape.replace("_s_s", "_sq_skv")
tensor_shape = [dim_to_num[j] for j in shape.split("_")]
bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
if config.bias_shape != "1hss":
bias.requires_grad = False

# Create RNG
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
Expand Down
Loading
Loading