|
7 | 7 | //===----------------------------------------------------------------------===// |
8 | 8 |
|
9 | 9 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 10 | +#include "mlir/Dialect/Index/IR/IndexDialect.h" |
10 | 11 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
11 | 12 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
12 | 13 | #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" |
| 14 | +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" |
13 | 15 | #include "mlir/Pass/Pass.h" |
14 | 16 | #include "mlir/Pass/PassManager.h" |
| 17 | +#include "mlir/Transforms/DialectConversion.h" |
15 | 18 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
16 | 19 |
|
17 | 20 | using namespace mlir; |
@@ -149,12 +152,116 @@ struct TestXeGPUUnrollingPatterns |
149 | 152 | } |
150 | 153 | }; |
151 | 154 |
|
| 155 | +#undef DEBUG_TYPE |
| 156 | +#define DEBUG_TYPE "test-xegpu-layout-interface" |
| 157 | +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 158 | +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
| 159 | + |
| 160 | +class TestStepOpPattern : public OpConversionPattern<vector::StepOp> { |
| 161 | + using OpConversionPattern<vector::StepOp>::OpConversionPattern; |
| 162 | + |
| 163 | + LogicalResult |
| 164 | + matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor, |
| 165 | + ConversionPatternRewriter &rewriter) const override { |
| 166 | + |
| 167 | + auto layoutName = xegpu::getLayoutName(op->getResult(0)); |
| 168 | + auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName); |
| 169 | + if (!sliceAttr || sliceAttr.getRank() != 1) |
| 170 | + return failure(); |
| 171 | + |
| 172 | + std::optional<SmallVector<int64_t>> sgShape = |
| 173 | + sliceAttr.getEffectiveSgData(); |
| 174 | + if (!sgShape) |
| 175 | + return failure(); |
| 176 | + |
| 177 | + Location loc = op.getLoc(); |
| 178 | + VectorType type = op.getResult().getType(); |
| 179 | + auto wgShape = type.getShape(); |
| 180 | + |
| 181 | + Value sgId = |
| 182 | + gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); |
| 183 | + auto maybeOffsets = sliceAttr.getOffsets(rewriter, loc, sgId, wgShape); |
| 184 | + if (failed(maybeOffsets)) |
| 185 | + return failure(); |
| 186 | + |
| 187 | + VectorType newTy = type.cloneWith(*sgShape, type.getElementType()); |
| 188 | + Value base = vector::StepOp::create(rewriter, loc, newTy); |
| 189 | + SmallVector<Value> newOps; |
| 190 | + for (auto offsets : *maybeOffsets) { |
| 191 | + Value bcast = |
| 192 | + vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]); |
| 193 | + Value add = arith::AddIOp::create(rewriter, loc, base, bcast); |
| 194 | + newOps.push_back(add); |
| 195 | + } |
| 196 | + rewriter.replaceOpWithMultiple(op, {newOps}); |
| 197 | + return success(); |
| 198 | + } |
| 199 | +}; |
| 200 | + |
| 201 | +struct TestXeGPULayoutInterface |
| 202 | + : public PassWrapper<TestXeGPULayoutInterface, |
| 203 | + OperationPass<gpu::GPUModuleOp>> { |
| 204 | + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPULayoutInterface) |
| 205 | + |
| 206 | + StringRef getArgument() const final { return "test-xegpu-layout-interface"; } |
| 207 | + |
| 208 | + StringRef getDescription() const final { |
| 209 | + return "Test the implementation of XeGPU Layout interfaces"; |
| 210 | + } |
| 211 | + |
| 212 | + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
| 213 | + registry.insert<arith::ArithDialect>(); |
| 214 | + registry.insert<memref::MemRefDialect>(); |
| 215 | + registry.insert<xegpu::XeGPUDialect>(); |
| 216 | + registry.insert<vector::VectorDialect>(); |
| 217 | + registry.insert<index::IndexDialect>(); |
| 218 | + } |
| 219 | + |
| 220 | + TestXeGPULayoutInterface() = default; |
| 221 | + TestXeGPULayoutInterface(const TestXeGPULayoutInterface &pass) |
| 222 | + : PassWrapper(pass) {} |
| 223 | + |
| 224 | + void runOnOperation() override { |
| 225 | + MLIRContext *ctx = &getContext(); |
| 226 | + |
| 227 | + TypeConverter typeConverter; |
| 228 | + auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type, |
| 229 | + mlir::ValueRange inputs, |
| 230 | + mlir::Location loc) -> mlir::Value { |
| 231 | + return builder.create<UnrealizedConversionCastOp>(loc, type, inputs) |
| 232 | + .getResult(0); |
| 233 | + }; |
| 234 | + typeConverter.addSourceMaterialization(materializeCast); |
| 235 | + typeConverter.addTargetMaterialization(materializeCast); |
| 236 | + |
| 237 | + RewritePatternSet patterns(ctx); |
| 238 | + patterns.add<TestStepOpPattern>(typeConverter, ctx); |
| 239 | + |
| 240 | + ConversionTarget target(*ctx); |
| 241 | + auto isLegal = [&](xegpu::SliceAttr layout) -> bool { |
| 242 | + return !layout || !layout.isWgLayout(); |
| 243 | + }; |
| 244 | + |
| 245 | + target.addDynamicallyLegalOp<vector::StepOp>( |
| 246 | + [&](vector::StepOp op) -> bool { |
| 247 | + auto layoutName = xegpu::getLayoutName(op->getResult(0)); |
| 248 | + auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName); |
| 249 | + return isLegal(sliceAttr); |
| 250 | + }); |
| 251 | + |
| 252 | + target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; }); |
| 253 | + |
| 254 | + (void)applyPartialConversion(getOperation(), target, std::move(patterns)); |
| 255 | + } |
| 256 | +}; |
| 257 | + |
152 | 258 | } // namespace |
153 | 259 |
|
154 | 260 | namespace mlir { |
155 | 261 | namespace test { |
156 | 262 | void registerTestXeGPULowerings() { |
157 | 263 | PassRegistration<TestXeGPUUnrollingPatterns>(); |
| 264 | + PassRegistration<TestXeGPULayoutInterface>(); |
158 | 265 | } |
159 | 266 | } // namespace test |
160 | 267 | } // namespace mlir |
0 commit comments