-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Description
I want to make use of CUTLASS GeMM API to implement a fully-connect layer with a bias vector. i.e., Y = W@X + b
As far as I know, it seems there are two methods to fuse bias into GeMM.
-
As GeMM is defined as
D = αAB + βC, we can take advantage of the "stride trick" - set β to 1 and set the leading dimension of the stride of C to 0. Then, we can treatCas the bias vector. -
Taking advantage of
LinCombPerRowBias<...>or more directly EVT, addbias_ptrinto the epilogue.
For my usecases, the first approach is a simpler way to implement the bias fusion. However,
On blackwell SM 12.0, I find some problematic gemm shapes (5120 x 13 x 2880, FYI) with the first approach. When I apply it to gemm of mxfp8 x mxfp8 -> bf16, it occasionally fails with some odd "warp illegal address" but I don't know why.
The second approach appears to be way more robust causing no problem til now. However, it's somewhat cumbersome as I utilize no-TMA epilogue in some cases for better latency.
The question is, is method#1 considered as a valid way for the fusion? Is there any concern regarding TMA load or any conflict to an underlying mechanism?