Skip to content

Fix incorrect error message and inconsistent ndim check#1377

Open
Mr-Neutr0n wants to merge 1 commit intofacebookresearch:mainfrom
Mr-Neutr0n:fix/attention-input-validation-bugs
Open

Fix incorrect error message and inconsistent ndim check#1377
Mr-Neutr0n wants to merge 1 commit intofacebookresearch:mainfrom
Mr-Neutr0n:fix/attention-input-validation-bugs

Conversation

@Mr-Neutr0n
Copy link

Summary

  • Fix misleading error message in _rand_seqlens_padded_k (attn_bias_utils.py): When q_len > kv_len, the function raises ValueError("need more queries than keys"), but the comments and logic clearly indicate the constraint is that there must be more keys than queries (bottom-right causal mask). Fixed the message to "need more keys than queries (kv_len must be >= q_len)".

  • Fix inconsistent ndim check in Inputs.get_qkv_in_bmghk (ops/fmha/common.py): The first two branches check self.query.ndim (for 5D and 4D), but the third branch checks self.value.ndim == 3 instead of self.query.ndim == 3. While validate_inputs() ensures all tensors share the same ndim, this inconsistency could cause silent wrong behavior if get_qkv_in_bmghk is called without prior validation and the tensors have mismatched dimensions.

Test plan

  • Verified the error message now correctly describes the constraint
  • Verified the ndim check is now consistent across all branches of get_qkv_in_bmghk
  • Existing tests should continue to pass (both fixes are behavioral no-ops when inputs are well-formed)

…code

Fix two bugs in attention input handling:

1. _rand_seqlens_padded_k in attn_bias_utils.py: The error message
   said "need more queries than keys" when the condition `q_len > kv_len`
   is triggered, but the intent (as documented in the comments) is that
   there must be more keys than queries. Fixed the message to say
   "need more keys than queries (kv_len must be >= q_len)".

2. get_qkv_in_bmghk in ops/fmha/common.py: The first two branches
   check `self.query.ndim` (for ndim==5 and ndim==4), but the third
   branch inconsistently checked `self.value.ndim == 3` instead of
   `self.query.ndim == 3`. While validate_inputs() ensures all tensors
   share the same ndim, this inconsistency could cause silent
   wrong behavior if get_qkv_in_bmghk is called without prior
   validation and the tensors have mismatched dimensions.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant