Skip to content

Commit 34fe67e

Browse files
authored
[LinalgToXeGPU] Remove redundant linalg.broadcasts (#419)
Signed-off-by: dchigarev <[email protected]>
1 parent fb51bb4 commit 34fe67e

File tree

4 files changed

+307
-0
lines changed

4 files changed

+307
-0
lines changed

lib/gc/Transforms/GPU/LinalgToXeGPU.cpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,105 @@ static FailureOr<Value> findAndReplaceTranspose(const Value &matmulOperand,
13121312
defOp, "No transpose operation producing the operand was found");
13131313
}
13141314

1315+
// Checks whether the given linalgOp operand is produced by a
1316+
// `linalg::BroadcastOp` that can be replaced by a simple subview
1317+
// (for example broadcast: 7x128 -> 1x1x7x128) and ensures that
1318+
// the broadcast result is only used by linalgOp in question.
1319+
//
1320+
// If a valid `linalg::BroadcastOp` is found, the function removes it
1321+
// and returns the operand of the `linalg::BroadcastOp` as the new
1322+
// linalgOp operand. Otherwise returns the original operand.
1323+
static Value findAndReplaceBroadcast(linalg::LinalgOp linalgOp,
1324+
size_t operandIdx,
1325+
PatternRewriter &rewriter) {
1326+
auto operand = linalgOp.getDpsInputs()[operandIdx];
1327+
auto operandParent = operand;
1328+
1329+
// walk over the 'Value' users and verify that it's only used by 'ops'
1330+
std::function<bool(Value, SmallVector<linalg::LinalgOp> &)> onlyUsedByOp =
1331+
[&onlyUsedByOp](Value value, SmallVector<linalg::LinalgOp> &ops) -> bool {
1332+
bool result = true;
1333+
for (auto user : value.getUsers()) {
1334+
if (auto linalgOpUser = dyn_cast<linalg::LinalgOp>(user))
1335+
result &= std::find(ops.begin(), ops.end(), linalgOpUser) !=
1336+
ops.end(); // linalgOpUser == op;
1337+
else if (isa<memref::DeallocOp>(user))
1338+
continue; // allow deallocs as users
1339+
else if (auto subview = dyn_cast<memref::SubViewOp>(user))
1340+
result &= onlyUsedByOp(subview.getResult(), ops);
1341+
else
1342+
return false;
1343+
}
1344+
return result;
1345+
};
1346+
1347+
linalg::BroadcastOp broadcastOp = nullptr;
1348+
while (auto defOp = operandParent.getDefiningOp()) {
1349+
for (auto x : defOp->getUsers()) {
1350+
if (!isa<linalg::BroadcastOp>(x))
1351+
continue;
1352+
1353+
if (broadcastOp) {
1354+
rewriter.notifyMatchFailure(broadcastOp,
1355+
"Only one broadcast operation is allowed");
1356+
return operand;
1357+
}
1358+
1359+
broadcastOp = dyn_cast<linalg::BroadcastOp>(x);
1360+
auto broadcastRes = broadcastOp.getDpsInits()[0];
1361+
SmallVector<linalg::LinalgOp> ops({linalgOp, broadcastOp});
1362+
1363+
// verify that there are no other users of the broadcast result
1364+
// other than the linalgOp in question
1365+
if (!onlyUsedByOp(broadcastRes, ops)) {
1366+
rewriter.notifyMatchFailure(
1367+
broadcastOp, "Broadcast result is used by more than one operation");
1368+
return operand;
1369+
}
1370+
break;
1371+
}
1372+
1373+
if (defOp->getOperands().size() == 0)
1374+
break;
1375+
1376+
operandParent = defOp->getOperand(0);
1377+
}
1378+
if (!broadcastOp) {
1379+
rewriter.notifyMatchFailure(
1380+
linalgOp, "No broadcast operation producing the operand was found");
1381+
return operand;
1382+
}
1383+
1384+
auto brInp = broadcastOp.getDpsInputs()[0];
1385+
auto brOut = broadcastOp.getDpsInits()[0];
1386+
1387+
auto inpType = dyn_cast<MemRefType>(brInp.getType());
1388+
auto outType = dyn_cast<MemRefType>(brOut.getType());
1389+
if (!inpType || !outType)
1390+
return operand;
1391+
1392+
auto inpShape = inpType.getShape();
1393+
auto outShape = outType.getShape();
1394+
1395+
if (inpShape.size() < 2) {
1396+
rewriter.notifyMatchFailure(broadcastOp, "Only nD broadcast is supported");
1397+
return operand;
1398+
}
1399+
1400+
if (!utils::canSqueezeDims(inpShape) || !utils::canSqueezeDims(outShape)) {
1401+
rewriter.notifyMatchFailure(broadcastOp,
1402+
"Can't squeeze broadcast operands to 2D");
1403+
return operand;
1404+
}
1405+
1406+
auto res = utils::reduceMemrefDims(rewriter, broadcastOp.getLoc(), brInp);
1407+
if (failed(res))
1408+
return operand;
1409+
1410+
rewriter.eraseOp(broadcastOp);
1411+
return res.value();
1412+
}
1413+
13151414
// Create XeGPU DPAS kernel out of GEMM-like operation.
13161415
static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
13171416
ArrayRef<int64_t> dpasTile, int kTile,
@@ -1690,7 +1789,9 @@ LogicalResult createEltwiseKernel(linalg::LinalgOp linalgOp,
16901789

16911790
// Create descriptors and load values for all inputs.
16921791
SmallVector<SmallVector<Value>> loadedInputs;
1792+
size_t operandIdx = 0;
16931793
for (auto input : linalgOp.getDpsInputs()) {
1794+
input = findAndReplaceBroadcast(linalgOp, operandIdx, rewriter);
16941795
SmallVector<Value> inputTiles =
16951796
createDescriptorTiles(rewriter, loc, input, outputShape, tileShape);
16961797

@@ -1699,6 +1800,7 @@ LogicalResult createEltwiseKernel(linalg::LinalgOp linalgOp,
16991800
/*vnniConf=*/std::nullopt,
17001801
/*transpose=*/nullptr, /*transpose_bit=*/nullptr);
17011802
loadedInputs.push_back(loadedVals);
1803+
operandIdx++;
17021804
}
17031805

17041806
// Extract SIMD sized sub-tiles from loaded tiles.

lib/gc/Transforms/GPU/Pipeline.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ void populateGPUPipeline(OpPassManager &pm,
3737

3838
pm.addPass(createDecomposeTensorOperation());
3939
pm.addNestedPass<func::FuncOp>(createGpuTilingAndFusion());
40+
pm.addPass(createCanonicalizerPass());
4041

4142
pm.addPass(bufferization::createEmptyTensorEliminationPass());
4243
pm.addPass(bufferization::createEmptyTensorToAllocTensorPass());
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @broadcast_eliminate_2d
4+
func.func @broadcast_eliminate_2d() {
5+
%c1 = arith.constant 1 : index
6+
%c2 = arith.constant 2 : index
7+
%c4 = arith.constant 4 : index
8+
9+
// CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<7x128xf16>
10+
%0 = memref.alloc() : memref<7x128xf16>
11+
// CHECK: %[[MEMREF_2:.*]] = memref.alloc() : memref<1x1x7x128xf16>
12+
%2 = memref.alloc() : memref<1x1x7x128xf16>
13+
// CHECK: %[[MEMREF_3:.*]] = memref.alloc() : memref<1x1x7x128xf16>
14+
%3 = memref.alloc() : memref<1x1x7x128xf16>
15+
16+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) {
17+
// CHECK-NOT: memref.alloc() : memref<4x1x7x128xf16, 3>
18+
%slm_base = memref.alloc() : memref<4x1x7x128xf16, 3>
19+
%1 = memref.subview %slm_base[%arg6, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : memref<4x1x7x128xf16, 3> to memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>
20+
21+
// CHECK-NOT: linalg.broadcast
22+
linalg.broadcast ins(%0 : memref<7x128xf16>) outs(%1 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>) dimensions = [0, 1]
23+
// CHECK: xegpu.create_nd_tdesc %[[MEMREF_0]]
24+
linalg.add ins(%1, %2 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>, memref<1x1x7x128xf16>) outs(%3 : memref<1x1x7x128xf16>)
25+
gpu.terminator
26+
}
27+
return
28+
}
29+
30+
// -----
31+
32+
// CHECK-LABEL: func.func @broadcast_eliminate
33+
func.func @broadcast_eliminate_3d() {
34+
%c1 = arith.constant 1 : index
35+
%c2 = arith.constant 2 : index
36+
%c4 = arith.constant 4 : index
37+
38+
// CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<1x7x128xf16>
39+
%0 = memref.alloc() : memref<1x7x128xf16>
40+
// CHECK: %[[MEMREF_2:.*]] = memref.alloc() : memref<1x1x7x128xf16>
41+
%2 = memref.alloc() : memref<1x1x7x128xf16>
42+
// CHECK: %[[MEMREF_3:.*]] = memref.alloc() : memref<1x1x7x128xf16>
43+
%3 = memref.alloc() : memref<1x1x7x128xf16>
44+
45+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) {
46+
// CHECK-NOT: memref.alloc() : memref<4x1x7x128xf16, 3>
47+
%slm_base = memref.alloc() : memref<4x1x7x128xf16, 3>
48+
%1 = memref.subview %slm_base[%arg6, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : memref<4x1x7x128xf16, 3> to memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>
49+
50+
// CHECK-NOT: linalg.broadcast
51+
linalg.broadcast ins(%0 : memref<1x7x128xf16>) outs(%1 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>) dimensions = [0]
52+
// Squeezing the %0 before passing to 'linalg.add'
53+
// CHECK: %[[MEMREF0_SQUEEZ:.+]] = memref.subview %[[MEMREF_0]][0, 0, 0] [1, 7, 128] [1, 1, 1] :
54+
// CHECK-SAME: memref<1x7x128xf16> to memref<7x128xf16, strided<[128, 1]>>
55+
// CHECK: xegpu.create_nd_tdesc %[[MEMREF0_SQUEEZ]]
56+
linalg.add ins(%1, %2 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>, memref<1x1x7x128xf16>) outs(%3 : memref<1x1x7x128xf16>)
57+
gpu.terminator
58+
}
59+
return
60+
}
61+
62+
// -----
63+
64+
// CHECK-LABEL: func.func @complex_broadcast
65+
func.func @complex_broadcast_3d() {
66+
%c1 = arith.constant 1 : index
67+
%c2 = arith.constant 2 : index
68+
%c4 = arith.constant 4 : index
69+
70+
// CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<7x128xf16>
71+
%0 = memref.alloc() : memref<7x128xf16>
72+
// CHECK: %[[MEMREF_1:.*]] = memref.alloc() : memref<7x7x128xf16>
73+
%1 = memref.alloc() : memref<7x7x128xf16>
74+
// CHECK: %[[MEMREF_2:.*]] = memref.alloc() : memref<7x7x128xf16>
75+
%2 = memref.alloc() : memref<7x7x128xf16>
76+
// CHECK: %[[MEMREF_3:.*]] = memref.alloc() : memref<7x7x128xf16>
77+
%3 = memref.alloc() : memref<7x7x128xf16>
78+
79+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) {
80+
// This broadcast can't be replaced by a single memref.subview. Can't remove it
81+
// CHECK: linalg.broadcast
82+
linalg.broadcast ins(%0 : memref<7x128xf16>) outs(%1 : memref<7x7x128xf16>) dimensions = [0]
83+
linalg.add ins(%1, %2 : memref<7x7x128xf16>, memref<7x7x128xf16>) outs(%3 : memref<7x7x128xf16>)
84+
gpu.terminator
85+
}
86+
return
87+
}
88+
89+
// -----
90+
91+
// CHECK-LABEL: func.func @single_broadcast
92+
func.func @single_broadcast() {
93+
%c1 = arith.constant 1 : index
94+
%c2 = arith.constant 2 : index
95+
%c4 = arith.constant 4 : index
96+
97+
// CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<7x128xf16>
98+
%0 = memref.alloc() : memref<7x128xf16>
99+
// CHECK: %[[MEMREF_1:.*]] = memref.alloc() : memref<1x1x7x128xf16>
100+
%1 = memref.alloc() : memref<1x1x7x128xf16>
101+
102+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) {
103+
// broadcast result is not an input of any xegpu operation, we can't lower it
104+
// CHECK: linalg.broadcast
105+
linalg.broadcast ins(%0 : memref<7x128xf16>) outs(%1 : memref<1x1x7x128xf16>) dimensions = [0, 1]
106+
gpu.terminator
107+
}
108+
return
109+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils %s | FileCheck %s
2+
3+
!dtype=f16
4+
!input_memref_type=memref<2x7x32x128x!dtype>
5+
!input_tensor_type=tensor<2x7x32x128x!dtype>
6+
!output_memref_type=memref<2x32x7x128x!dtype>
7+
!output_tensor_type=tensor<2x32x7x128x!dtype>
8+
!cos_sin_cache_memref_type=memref<1x1x2048x128x!dtype>
9+
!cos_sin_cache_tensor_type=tensor<1x1x2048x128x!dtype>
10+
!cos_sin_cache_tensor_shrink_type=tensor<1x1x7x128x!dtype>
11+
!pos_ids_memref_type=memref<1x7xindex>
12+
!pos_ids_tensor_type=tensor<1x7xindex>
13+
#map = affine_map<(xi, yi, zi) -> ((xi * 3 * 4 + yi * 4 + zi) * 2)>
14+
module @fragment_name {
15+
memref.global "private" constant @_cos_cache : !cos_sin_cache_memref_type = dense<3.000000e+00>
16+
memref.global "private" constant @_sin_cache : !cos_sin_cache_memref_type = dense<2.000000e+00>
17+
memref.global "private" constant @_iinput_const : !input_memref_type = dense<3.000000e+00>
18+
memref.global "private" constant @_ipos_ids_const : !pos_ids_memref_type = dense<1>
19+
memref.global "private" constant @_ipos_id_end_const : memref<1xindex> = dense<1>
20+
func.func @RoPE(%iinput: !input_memref_type, %ipos_ids: !pos_ids_memref_type, %ipos_id_end: memref<1xindex>, %out: !output_memref_type) {
21+
%input = bufferization.to_tensor %iinput restrict : !input_memref_type
22+
%cos_cache = memref.get_global @_cos_cache : !cos_sin_cache_memref_type
23+
%sin_cache = memref.get_global @_sin_cache : !cos_sin_cache_memref_type
24+
%cos_cache_tensor = bufferization.to_tensor %cos_cache restrict : !cos_sin_cache_memref_type
25+
%sin_cache_tensor = bufferization.to_tensor %sin_cache restrict : !cos_sin_cache_memref_type
26+
%pos_ids = bufferization.to_tensor %ipos_ids restrict : !pos_ids_memref_type
27+
%pos_id_end = bufferization.to_tensor %ipos_id_end restrict : memref<1xindex>
28+
%3 = tensor.empty(): !output_tensor_type
29+
30+
%transpose_in = linalg.transpose ins(%input: !input_tensor_type) outs(%3:!output_tensor_type) permutation = [0, 2, 1, 3]
31+
32+
%c0 = arith.constant 0 : index
33+
%c3 = arith.constant 3 : index
34+
%cos_cache_slice = tensor.extract_slice %cos_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type
35+
%cos_cache_slice2 = tensor.collapse_shape %cos_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype>
36+
%cos_cache_slice3 = tensor.collapse_shape %cos_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype>
37+
%pos_ids_index=tensor.expand_shape %pos_ids [[0],[1,2]] output_shape [1, 7, 1] : tensor<1x7xindex> into tensor<1x7x1xindex>
38+
39+
%cos_cache_slice4 = tensor.gather %cos_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype>
40+
41+
%cos_cache_slice5 = tensor.expand_shape %cos_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype>
42+
%cos_cache_slice6 = tensor.collapse_shape %cos_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype>
43+
44+
%cos_cache_slice7 = linalg.broadcast ins(%cos_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1]
45+
%input_apply_cos_cache = linalg.mul ins(%transpose_in, %cos_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type
46+
47+
%head_dim = tensor.dim %transpose_in, %c3 : !output_tensor_type
48+
%c2 = arith.constant 2 : index
49+
%half_head_dim = arith.floordivsi %head_dim, %c2 : index
50+
%transpose_input_first_half = tensor.extract_slice %transpose_in[0, 0, 0, 0][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype>
51+
%transpose_input_second_half = tensor.extract_slice %transpose_in[0, 0, 0, %half_head_dim][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype>
52+
%cnegative1 = arith.constant dense<-1.000000e+00> : tensor<2x32x7x64x!dtype>
53+
%empty_tensor = tensor.empty() : tensor<2x32x7x64x!dtype>
54+
%transpose_input_second_half_opposite = linalg.mul ins(%transpose_input_second_half, %cnegative1: tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) outs(%empty_tensor: tensor<2x32x7x64x!dtype>) -> tensor<2x32x7x64x!dtype>
55+
56+
%transformed_input = tensor.concat dim(3) %transpose_input_second_half_opposite, %transpose_input_first_half : (tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) -> !output_tensor_type
57+
58+
%sin_cache_slice = tensor.extract_slice %sin_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type
59+
%sin_cache_slice2 = tensor.collapse_shape %sin_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype>
60+
%sin_cache_slice3 = tensor.collapse_shape %sin_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype>
61+
%sin_cache_slice4 = tensor.gather %sin_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype>
62+
63+
%sin_cache_slice5 = tensor.expand_shape %sin_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype>
64+
%sin_cache_slice6 = tensor.collapse_shape %sin_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype>
65+
66+
%sin_cache_slice7 = linalg.broadcast ins(%sin_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1]
67+
%input_apply_sin_cache = linalg.mul ins(%transformed_input, %sin_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type
68+
69+
%result = linalg.add ins(%input_apply_cos_cache, %input_apply_sin_cache: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type
70+
bufferization.materialize_in_destination %result in restrict writable %out : (!output_tensor_type, !output_memref_type) -> ()
71+
return
72+
}
73+
74+
func.func @main() {
75+
%inp = memref.get_global @_iinput_const : !input_memref_type
76+
%ipos_ids = memref.get_global @_ipos_ids_const : !pos_ids_memref_type
77+
%ipos_id_end = memref.get_global @_ipos_id_end_const : memref<1xindex>
78+
79+
%out = memref.alloc() {alignment = 64 : i64} : !output_memref_type
80+
81+
func.call @RoPE(%inp, %ipos_ids, %ipos_id_end, %out) : (!input_memref_type, !pos_ids_memref_type, memref<1xindex>, !output_memref_type) -> ()
82+
83+
%out_subview = memref.subview %out[0, 0, 0, 0] [2, 1, 1, 1] [1, 1, 1, 1] : !output_memref_type to memref<2xf16, strided<[28672]>>
84+
%cast = memref.cast %out_subview : memref<2xf16, strided<[28672]>> to memref<*xf16>
85+
call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
86+
87+
return
88+
}
89+
90+
func.func private @printMemrefF16(%ptr : memref<*xf16>)
91+
}
92+
93+
// CHECK: Unranked Memref base@{{(0x)?[-0-9a-fA-F]*}}
94+
// CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [28672] data =
95+
// CHECK-NEXT: [3, 3]

0 commit comments

Comments
 (0)