-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Open
Labels
Description
when ElementAccumulatorQK=float, acc_qk = q@k, will exist in tensor memory with float type, when next gemm acc_qk @ v, type of acc_qk is float, type of v is cutlass::half_t, acc_qk should trans to half_t firstly, but I don't see the transform code in sm100_fmha_fwd_kernel_tma_warpspecialized.hpp, where is it ? And question2, for the code: tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); why should tOrP0 offset 32 before matmul with v ? shouldn't be tmem_base_addr? I see the tOrP0 layout is tmem_16b o ((_128,_16),_1,(_4,_2)):((_131072,_1),_0,(_16,_64))
,why is the row coordinate stride _131072 , not the original _65536 ?
I am very confused and would appreciate it if you could answer my question. Thank you.