You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Improves backward handling when bias is broadcast across sequence or batch by allocating correctly shaped scratch buffers and adjusting reduction paths. Adds a kernel parameter to accumulate along sequence for S=1 bias, and uses fp32 buffers for numerically stable accumulation.
Corrects the previous over-eager scratch allocation on batch-size mismatch to only trigger for shared (B=1) or head-grouped cases, aligning with broadcasting semantics (incl. MQA/GQA). Leaves the variable-length path unchanged (no accumulation).
Results in correct dbias reductions and gradients for broadcasted bias with better numerical stability.
0 commit comments