Skip to content

[QST]some question about blackwell fmha #2909

@liuqi123123

Description

@liuqi123123

when ElementAccumulatorQK=float, acc_qk = q@k, will exist in tensor memory with float type, when next gemm acc_qk @ v, type of acc_qk is float, type of v is cutlass::half_t, acc_qk should trans to half_t firstly, but I don't see the transform code in sm100_fmha_fwd_kernel_tma_warpspecialized.hpp, where is it ? And question2, for the code: tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); why should tOrP0 offset 32 before matmul with v ? shouldn't be tmem_base_addr? I see the tOrP0 layout is tmem_16b o ((_128,_16),_1,(_4,_2)):((_131072,_1),_0,(_16,_64))
,why is the row coordinate stride _131072 , not the original _65536 ?
I am very confused and would appreciate it if you could answer my question. Thank you.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions