16
16
#include < imex/Conversion/XeGPUToVC/XeGPUToVC.h>
17
17
18
18
#include " mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
19
+ #include " mlir/IR/BuiltinDialect.h"
20
+ #include " mlir/IR/BuiltinOps.h"
21
+ #include " mlir/IR/BuiltinTypes.h"
19
22
#include " mlir/Pass/Pass.h"
20
23
#include " mlir/Pass/PassManager.h"
21
24
@@ -155,14 +158,17 @@ struct CreateNdDescPattern
155
158
// compute surface width
156
159
auto bytesPerElem = createIntConstant (bitWidth / 8 );
157
160
auto one = createIntConstant (1 );
158
- surfaceW = rewriter.create <arith::ExtUIOp>(loc, i32Type,
159
- adaptor.getShape ()[1 ]);
160
- surfaceW = rewriter.create <arith::MulIOp>(loc, surfaceW, bytesPerElem);
161
+ auto surfaceWCast = rewriter.create <arith::IndexCastUIOp>(
162
+ loc, i32Type, adaptor.getShape ()[1 ]);
163
+
164
+ surfaceW =
165
+ rewriter.create <arith::MulIOp>(loc, surfaceWCast, bytesPerElem);
161
166
surfaceW = rewriter.create <arith::SubIOp>(loc, surfaceW, one);
162
167
// compute surface height
163
- surfaceH = rewriter.create <arith::ExtUIOp>(loc, i32Type,
164
- adaptor.getShape ()[0 ]);
165
- surfaceH = rewriter.create <arith::SubIOp>(loc, surfaceH, one);
168
+
169
+ auto surfaceHCast = rewriter.create <arith::IndexCastUIOp>(
170
+ loc, i32Type, adaptor.getShape ()[0 ]);
171
+ surfaceH = rewriter.create <arith::SubIOp>(loc, surfaceHCast, one);
166
172
// fixme: pitch = width for now
167
173
surfaceP = surfaceW;
168
174
}
@@ -210,15 +216,15 @@ class UpdateNDOffsetToVCPattern
210
216
211
217
auto loc = op.getLoc ();
212
218
auto i32Type = rewriter.getI32Type ();
213
- auto offsets = adaptor .getOffsets ();
219
+ auto offsets = op .getOffsets ();
214
220
215
221
// Get Payload
216
222
auto desc = adaptor.getTensorDesc ();
217
-
218
223
for (size_t i = 0 ; i < offsets.size (); i++) {
219
224
auto offset = offsets[i];
220
- if (auto cst = dyn_cast<arith::ConstantOp>(offset.getDefiningOp ()))
221
- if (auto attr = dyn_cast<mlir::IntegerAttr>(cst.getValue ());
225
+ if (auto cst =
226
+ dyn_cast_if_present<arith::ConstantOp>(offset.getDefiningOp ()))
227
+ if (auto attr = dyn_cast_if_present<mlir::IntegerAttr>(cst.getValue ());
222
228
attr && attr.getInt () == 0 )
223
229
continue ;
224
230
@@ -227,7 +233,8 @@ class UpdateNDOffsetToVCPattern
227
233
// offset.
228
234
int32_t idx = i == 0 ? 6 : 5 ;
229
235
auto oldOffset = rewriter.create <vector::ExtractOp>(loc, desc, idx);
230
- offset = rewriter.create <arith::TruncIOp>(loc, i32Type, offset);
236
+ offset = rewriter.create <arith::IndexCastUIOp>(loc, i32Type, offset);
237
+
231
238
auto newOffset = rewriter.create <arith::AddIOp>(loc, oldOffset, offset);
232
239
233
240
// Update new 2D Block OffsetX/OffsetY in Payload descriptor.
@@ -630,6 +637,7 @@ struct DpasPattern : public OpConversionPattern<::mlir::xegpu::DpasOp> {
630
637
auto infoAttr = rewriter.getIntegerAttr (rewriter.getI32Type (), infoVal);
631
638
auto info = rewriter.create <arith::ConstantOp>(loc, rewriter.getI32Type (),
632
639
infoAttr);
640
+
633
641
auto newResultType = encodeVectorType (rewriter, resultType).second ;
634
642
SmallVector<Value, 4 > args{adaptor.getRhs (), adaptor.getLhs (), info};
635
643
std::string funcName = " llvm.genx.dpas.nosrc0." ;
@@ -1106,7 +1114,7 @@ class FenceToVCPattern : public OpConversionPattern<::mlir::xegpu::FenceOp> {
1106
1114
}
1107
1115
};
1108
1116
1109
- struct VectorShapeCast final
1117
+ struct VectorShapeCastVC final
1110
1118
: public OpConversionPattern<mlir::vector::ShapeCastOp> {
1111
1119
using OpConversionPattern<mlir::vector::ShapeCastOp>::OpConversionPattern;
1112
1120
@@ -1130,7 +1138,7 @@ struct VectorShapeCast final
1130
1138
}
1131
1139
};
1132
1140
1133
- struct VectorExtract final
1141
+ struct VectorExtractVC final
1134
1142
: public OpConversionPattern<mlir::vector::ExtractOp> {
1135
1143
using OpConversionPattern<mlir::vector::ExtractOp>::OpConversionPattern;
1136
1144
@@ -1139,6 +1147,7 @@ struct VectorExtract final
1139
1147
ConversionPatternRewriter &rewriter) const override {
1140
1148
1141
1149
auto *converter = getTypeConverter ();
1150
+
1142
1151
auto dstTy = converter->convertType (extractOp.getType ());
1143
1152
if (!dstTy)
1144
1153
return failure ();
@@ -1200,7 +1209,7 @@ static uint64_t getFirstIntValue(mlir::ArrayAttr attr) {
1200
1209
return (*attr.getAsValueRange <IntegerAttr>().begin ()).getZExtValue ();
1201
1210
};
1202
1211
1203
- struct VectorExtractStridedSlice final
1212
+ struct VectorExtractStridedSliceVC final
1204
1213
: public OpConversionPattern<vector::ExtractStridedSliceOp> {
1205
1214
using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
1206
1215
LogicalResult
@@ -1298,7 +1307,7 @@ struct VectorExtractStridedSlice final
1298
1307
}
1299
1308
};
1300
1309
1301
- struct VectorShuffle final
1310
+ struct VectorShuffleVC final
1302
1311
: public OpConversionPattern<mlir::vector::ShuffleOp> {
1303
1312
using OpConversionPattern<mlir::vector::ShuffleOp>::OpConversionPattern;
1304
1313
@@ -1387,7 +1396,7 @@ struct SCFForOpBlockVCPattern final
1387
1396
1388
1397
rewriter.applySignatureConversion (&op.getRegion (), signatureConverter);
1389
1398
1390
- newOp.getBody ()-> erase ( );
1399
+ rewriter. eraseBlock ( newOp.getBody ());
1391
1400
rewriter.inlineRegionBefore (op.getRegion (), newOp.getRegion (),
1392
1401
newOp.getRegion ().end ());
1393
1402
rewriter.replaceOp (op, newOp.getResults ());
@@ -1453,8 +1462,8 @@ struct XeGPUToVCPass : public ::imex::ConvertXeGPUToVCBase<XeGPUToVCPass> {
1453
1462
target.addDynamicallyLegalDialect <mlir::scf::SCFDialect>(
1454
1463
[&](mlir::Operation *op) { return isLegalXeGPUSCFOp (op); });
1455
1464
1456
- target.addIllegalOp <::mlir::vector::ShapeCastOp>();
1457
- target. addIllegalOp < ::mlir::vector::ExtractStridedSliceOp>();
1465
+ target.addIllegalOp <::mlir::vector::ShapeCastOp,
1466
+ ::mlir::vector::ExtractStridedSliceOp>();
1458
1467
1459
1468
typeConverter.addConversion (
1460
1469
[&](xegpu::TensorDescType type) -> ::mlir::Type {
@@ -1507,11 +1516,12 @@ struct XeGPUToVCPass : public ::imex::ConvertXeGPUToVCBase<XeGPUToVCPass> {
1507
1516
CompilerHintToVCPattern, FenceToVCPattern,
1508
1517
UpdateNDOffsetToVCPattern, SCFYieldOpVCPattern>(
1509
1518
patterns.getContext ());
1510
- patterns.add <GatherScatterToRawSend<xegpu::LoadGatherOp>,
1511
- GatherScatterToRawSend<xegpu::StoreScatterOp>, AtomicToLsc,
1512
- VectorShapeCast, VectorExtract, VectorExtractStridedSlice,
1513
- VectorShuffle, SCFForOpBlockVCPattern>(typeConverter,
1514
- patterns.getContext ());
1519
+ patterns
1520
+ .add <GatherScatterToRawSend<xegpu::LoadGatherOp>,
1521
+ GatherScatterToRawSend<xegpu::StoreScatterOp>, AtomicToLsc,
1522
+ VectorShapeCastVC, VectorExtractVC, VectorExtractStridedSliceVC,
1523
+ VectorShuffleVC, SCFForOpBlockVCPattern>(typeConverter,
1524
+ patterns.getContext ());
1515
1525
1516
1526
if (this ->useRawSend ) {
1517
1527
patterns.add <LoadStorePrefetchNdToRawSendPattern<xegpu::LoadNdOp>,
0 commit comments