@@ -371,6 +371,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
371371
372372// -----
373373
374+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 2 ], order = [1 , 0 ]}>
375+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [16 , 2 ], warpsPerCTA = [2 , 1 ], order = [1 , 0 ]}>
376+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [2 , 1 ], order = [1 , 0 ]}>
377+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 2 : i32 , ttg.target = " cuda:100" , " ttg.threads-per-warp" = 32 : i32 } {
378+ // Make sure we fall back to mmav2 when num warps < 4
379+ // CHECK-LABEL: block_scaled_2_warps
380+ // CHECK: tt.dot
381+ // CHECK: tt.return
382+ tt.func public @block_scaled_2_warps (%a: tensor <128 x64 xf8 E4 M3 FN, #blocked2 >, %scale_a: tensor <128 x2 xi8 , #blocked1 >, %b: tensor <64 x128 xf8 E4 M3 FN, #blocked >, %scale_b: tensor <128 x2 xi8 , #blocked1 >) -> tensor <128 x128 xf32 , #blocked > {
383+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #blocked >
384+ %d = tt.dot_scaled %a scale %scale_a , %b scale %scale_b , %cst lhs = e4m3 rhs = e4m3 {fastMath = false } : tensor <128 x64 xf8 E4 M3 FN, #blocked2 >, tensor <128 x2 xi8 , #blocked1 > * tensor <64 x128 xf8 E4 M3 FN, #blocked >, tensor <128 x2 xi8 , #blocked1 > -> tensor <128 x128 xf32 , #blocked >
385+ tt.return %d : tensor <128 x128 xf32 , #blocked >
386+ }
387+ }
388+
389+ // -----
390+
374391// Verify that dot_scaled (mxfp4 x {bf16,fp8}) decomposes to mmav3 if it's bf16, otherwise it fallsback to mmav2
375392#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
376393#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [16 , 2 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
0 commit comments