@@ -62,6 +62,88 @@ tt.func @matmul_loop_load_acc(%lb : index, %ub : index, %step : index,
6262
6363// -----
6464
65+ #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [4 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
66+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
67+ #mma = #ttg.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 256 , 16 ]}>
68+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 16 }>
69+ #smem = #ttg.shared_memory
70+
71+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " cuda:90" , " ttg.threads-per-warp" = 32 : i32 } {
72+
73+ // CHECK-LABEL: @fused_loop
74+ tt.func public @fused_loop (%arg5: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg7: i32 {tt.divisibility = 16 : i32 }) {
75+ %c10_i32 = arith.constant 10 : i32
76+ %false = arith.constant false
77+ %0 = ub.poison : !tt.tensordesc <tensor <64 x256 xf16 >>
78+ %cst = arith.constant dense <0 > : tensor <128 x1 xi64 , #blocked >
79+ %c -1 _i32 = arith.constant -1 : i32
80+ %c1_i32 = arith.constant 1 : i32
81+ %c0_i32 = arith.constant 0 : i32
82+ %c64_i32 = arith.constant 64 : i32
83+ %c1_i64 = arith.constant 1 : i64
84+ %cst_0 = arith.constant dense <0.000000e+00 > : tensor <128 x256 xf32 , #mma >
85+
86+ %1 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
87+ %2 = tt.expand_dims %1 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x64 xi32 , #blocked >
88+ %3 = arith.extsi %arg7 : i32 to i64
89+ %4 = tt.make_tensor_descriptor %arg5 , [%arg7 , %arg7 ], [%3 , %c1_i64 ] : <f16 >, <tensor <64 x256 xf16 >>
90+ %5 = tt.broadcast %2 : tensor <1 x64 xi32 , #blocked > -> tensor <128 x64 xi32 , #blocked >
91+ %7 = tt.splat %3 : i64 -> tensor <128 x1 xi64 , #blocked >
92+
93+ // CHECK: scf.for
94+ %8:9 = scf.for %arg29 = %c0_i32 to %arg7 step %c1_i32 iter_args (%arg30 = %c -1 _i32 , %arg31 = %4 , %arg32 = %c0_i32 , %arg33 = %arg5 , %arg34 = %cst_0 , %arg35 = %c0_i32 , %arg36 = %cst , %arg37 = %0 , %arg38 = %false ) -> (i32 , !tt.tensordesc <tensor <64 x256 xf16 >>, i32 , !tt.ptr <f16 >, tensor <128 x256 xf32 , #mma >, i32 , tensor <128 x1 xi64 , #blocked >, !tt.tensordesc <tensor <64 x256 xf16 >>, i1 ) : i32 {
95+ %9 = arith.addi %arg30 , %c1_i32 : i32
96+ %10 = arith.cmpi eq , %arg30 , %c10_i32 : i32
97+ %11 = arith.select %10 , %c0_i32 , %9 : i32
98+ %12 = arith.cmpi eq , %11 , %c0_i32 : i32
99+
100+ // This op is a distance 1 dependency of itself.
101+ // CHECK: {_test_marker_0, loop.cluster = 4 : i32, loop.stage = 0 : i32}
102+ %13 = arith.select %12 , %c0_i32 , %arg32 {_test_marker_0 } : i32
103+
104+ %14 = arith.select %12 , %arg31 , %arg37 : !tt.tensordesc <tensor <64 x256 xf16 >>
105+ %15 = arith.select %12 , %c10_i32 , %arg35 : i32
106+ %16 = scf.if %12 -> (tensor <128 x1 xi64 , #blocked >) {
107+ %32 = arith.muli %cst , %7 : tensor <128 x1 xi64 , #blocked >
108+ scf.yield %32 : tensor <128 x1 xi64 , #blocked >
109+ } else {
110+ scf.yield %arg36 : tensor <128 x1 xi64 , #blocked >
111+ }
112+ %17 = tt.splat %arg33 : !tt.ptr <f16 > -> tensor <128 x1 x!tt.ptr <f16 >, #blocked >
113+ %18 = tt.addptr %17 , %16 : tensor <128 x1 x!tt.ptr <f16 >, #blocked >, tensor <128 x1 xi64 , #blocked >
114+ %19 = tt.broadcast %18 : tensor <128 x1 x!tt.ptr <f16 >, #blocked > -> tensor <128 x64 x!tt.ptr <f16 >, #blocked >
115+ %20 = tt.addptr %19 , %5 : tensor <128 x64 x!tt.ptr <f16 >, #blocked >, tensor <128 x64 xi32 , #blocked >
116+ %21 = tt.addptr %arg33 , %c64_i32 : !tt.ptr <f16 >, i32
117+ %22 = tt.load %20 : tensor <128 x64 x!tt.ptr <f16 >, #blocked >
118+ %23 = ttg.local_alloc %22 : (tensor <128 x64 xf16 , #blocked >) -> !ttg.memdesc <128 x64 xf16 , #shared , #smem >
119+ %24 = arith.muli %13 , %c64_i32 : i32
120+ %25 = tt.experimental_descriptor_load %14 [%24 , %15 ] : !tt.tensordesc <tensor <64 x256 xf16 >> -> tensor <64 x256 xf16 , #blocked1 >
121+ %26 = ttg.local_alloc %25 : (tensor <64 x256 xf16 , #blocked1 >) -> !ttg.memdesc <64 x256 xf16 , #shared , #smem >
122+ %27 = ttng.warp_group_dot %23 , %26 , %arg34 , %arg38 {inputPrecision = 0 : i32 } : !ttg.memdesc <128 x64 xf16 , #shared , #smem > * !ttg.memdesc <64 x256 xf16 , #shared , #smem > -> tensor <128 x256 xf32 , #mma >
123+ %28 = arith.addi %13 , %c1_i32 : i32
124+
125+ // This op is in the backward slice of `_test_marker_2` and the epilogue.
126+ // CHECK: {_test_marker_1, loop.cluster = 3 : i32, loop.stage = 1 : i32}
127+ %29 = arith.cmpi eq , %11 , %c10_i32 {_test_marker_1 } : i32
128+
129+ // CHECK: {_test_marker_2, loop.cluster = 3 : i32, loop.stage = 1 : i32}
130+ %30 = arith.select %29 , %arg5 , %21 {_test_marker_2 } : !tt.ptr <f16 >
131+
132+ %31 = arith.cmpi ne , %11 , %c10_i32 : i32
133+
134+ scf.if %29 {
135+ " use" (%27 ) : (tensor <128 x256 xf32 , #mma >) -> ()
136+ // CHECK: {_test_marker_3, loop.cluster = 5 : i32, loop.stage = 2 : i32}
137+ } {_test_marker_3 }
138+ scf.yield %11 , %14 , %28 , %30 , %27 , %15 , %16 , %14 , %31 : i32 , !tt.tensordesc <tensor <64 x256 xf16 >>, i32 , !tt.ptr <f16 >, tensor <128 x256 xf32 , #mma >, i32 , tensor <128 x1 xi64 , #blocked >, !tt.tensordesc <tensor <64 x256 xf16 >>, i1
139+ }
140+ tt.return
141+ }
142+
143+ }
144+
145+ // -----
146+
65147// CHECK-LABEL: @prologue_backward_slice
66148tt.func @prologue_backward_slice (%ub: i32 , %cond: i1 ) {
67149 %c0_i32 = arith.constant 0 : i32
0 commit comments