Skip to content

Commit 4623a97

Browse files
authored
[aievec] Update the lowering strategy of aievec.srs to llvm (#2632)
1 parent f4b1d8c commit 4623a97

File tree

5 files changed

+121
-49
lines changed

5 files changed

+121
-49
lines changed

lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,23 @@ static Value forceCastValueToType(OpBuilder &builder, Location loc, Value val,
7272
if (valTy == type)
7373
return val;
7474
auto srcVecTy = dyn_cast<VectorType>(valTy);
75+
auto dstVecTy = dyn_cast<VectorType>(type);
76+
7577
if (srcVecTy) {
76-
auto dstVecTy = dyn_cast<VectorType>(type);
7778
assert(dstVecTy && "vector values cannot be forced into a non-vector type");
78-
assert(srcVecTy.getRank() == 1 && dstVecTy.getRank() == 1 &&
79-
"only flat 1D vectors can be force casted");
79+
80+
// Flatten source vector if it's not rank-1
81+
auto flatSrcVecTy = getFlattenedVectorType(srcVecTy);
82+
if (srcVecTy != flatSrcVecTy)
83+
val = builder.create<vector::ShapeCastOp>(loc, flatSrcVecTy, val);
84+
85+
// Flatten destination type if it's not rank-1
86+
auto flatDstVecTy = getFlattenedVectorType(dstVecTy);
87+
8088
int64_t dstVecLength =
81-
dstVecTy.getElementTypeBitWidth() * dstVecTy.getShape()[0];
89+
flatDstVecTy.getElementTypeBitWidth() * flatDstVecTy.getShape()[0];
8290
int64_t srcVecLength =
83-
srcVecTy.getElementTypeBitWidth() * srcVecTy.getShape()[0];
91+
flatSrcVecTy.getElementTypeBitWidth() * flatSrcVecTy.getShape()[0];
8492
if (srcVecLength != dstVecLength) {
8593
assert(srcVecLength < dstVecLength &&
8694
"only widening forced casts are supported");
@@ -92,7 +100,19 @@ static Value forceCastValueToType(OpBuilder &builder, Location loc, Value val,
92100
else
93101
val = widen256bVectorValueTo512b(builder, loc, val);
94102
}
103+
104+
// Bitcast to flat destination type (bitcast only supports flat vectors)
105+
val = bitcastValueToType(builder, loc, val, flatDstVecTy);
106+
107+
// Reshape back to original destination shape if needed
108+
if (flatDstVecTy != dstVecTy)
109+
val = builder.create<vector::ShapeCastOp>(loc, dstVecTy, val);
110+
111+
return val;
95112
}
113+
114+
// Non-vector types can be bitcast directly
115+
assert(!dstVecTy && "cannot force cast scalar to vector type");
96116
return bitcastValueToType(builder, loc, val, type);
97117
}
98118

@@ -280,9 +300,10 @@ class AddElemOpConversion
280300
return failure();
281301
}
282302

283-
// create bitcast for result
284-
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
285-
addElemOp);
303+
// create bitcast/shape_cast for result
304+
auto resultVal = forceCastValueToType(rewriter, loc, addElemOp,
305+
op.getResult().getType());
306+
rewriter.replaceOp(op, resultVal);
286307
return success();
287308
}
288309
};
@@ -643,9 +664,10 @@ class MulElemOpConversion
643664
/*variant=*/2, /*zero_acc=*/0, /*shift16=*/1,
644665
/*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0));
645666

646-
// create bitcast for result
647-
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
648-
acc64Val);
667+
// create bitcast/shape_cast for result
668+
auto resultVal =
669+
forceCastValueToType(rewriter, loc, acc64Val, op.getResult().getType());
670+
rewriter.replaceOp(op, resultVal);
649671
return success();
650672
}
651673

@@ -828,9 +850,10 @@ class MulElemOpConversion
828850
createMacOps(c, e, cfMul))))))));
829851
}
830852

831-
// create bitcast for result
832-
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
833-
finalMacVal);
853+
// create bitcast/shape_cast for result
854+
auto resultVal = forceCastValueToType(rewriter, loc, finalMacVal,
855+
op.getResult().getType());
856+
rewriter.replaceOp(op, resultVal);
834857
return success();
835858
}
836859

@@ -881,9 +904,10 @@ class MulElemOpConversion
881904
rewriter.getI32Type()}));
882905
}
883906

884-
// create bitcast for result
885-
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
886-
mulElemOp);
907+
// create bitcast/shape_cast for result
908+
auto resultVal = forceCastValueToType(rewriter, loc, mulElemOp,
909+
op.getResult().getType());
910+
rewriter.replaceOp(op, resultVal);
887911
return success();
888912
}
889913
};
@@ -1186,13 +1210,10 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::SRSOp> {
11861210
return failure();
11871211
}
11881212

1189-
// create bitcast for result if needed
1190-
if (op.getResult().getType() != srsIntrOp.getType()) {
1191-
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
1192-
srsIntrOp);
1193-
} else {
1194-
rewriter.replaceOp(op, srsIntrOp);
1195-
}
1213+
// create bitcast/shape_cast for result if needed
1214+
auto resultVal = forceCastValueToType(rewriter, loc, srsIntrOp,
1215+
op.getResult().getType());
1216+
rewriter.replaceOp(op, resultVal);
11961217

11971218
return success();
11981219
}
@@ -1388,9 +1409,10 @@ class ConcatOpConversion
13881409
return failure();
13891410
}
13901411

1391-
// create bitcast for result
1392-
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
1393-
concatOp);
1412+
// create bitcast/shape_cast for result
1413+
auto resultVal =
1414+
forceCastValueToType(rewriter, loc, concatOp, op.getResult().getType());
1415+
rewriter.replaceOp(op, resultVal);
13941416

13951417
return success();
13961418
}
@@ -1484,13 +1506,10 @@ class ExtOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ExtOp> {
14841506
return failure();
14851507
}
14861508

1487-
// create bitcast for result
1488-
if (op.getResult().getType() != extOp.getType()) {
1489-
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
1490-
extOp);
1491-
} else {
1492-
rewriter.replaceOp(op, extOp);
1493-
}
1509+
// create bitcast/shape_cast for result
1510+
auto resultVal =
1511+
forceCastValueToType(rewriter, loc, extOp, op.getResult().getType());
1512+
rewriter.replaceOp(op, resultVal);
14941513

14951514
return success();
14961515
}
@@ -1964,9 +1983,10 @@ class ShiftOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ShiftOp> {
19641983
rewriter.getI32Type(), rewriter.getI32Type()}));
19651984
}
19661985

1967-
// create bitcast for result
1968-
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
1969-
shiftOp);
1986+
// create bitcast/shape_cast for result
1987+
auto resultVal =
1988+
forceCastValueToType(rewriter, loc, shiftOp, op.getResult().getType());
1989+
rewriter.replaceOp(op, resultVal);
19701990

19711991
return success();
19721992
}

test/Conversion/AIEVecToLLVM/mul_elem.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ func.func @i32_i32_i32_mul_elem(%arg0 : vector<16xi32>, %arg1 : vector<16xi32>)
7979
// CHECK-NEXT: %[[CST8:.*]] = llvm.mlir.constant(1114 : i32) : i32
8080
// CHECK-NEXT: %[[BITCAST3:.*]] = llvm.bitcast %[[SHUFF0]] : vector<16xi32> to vector<64xi8>
8181
// CHECK-NEXT: %[[ACC3:.*]] = "xllvm.intr.aie2.I512.I512.ACC1024.acc64.mac.conf"(%[[BITCAST3]], %[[SHUFF2]], %[[ACC2]], %[[CST8]]) : (vector<64xi8>, vector<16xi32>, vector<16xi64>, i32) -> vector<16xi64>
82-
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[ACC3]] : vector<16xi64> to vector<16xi64>
83-
// CHECK-NEXT: return %[[RES]] : vector<16xi64>
82+
// CHECK-NEXT: return %[[ACC3]] : vector<16xi64>
8483

8584
// -----
8685

test/Conversion/AIEVecToLLVM/shift.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ func.func @i32_shift(%arg0 : vector<16xi32>, %shift : i32) -> vector<16xi32> {
5252
// CHECK-NEXT: %[[VSHIFT:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
5353
// CHECK-SAME: %[[ARG0]], %[[ARG0]], %[[CST]], %[[SHIFT]]) :
5454
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
55-
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[VSHIFT]] : vector<16xi32> to vector<16xi32>
56-
// CHECK-NEXT: return %[[RES]] : vector<16xi32>
55+
// CHECK-NEXT: return %[[VSHIFT]] : vector<16xi32>
5756

5857
// -----
5958

@@ -69,5 +68,4 @@ func.func @bf16_shift(%arg0 : vector<32xbf16>, %shift : i32) -> vector<32xbf16>
6968
// CHECK-NEXT: %[[VSHIFT:.*]] = "xllvm.intr.aie2.vshift.bf512.bf512"(
7069
// CHECK-SAME: %[[ARG0]], %[[ARG0]], %[[CST]], %[[SHIFT]]) :
7170
// CHECK-SAME: (vector<32xbf16>, vector<32xbf16>, i32, i32) -> vector<32xbf16>
72-
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[VSHIFT]] : vector<32xbf16> to vector<32xbf16>
73-
// CHECK-NEXT: return %[[RES]] : vector<32xbf16>
71+
// CHECK-NEXT: return %[[VSHIFT]] : vector<32xbf16>

test/Conversion/AIEVecToLLVM/test-concat.mlir

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,7 @@ func.func @v16i32_concat_v8i32(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>) ->
123123
// CHECK: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I512.I256"(
124124
// CHECK-SAME: %[[ARG0]], %[[ARG1]]) :
125125
// CHECK-SAME: (vector<8xi32>, vector<8xi32>) -> vector<16xi32>
126-
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<16xi32> to vector<16xi32>
127-
// CHECK-NEXT: return %[[RES]] : vector<16xi32>
126+
// CHECK-NEXT: return %[[CONCAT]] : vector<16xi32>
128127

129128
// -----
130129

@@ -139,8 +138,7 @@ func.func @v32i32_concat_v8i32(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>,
139138
// CHECK: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I1024.I256"(
140139
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]) :
141140
// CHECK-SAME: (vector<8xi32>, vector<8xi32>, vector<8xi32>, vector<8xi32>) -> vector<32xi32>
142-
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<32xi32> to vector<32xi32>
143-
// CHECK-NEXT: return %[[RES]] : vector<32xi32>
141+
// CHECK-NEXT: return %[[CONCAT]] : vector<32xi32>
144142

145143
// -----
146144

@@ -155,8 +153,7 @@ func.func @v32i32_concat_v16i32(%arg0 : vector<16xi32>, %arg1 : vector<16xi32>)
155153
// CHECK: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I1024.I512"(
156154
// CHECK-SAME: %[[ARG0]], %[[ARG1]]) :
157155
// CHECK-SAME: (vector<16xi32>, vector<16xi32>) -> vector<32xi32>
158-
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<32xi32> to vector<32xi32>
159-
// CHECK-NEXT: return %[[RES]] : vector<32xi32>
156+
// CHECK-NEXT: return %[[CONCAT]] : vector<32xi32>
160157

161158
// -----
162159

test/Conversion/AIEVecToLLVM/test-srs.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,61 @@ func.func @v32bf16_srs_v32f32(%arg0 : vector<32xf32>) {
200200
// CHECK-SAME: %[[BITCAST4]], %[[BITCAST5]]) :
201201
// CHECK-SAME: (vector<8xi32>, vector<8xi32>) -> vector<16xi32>
202202
// CHECK-NEXT: %[[BITCAST6:.*]] = llvm.bitcast %[[CONCAT]] : vector<16xi32> to vector<32xbf16>
203+
204+
// -----
205+
206+
func.func @v4x4bf16_srs_v4x4f32(%arg0 : vector<4x4xf32>) {
207+
%c0 = arith.constant 0 : i32
208+
%0 = aievec.srs %arg0, %c0 : vector<4x4xf32>, i32, vector<4x4xbf16>
209+
return
210+
}
211+
212+
// CHECK-LABEL: @v4x4bf16_srs_v4x4f32
213+
// CHECK-SAME: %[[ARG0:.*]]: vector<4x4xf32>
214+
// CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32
215+
// CHECK-NEXT: %[[FLATTEN0:.*]] = vector.shape_cast %[[ARG0]] : vector<4x4xf32> to vector<16xf32>
216+
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[FLATTEN0]] : vector<16xf32> to vector<8xi64>
217+
// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(
218+
// CHECK-SAME: %[[BITCAST0]]) :
219+
// CHECK-SAME: (vector<8xi64>) -> vector<16xbf16>
220+
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[SRS0]] : vector<16xbf16> to vector<16xbf16>
221+
// CHECK-NEXT: %[[RESHAPE0:.*]] = vector.shape_cast %[[BITCAST1]] : vector<16xbf16> to vector<4x4xbf16>
222+
223+
// -----
224+
225+
func.func @v1x1x4x4bf16_srs_v1x1x4x4f32(%arg0 : vector<1x1x4x4xf32>) {
226+
%c0 = arith.constant 0 : i32
227+
%0 = aievec.srs %arg0, %c0 : vector<1x1x4x4xf32>, i32, vector<1x1x4x4xbf16>
228+
return
229+
}
230+
231+
// CHECK-LABEL: @v1x1x4x4bf16_srs_v1x1x4x4f32
232+
// CHECK-SAME: %[[ARG0:.*]]: vector<1x1x4x4xf32>
233+
// CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32
234+
// CHECK-NEXT: %[[FLATTEN0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4x4xf32> to vector<16xf32>
235+
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[FLATTEN0]] : vector<16xf32> to vector<8xi64>
236+
// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(
237+
// CHECK-SAME: %[[BITCAST0]]) :
238+
// CHECK-SAME: (vector<8xi64>) -> vector<16xbf16>
239+
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[SRS0]] : vector<16xbf16> to vector<16xbf16>
240+
// CHECK-NEXT: %[[RESHAPE0:.*]] = vector.shape_cast %[[BITCAST1]] : vector<16xbf16> to vector<1x1x4x4xbf16>
241+
242+
// -----
243+
244+
func.func @v2x8i16_srs_v2x8i32(%arg0 : vector<2x8xi32>) {
245+
%c0 = arith.constant 0 : i32
246+
%0 = aievec.srs %arg0, %c0 : vector<2x8xi32>, i32, vector<2x8xi16>
247+
return
248+
}
249+
250+
// CHECK-LABEL: @v2x8i16_srs_v2x8i32
251+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x8xi32>
252+
// CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32
253+
// CHECK-NEXT: %[[SIGN0:.*]] = llvm.mlir.constant(1 : i32) : i32
254+
// CHECK-NEXT: %[[FLATTEN0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x8xi32> to vector<16xi32>
255+
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[FLATTEN0]] : vector<16xi32> to vector<8xi64>
256+
// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.I256.v16.acc32.srs"(
257+
// CHECK-SAME: %[[BITCAST0]], %[[SHIFT0]], %[[SIGN0]]) :
258+
// CHECK-SAME: (vector<8xi64>, i32, i32) -> vector<16xi16>
259+
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[SRS0]] : vector<16xi16> to vector<16xi16>
260+
// CHECK-NEXT: %[[RESHAPE0:.*]] = vector.shape_cast %[[BITCAST1]] : vector<16xi16> to vector<2x8xi16>

0 commit comments

Comments
 (0)