Skip to content

Commit 793a3ae

Browse files
committed
[XPU][EWOpt] Enable loop block arguments optimization
Allow applying the same kind of optimizations to `scf.for` block arguments. Signed-off-by: victor-eds <[email protected]>
1 parent 6bb173d commit 793a3ae

File tree

2 files changed

+265
-37
lines changed

2 files changed

+265
-37
lines changed

test/TritonIntelGPU/optimize-elementwise.mlir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,78 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
182182
tt.return %2 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
183183
}
184184
}
185+
186+
// -----
187+
188+
// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
189+
// CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
190+
191+
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
192+
193+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
194+
195+
// CHECK-LABEL: tt.func @test_basic_loop(
196+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>, %[[VAL_1:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>,
197+
// CHECK-SAME: %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index
198+
tt.func @test_basic_loop(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg2: index, %arg3: index, %arg4: index) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> {
199+
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
200+
// CHECK: %[[VAL_6:.*]] = scf.for %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_3]] step %[[VAL_4]] iter_args(%[[VAL_8:.*]] = %[[VAL_5]]) -> (tensor<16xf32, #[[$ATTR_0]]>) {
201+
%0 = scf.for %arg5 = %arg2 to %arg3 step %arg4 iter_args(%arg6 = %arg0) -> (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) {
202+
// CHECK: %[[VAL_9:.*]] = triton_gpu.convert_layout %[[VAL_8]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
203+
// CHECK: %[[VAL_10:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
204+
// CHECK: %[[VAL_11:.*]] = triton_gpu.convert_layout %[[VAL_9]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
205+
// CHECK: %[[VAL_12:.*]] = arith.addf %[[VAL_10]], %[[VAL_11]] : tensor<16xf32, #[[$ATTR_0]]>
206+
%1 = arith.addf %arg1, %arg6 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
207+
// CHECK: %[[VAL_13:.*]] = triton_gpu.convert_layout %[[VAL_12]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
208+
// CHECK: %[[VAL_14:.*]] = triton_gpu.convert_layout %[[VAL_13]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
209+
// CHECK: scf.yield %[[VAL_14]] : tensor<16xf32, #[[$ATTR_0]]>
210+
scf.yield %1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
211+
}
212+
// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
213+
// CHECK: tt.return %[[VAL_15]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
214+
tt.return %0 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
215+
}
216+
}
217+
218+
// -----
219+
220+
// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
221+
// CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
222+
223+
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
224+
225+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
226+
227+
// CHECK-LABEL: tt.func @test_advanced_loop(
228+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>, %[[VAL_1:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>,
229+
// CHECK-SAME: %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index,
230+
// CHECK-SAME: %[[VAL_5:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
231+
tt.func @test_advanced_loop(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg2: index, %arg3: index, %arg4: index, %arg5: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) {
232+
// CHECK: %[[VAL_6:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
233+
// CHECK: %[[VAL_7:.*]] = triton_gpu.convert_layout %[[VAL_5]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
234+
// CHECK: %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_2]] to %[[VAL_3]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<16xf32, #[[$ATTR_0]]>, tensor<16xf32, #[[$ATTR_0]]>) {
235+
%0:2 = scf.for %arg6 = %arg2 to %arg3 step %arg4 iter_args(%arg7 = %arg0, %arg8 = %arg5) -> (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) {
236+
// CHECK: %[[VAL_12:.*]] = triton_gpu.convert_layout %[[VAL_10]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
237+
// CHECK: %[[VAL_13:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
238+
// CHECK: %[[VAL_14:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
239+
// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_12]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
240+
// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : tensor<16xf32, #[[$ATTR_0]]>
241+
%1 = arith.addf %arg1, %arg7 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
242+
// CHECK: %[[VAL_17:.*]] = triton_gpu.convert_layout %[[VAL_16]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
243+
// CHECK: %[[VAL_18:.*]] = triton_gpu.convert_layout %[[VAL_17]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
244+
// CHECK: %[[VAL_19:.*]] = triton_gpu.convert_layout %[[VAL_13]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
245+
// CHECK: %[[VAL_20:.*]] = arith.addf %[[VAL_18]], %[[VAL_19]] : tensor<16xf32, #[[$ATTR_0]]>
246+
// CHECK: %[[VAL_21:.*]] = triton_gpu.convert_layout %[[VAL_20]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
247+
// CHECK: %[[VAL_22:.*]] = triton_gpu.convert_layout %[[VAL_17]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
248+
// CHECK: %[[VAL_23:.*]] = triton_gpu.convert_layout %[[VAL_21]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
249+
// CHECK: scf.yield %[[VAL_22]], %[[VAL_23]] : tensor<16xf32, #[[$ATTR_0]]>, tensor<16xf32, #[[$ATTR_0]]>
250+
%2 = arith.addf %1, %arg8 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
251+
scf.yield %1, %2 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
252+
}
253+
// CHECK: }
254+
// CHECK: %[[VAL_24:.*]] = triton_gpu.convert_layout %[[VAL_25:.*]]#0 : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
255+
// CHECK: %[[VAL_26:.*]] = triton_gpu.convert_layout %[[VAL_25]]#1 : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
256+
// CHECK: tt.return %[[VAL_24]], %[[VAL_26]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
257+
tt.return %0#0, %0#1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
258+
}
259+
}

0 commit comments

Comments
 (0)