@@ -417,3 +417,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
417417 tt.return
418418 }
419419}
420+
421+ // -----
422+
423+ // CHECK: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1, 1], repCluster = [1, 4, 2], A = [1, 32, 16], B = [1, 16, 32], C = [1, 32, 32]}>
424+ #blocked = #ttg.blocked <{sizePerThread = [1 , 4 , 4 ], threadsPerWarp = [1 , 1 , 16 ], warpsPerCTA = [1 , 4 , 1 ], order = [2 , 1 , 0 ]}>
425+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 , " ttig.min_sg_size" = 16 : i32 , ttig.support_dpas } {
426+ tt.func public @_helion_repro_baddbmm_kernel (%A: tensor <1 x64 x64 xbf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>>, %B: tensor <1 x64 x64 xbf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>>, %C: tensor <1 x64 x64 x!tt.ptr <bf16 >, #blocked >) {
427+ %cst = arith.constant dense <0.000000e+00 > : tensor <1 x64 x64 xf32 , #blocked >
428+ // CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<1x64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<1x64x64xbf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<1x64x64xf32, #[[$DPAS]]>
429+ %31 = tt.dot %A , %B , %cst , inputPrecision = tf32 : tensor <1 x64 x64 xbf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>> * tensor <1 x64 x64 xbf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>> -> tensor <1 x64 x64 xf32 , #blocked >
430+ %39 = arith.truncf %31 : tensor <1 x64 x64 xf32 , #blocked > to tensor <1 x64 x64 xbf16 , #blocked >
431+ %40 = ttg.convert_layout %39 : tensor <1 x64 x64 xbf16 , #blocked > -> tensor <1 x64 x64 xbf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>>
432+ // CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<1x64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<1x64x64xbf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<1x64x64xf32, #[[$DPAS]]>
433+ %42 = tt.dot %40 , %B , %cst , inputPrecision = tf32 : tensor <1 x64 x64 xbf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>> * tensor <1 x64 x64 xbf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>> -> tensor <1 x64 x64 xf32 , #blocked >
434+ %43 = arith.truncf %42 : tensor <1 x64 x64 xf32 , #blocked > to tensor <1 x64 x64 xbf16 , #blocked >
435+ tt.store %C , %43 : tensor <1 x64 x64 x!tt.ptr <bf16 >, #blocked >
436+ tt.return
437+ }
438+ }
0 commit comments