Skip to content

[QST] What is the better way to fuse bias vector into GeMM? #2655

@Algy

Description

@Algy

I want to make use of CUTLASS GeMM API to implement a fully-connect layer with a bias vector. i.e., Y = W@X + b

As far as I know, it seems there are two methods to fuse bias into GeMM.

  1. As GeMM is defined as D = αAB + βC, we can take advantage of the "stride trick" - set β to 1 and set the leading dimension of the stride of C to 0. Then, we can treat C as the bias vector.

  2. Taking advantage of LinCombPerRowBias<...> or more directly EVT, add bias_ptr into the epilogue.

For my usecases, the first approach is a simpler way to implement the bias fusion. However,
On blackwell SM 12.0, I find some problematic gemm shapes (5120 x 13 x 2880, FYI) with the first approach. When I apply it to gemm of mxfp8 x mxfp8 -> bf16, it occasionally fails with some odd "warp illegal address" but I don't know why.
The second approach appears to be way more robust causing no problem til now. However, it's somewhat cumbersome as I utilize no-TMA epilogue in some cases for better latency.

The question is, is method#1 considered as a valid way for the fusion? Is there any concern regarding TMA load or any conflict to an underlying mechanism?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions