@@ -345,3 +345,57 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}
345345 tt.return
346346 }
347347}
348+
349+ // -----
350+
351+ // CHECK: #[[BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
352+ // CHECK: #[[DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
353+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
354+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
355+ #dpas = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [1 , 4 ], repCluster = [1 , 1 ], A = [8 , 16 ], B = [16 , 16 ], C = [8 , 16 ]}>
356+ #dot0 = #ttg.dot_op <{opIdx = 0 , parent = #dpas , kWidth =1 }>
357+ #dot1 = #ttg.dot_op <{opIdx = 1 , parent = #dpas , kWidth =2 }>
358+ module attributes {" ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 , " ttig.support_sg_2d_block" } {
359+ tt.func public @reduce_loop_carried_values (%arg0: !tt.ptr <f16 >, %arg1: !tt.ptr <f16 >, %arg2: !tt.ptr <f16 >, %arg5: i32 ) {
360+ %c1_i64 = arith.constant 1 : i64
361+ %c0_i32 = arith.constant 0 : i32
362+ %c0_i64 = arith.constant 0 : i64
363+ %c32_i32 = arith.constant 32 : i32
364+ %cst = arith.constant dense <0.000000e+00 > : tensor <64 x256 xf32 , #dpas >
365+ %18 = tt.make_tensor_ptr %arg0 , [%c0_i64 , %c0_i64 ], [%c0_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x32 xf16 , #blocked >>
366+ %22 = tt.make_tensor_ptr %arg1 , [%c0_i64 , %c0_i64 ], [%c0_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <32 x256 xf16 , #blocked1 >>
367+ %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args (%arg10 = %cst , %arg11 = %18 , %arg12 = %22 ) -> (tensor <64 x256 xf32 , #dpas >, !tt.ptr <tensor <64 x32 xf16 , #blocked >>, !tt.ptr <tensor <32 x256 xf16 , #blocked1 >>) : i32 {
368+ // COM: Ensure there are only 3 loop results and not layout conversion in the loop.
369+ // CHECK: [[LOOP_RES:%.*]]:3 = scf.for
370+ // CHECK-NOT: ttg.convert_layout
371+ // CHECK: scf.yield
372+ %28 = tt.load %arg11 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <64 x32 xf16 , #blocked >>
373+ %29 = tt.load %arg12 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <32 x256 xf16 , #blocked1 >>
374+ %30 = ttg.convert_layout %28 : tensor <64 x32 xf16 , #blocked > -> tensor <64 x32 xf16 , #dot0 >
375+ %31 = ttg.convert_layout %29 : tensor <32 x256 xf16 , #blocked1 > -> tensor <32 x256 xf16 , #dot1 >
376+ %32 = tt.dot %30 , %31 , %arg10 , inputPrecision = tf32 : tensor <64 x32 xf16 , #dot0 > * tensor <32 x256 xf16 , #dot1 > -> tensor <64 x256 xf32 , #dpas >
377+ %33 = tt.advance %arg11 , [%c0_i32 , %c32_i32 ] : <tensor <64 x32 xf16 , #blocked >>
378+ %34 = tt.advance %arg12 , [%c32_i32 , %c0_i32 ] : <tensor <32 x256 xf16 , #blocked1 >>
379+ scf.yield %32 , %33 , %34 : tensor <64 x256 xf32 , #dpas >, !tt.ptr <tensor <64 x32 xf16 , #blocked >>, !tt.ptr <tensor <32 x256 xf16 , #blocked1 >>
380+ }
381+ %24 = arith.truncf %23#0 : tensor <64 x256 xf32 , #dpas > to tensor <64 x256 xf16 , #dpas >
382+ %27 = tt.make_tensor_ptr %arg2 , [%c0_i64 , %c0_i64 ], [%c0_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x256 xf16 , #dpas >>
383+ tt.store %27 , %24 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <64 x256 xf16 , #dpas >>
384+
385+ // CHECK: [[LOAD1:%.*]] = tt.load [[LOOP_RES]]#1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>>
386+ // CHECK: [[CONV1:%.*]] = ttg.convert_layout [[LOAD1]] : tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>> -> tensor<64x32xf16, #[[BLOCKED]]>
387+ // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #[[BLOCKED]]>>
388+ // CHECK: tt.store [[PTR]], [[CONV1]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #[[BLOCKED]]>>
389+ %28 = tt.load %23#1 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <64 x32 xf16 , #blocked >>
390+ %29 = tt.make_tensor_ptr %arg2 , [%c0_i64 , %c0_i64 ], [%c0_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x32 xf16 , #blocked >>
391+ tt.store %29 , %28 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <64 x32 xf16 , #blocked >>
392+
393+ // CHECK: [[LOAD2:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #[[BLOCKED]]>>
394+ // CHECK: [[CONV2:%.*]] = ttg.convert_layout [[LOAD2]] : tensor<64x32xf16, #[[BLOCKED]]> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>
395+ // CHECK: tt.store [[LOOP_RES]]#1, [[CONV2]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>>
396+ %30 = tt.load %29 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <64 x32 xf16 , #blocked >>
397+ tt.store %23#1 , %30 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <64 x32 xf16 , #blocked >>
398+
399+ tt.return
400+ }
401+ }
0 commit comments