From a808983eac590930a17863152cb2a6f2aa855141 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Tue, 3 Jun 2025 14:44:03 +0000 Subject: [PATCH 01/14] add scf support --- .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 3 + .../Transforms/XeGPUWgToSgDistribute.cpp | 106 ++++++++++++++++-- mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 4 +- .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 2 +- mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 6 +- 5 files changed, 104 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h index f9327d63869c0..6fea10185402a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h @@ -26,6 +26,9 @@ class TensorDescType; namespace xegpu { +/// Flatten a set of ValueRange into a single SmallVector +SmallVector flattenValues(ArrayRef values); + /// If tensor descriptor has a layout attribute it is used in SIMT mode. /// In this mode, the distributed vector shape is determined as follows: /// Definitions: diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 3bf76af674ba0..ad12cf34ca7b3 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -29,6 +30,30 @@ using namespace mlir; namespace { +static std::pair, int> +computeTileShapeAndCount(ArrayRef shape, xegpu::LayoutAttr layout) { + // init count and subShape to the default value. If the LayoutAttr + // is not present, it will return a VectorType with original shape. + int count = 1; + SmallVector tileShape(shape); + + if (layout) { + if (DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout()) { + auto sgLayout = llvm::to_vector_of(sgLayoutAttr.asArrayRef()); + if (DenseI32ArrayAttr sgDataAttr = layout.getSgData()) + tileShape = llvm::to_vector_of(sgDataAttr.asArrayRef()); + else + tileShape = computeShapeRatio(shape, sgLayout).value_or(tileShape); + SmallVector distUnit = + computeElementwiseMul(sgLayout, tileShape); + for (size_t i = 0; i < distUnit.size(); ++i) + distUnit[i] = std::min(shape[i], distUnit[i]); + count = computeProduct(shape) / computeProduct(distUnit); + } + } + return std::make_pair(tileShape, count); +} + /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor /// from a workgroup descriptor. It replaces the offsets and sizes with /// appropriate values for the subgroup. @@ -266,15 +291,15 @@ struct WgToSgDpasOp : public OpConversionPattern { if (resultTy.getRank() != 2) return failure(); - auto originalLayout = - llvm::dyn_cast_or_null(op->getAttr("layout")); + auto originalLayout = xegpu::getLayoutAttr(op.getResult()); if (!originalLayout) return failure(); - SmallVector newDpasOps; size_t i = 0; + SmallVector newDpasOps; for (auto aVec : adaptor.getLhs()) { for (auto bVec : adaptor.getRhs()) { + llvm::SmallVector operands({aVec, bVec}); Value tmpC; if (op.getAcc()) { @@ -288,10 +313,9 @@ struct WgToSgDpasOp : public OpConversionPattern { llvm::cast(bVec.getType()).getShape(); VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]}, resultTy.getElementType()); - tmpC = rewriter.create( - loc, resTy, operands, - llvm::ArrayRef( - {"layout_result_0", originalLayout.dropSgLayoutAndData()})); + tmpC = rewriter.create(loc, resTy, operands); + xegpu::setLayoutAttr(cast(tmpC), originalLayout.dropSgLayoutAndData()); + newDpasOps.push_back(tmpC); } } @@ -314,14 +338,30 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern { } }; +struct UnrealizedConversionCastOpPattern + : public OpConversionPattern { + using OpConversionPattern< + mlir::UnrealizedConversionCastOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getNumOperands() == 1 && op.getNumResults() == 1) { + rewriter.replaceOpWithMultiple(op, xegpu::flattenValues(adaptor.getInputs())); + return mlir::success(); + } + return mlir::failure(); + } +}; + } // namespace namespace mlir { namespace xegpu { void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { patterns.add( - patterns.getContext()); + WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, + UnrealizedConversionCastOpPattern>(patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -353,7 +393,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { }; auto isLegal = [&](xegpu::LayoutAttr layout) -> bool { - return !layout || layout.getSgLayout() == nullptr; + return !layout || !layout.isWgLayout(); }; target.addDynamicallyLegalOp([=](xegpu::DpasOp op) -> bool { - auto layout = dyn_cast_or_null(op->getAttr("layout")); + auto layout = xegpu::getLayoutAttr(op.getResult()); return isLegal(layout); }); + target.addIllegalOp(); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + TypeConverter converter; + converter.addConversion([&](Type type) -> Type { return type; }); + converter.addConversion( + [&](RankedTensorType type, + SmallVectorImpl &result) -> std::optional { + Type elemTy = type.getElementType(); + ArrayRef shape = type.getShape(); + + int count; + SmallVector subShape; + std::tie(subShape, count) = computeTileShapeAndCount( + shape, dyn_cast(type.getEncoding())); + + auto newTy = VectorType::get(subShape, elemTy); + result.append(count, newTy); + return success(); + }); + + converter.addConversion( + [&](xegpu::TensorDescType type, + SmallVectorImpl &result) -> std::optional { + Type elemTy = type.getElementType(); + ArrayRef shape = type.getShape(); + + // init count and newTy to the default value. If the layout + // attribute is not present, it will return the original type. + int count; + SmallVector subShape; + xegpu::LayoutAttr layout = type.getLayoutAttr(); + std::tie(subShape, count) = computeTileShapeAndCount(shape, layout); + + if (layout) + layout = layout.dropSgLayoutAndData(); + + auto newTy = xegpu::TensorDescType::get( + type.getContext(), subShape, elemTy, type.getEncoding(), layout); + result.append(count, newTy); + return success(); + }); + + xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter); + xegpu::populateXeGPUWgToSgDistributePatterns(patterns); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index dcaf4e85a82c5..ed48e3cc13117 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -27,7 +27,7 @@ using namespace mlir; /// convert ArrayRef into SmallVector -static SmallVector flattenValues(ArrayRef values) { +SmallVector xegpu::flattenValues(ArrayRef values) { SmallVector result; for (const auto &vals : values) llvm::append_range(result, vals); @@ -342,7 +342,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( } if (isa(inputTy) && isa(outputTy)) { - SmallVector values = flattenValues(adaptor.getInputs()); + SmallVector values = xegpu::flattenValues(adaptor.getInputs()); auto newOp = rewriter.create( op.getLoc(), outputTy, values); rewriter.replaceOp(op, newOp); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index bee026eb2084d..fa1e5fbae0954 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -85,7 +85,7 @@ gpu.module @test_round_robin_assignment { %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout> %dpas = xegpu.dpas %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32> gpu.return } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index 7e89ada934071..22374f74b133e 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -5,7 +5,7 @@ gpu.module @test_1_1_assignment { // CHECK-LABEL: test_create_nd_tdesc // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) { + gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) { // CHECK: %[[SGID:.*]] = gpu.subgroup_id // CHECK: %[[C12:.*]] = arith.constant 12 : index // CHECK: %[[C4:.*]] = arith.constant 4 : index @@ -108,7 +108,7 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { : !xegpu.tensor_desc<32x24xf32, #xegpu.layout> -> vector<32x24xf32> %dpas = xegpu.dpas %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> gpu.return } @@ -142,7 +142,7 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { : !xegpu.tensor_desc<32x24xf32, #xegpu.layout> -> vector<32x24xf32> %dpas = xegpu.dpas %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> gpu.return } From bc3b74b7df81dd3eddfdb4c243985f49771ef8ca Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Tue, 3 Jun 2025 14:54:26 +0000 Subject: [PATCH 02/14] fix format --- mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index ad12cf34ca7b3..f09e77273e9e6 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -314,7 +314,8 @@ struct WgToSgDpasOp : public OpConversionPattern { VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]}, resultTy.getElementType()); tmpC = rewriter.create(loc, resTy, operands); - xegpu::setLayoutAttr(cast(tmpC), originalLayout.dropSgLayoutAndData()); + xegpu::setLayoutAttr(cast(tmpC), + originalLayout.dropSgLayoutAndData()); newDpasOps.push_back(tmpC); } @@ -347,7 +348,8 @@ struct UnrealizedConversionCastOpPattern matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (op.getNumOperands() == 1 && op.getNumResults() == 1) { - rewriter.replaceOpWithMultiple(op, xegpu::flattenValues(adaptor.getInputs())); + rewriter.replaceOpWithMultiple(op, + xegpu::flattenValues(adaptor.getInputs())); return mlir::success(); } return mlir::failure(); From 392dfb072db6c273d255396247ac5dd3bf526ed9 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 4 Jun 2025 17:16:06 +0000 Subject: [PATCH 03/14] refine UnrealizedConversionCastOpPattern --- .../Transforms/XeGPUWgToSgDistribute.cpp | 174 +++++++++++------- mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 2 +- .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 24 +++ mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 64 +++++++ 4 files changed, 194 insertions(+), 70 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index f09e77273e9e6..5bf8d7975a131 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -29,31 +29,6 @@ namespace xegpu { using namespace mlir; namespace { - -static std::pair, int> -computeTileShapeAndCount(ArrayRef shape, xegpu::LayoutAttr layout) { - // init count and subShape to the default value. If the LayoutAttr - // is not present, it will return a VectorType with original shape. - int count = 1; - SmallVector tileShape(shape); - - if (layout) { - if (DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout()) { - auto sgLayout = llvm::to_vector_of(sgLayoutAttr.asArrayRef()); - if (DenseI32ArrayAttr sgDataAttr = layout.getSgData()) - tileShape = llvm::to_vector_of(sgDataAttr.asArrayRef()); - else - tileShape = computeShapeRatio(shape, sgLayout).value_or(tileShape); - SmallVector distUnit = - computeElementwiseMul(sgLayout, tileShape); - for (size_t i = 0; i < distUnit.size(); ++i) - distUnit[i] = std::min(shape[i], distUnit[i]); - count = computeProduct(shape) / computeProduct(distUnit); - } - } - return std::make_pair(tileShape, count); -} - /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor /// from a workgroup descriptor. It replaces the offsets and sizes with /// appropriate values for the subgroup. @@ -339,6 +314,15 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern { } }; +// Handles UnrealizedConversionCastOp generated during +// SCFStructuralTypeConversions (step 1). This op may appear as either a +// target or source materialization for Vector or TensorDesc, e.g.: +// 1. unrealized_conversion_cast %1 : tensor_desc<16xf16> to +// tensor_desc<128xf16, ...> +// 2. unrealized_conversion_cast %1 : vector<256xf32> to vector<16xf32>, ... +// 3. unrealized_conversion_cast %1 : vector<16xf32>, ... to vector<256xf32> +// In all cases, the pattern simply forwards the inputs to the outputs with +// one-to-one or one-to-n patterns. struct UnrealizedConversionCastOpPattern : public OpConversionPattern { using OpConversionPattern< @@ -347,11 +331,40 @@ struct UnrealizedConversionCastOpPattern mlir::LogicalResult matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (op.getNumOperands() == 1 && op.getNumResults() == 1) { - rewriter.replaceOpWithMultiple(op, - xegpu::flattenValues(adaptor.getInputs())); - return mlir::success(); + SmallVector inputs = xegpu::flattenValues(adaptor.getInputs()); + + // Handles the case where cast %1 : tensor_desc<16xf16> to + // tensor_desc<128xf16, ...> The input values provided by the adaptor should + // already be distributed. + if (op.getNumOperands() == 1 && op.getNumResults() == 1 && + isa(op->getOperand(0).getType()) && + isa(op->getResult(0).getType())) { + rewriter.replaceOp(op, inputs); + return success(); } + + // Handles the case where cast %1 : vector<256xf32> to vector<16xf32>, ... + // the input values provided by the adaptor should already be distributed, + // and their types should correspond exactly to the result types of the + // operation. + if (op.getNumOperands() == 1 && + llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) { + rewriter.replaceOp(op, inputs); + return success(); + } + + // Handles the case where cast %1 : vector<16xf32>, ... to vector<256xf32>. + // All input values must have the same vector type, and their shape must be + // evenly divisible by the output vector's shape. + auto inputTy = dyn_cast(inputs[0].getType()); + auto outputTy = dyn_cast(op->getOpResult(0).getType()); + if (op.getNumResults() == 1 && inputTy && outputTy && + llvm::all_equal(ValueRange(inputs).getTypes()) && + computeShapeRatio(outputTy.getShape(), inputTy.getShape())) { + rewriter.replaceOpWithMultiple(op, {inputs}); + return success(); + } + return mlir::failure(); } }; @@ -376,45 +389,29 @@ struct XeGPUWgToSgDistributePass } // namespace void XeGPUWgToSgDistributePass::runOnOperation() { - MLIRContext *ctx = &getContext(); - RewritePatternSet patterns(ctx); - ConversionTarget target(*ctx); + auto getSgShapeAndCount = [](ArrayRef shape, + xegpu::LayoutAttr layout) { + int count = 1; + SmallVector sgShape(shape); - auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType { - if (auto createOp = dyn_cast(op)) - return createOp.getType(); - if (auto loadOp = dyn_cast(op)) - return loadOp.getTensorDescType(); - if (auto storeOp = dyn_cast(op)) - return storeOp.getTensorDescType(); - if (auto updateOp = dyn_cast(op)) - return updateOp.getType(); - if (auto prefetchOp = dyn_cast(op)) - return prefetchOp.getTensorDescType(); - return xegpu::TensorDescType(); - }; - - auto isLegal = [&](xegpu::LayoutAttr layout) -> bool { - return !layout || !layout.isWgLayout(); + if (layout && layout.isWgLayout()) { + DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout(); + auto sgLayout = llvm::to_vector_of(sgLayoutAttr.asArrayRef()); + if (DenseI32ArrayAttr sgDataAttr = layout.getSgData()) + sgShape = llvm::to_vector_of(sgDataAttr.asArrayRef()); + else + sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape); + SmallVector distUnit = computeElementwiseMul(sgLayout, sgShape); + // Clamp distUnit to the original shape to handle cases where data is + // shared among subgroups, which may cause distUnit to exceed the original + // shape. + for (size_t i = 0; i < distUnit.size(); ++i) + distUnit[i] = std::min(shape[i], distUnit[i]); + count = computeProduct(shape) / computeProduct(distUnit); + } + return std::make_pair(sgShape, count); }; - target.addDynamicallyLegalOp([=](Operation *op) -> bool { - auto tdescTy = getTensorDescType(op); - auto layout = dyn_cast_or_null(tdescTy.getLayout()); - return isLegal(layout); - }); - - target.addDynamicallyLegalOp([=](xegpu::DpasOp op) -> bool { - auto layout = xegpu::getLayoutAttr(op.getResult()); - return isLegal(layout); - }); - - target.addIllegalOp(); - - target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - TypeConverter converter; converter.addConversion([&](Type type) -> Type { return type; }); converter.addConversion( @@ -425,7 +422,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { int count; SmallVector subShape; - std::tie(subShape, count) = computeTileShapeAndCount( + std::tie(subShape, count) = getSgShapeAndCount( shape, dyn_cast(type.getEncoding())); auto newTy = VectorType::get(subShape, elemTy); @@ -439,12 +436,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() { Type elemTy = type.getElementType(); ArrayRef shape = type.getShape(); - // init count and newTy to the default value. If the layout - // attribute is not present, it will return the original type. int count; SmallVector subShape; xegpu::LayoutAttr layout = type.getLayoutAttr(); - std::tie(subShape, count) = computeTileShapeAndCount(shape, layout); + std::tie(subShape, count) = getSgShapeAndCount(shape, layout); if (layout) layout = layout.dropSgLayoutAndData(); @@ -455,8 +450,49 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return success(); }); + // step1: perform SCFStructuralTypeConversions on SCF ops xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter); + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + ConversionTarget target(*ctx); + + auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType { + if (auto createOp = dyn_cast(op)) + return createOp.getType(); + if (auto loadOp = dyn_cast(op)) + return loadOp.getTensorDescType(); + if (auto storeOp = dyn_cast(op)) + return storeOp.getTensorDescType(); + if (auto updateOp = dyn_cast(op)) + return updateOp.getType(); + if (auto prefetchOp = dyn_cast(op)) + return prefetchOp.getTensorDescType(); + return xegpu::TensorDescType(); + }; + + auto isLegal = [&](xegpu::LayoutAttr layout) -> bool { + return !layout || !layout.isWgLayout(); + }; + + target.addDynamicallyLegalOp([=](Operation *op) -> bool { + auto tdescTy = getTensorDescType(op); + auto layout = dyn_cast_if_present(tdescTy.getLayout()); + return isLegal(layout); + }); + + target.addDynamicallyLegalOp([=](xegpu::DpasOp op) -> bool { + auto layout = xegpu::getLayoutAttr(op.getResult()); + return isLegal(layout); + }); + + target.addIllegalOp(); + + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + + // step2: Perform for workgroup to subgroup distribution for rest ops xegpu::populateXeGPUWgToSgDistributePatterns(patterns); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index ed48e3cc13117..6b85a66a8bd36 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -271,7 +271,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( auto resultTy = dyn_cast(result.getType()); // Only look at ops casting from VectorType to RankedTensorType - if (!isa(inputTy) || !isa(resultTy)) + if (!inputTy || !resultTy) return WalkResult::skip(); xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index fa1e5fbae0954..ff86e65300bb8 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -102,4 +102,28 @@ gpu.module @test_round_robin_assignment { : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> gpu.return } + + gpu.func @test_scf_while_and_condition(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c1_i32 = arith.constant 1 : i32 + %c10_i32 = arith.constant 10 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + //CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32) + %3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) { + %4 = arith.cmpi slt, %arg3, %c10_i32 : i32 + //CHECK: scf.condition{{.*}} : vector<16xf32>, vector<16xf32>, i32 + scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32 + } do { + // CHECK: ([[arg2:%.+]]: vector<16xf32>, [[arg3:%.+]]: vector<16xf32>, [[arg4:%.+]]: i32) + ^bb0(%arg2: vector<256xf32>, %arg3: i32): + xegpu.store_nd %arg2, %2 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + %4 = arith.addi %arg3, %c1_i32 : i32 + %5 = xegpu.update_nd_offset %0, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout> + %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + scf.yield %6, %4 : vector<256xf32>, i32 + } + gpu.return + } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index 22374f74b133e..d016d3a30a339 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -169,4 +169,68 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> gpu.return } + + gpu.func @test_scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) { + //CHECK: [[c0:%.+]] = arith.constant 0 : index + //CHECK: [[c128:%.+]] = arith.constant 128 : index + //CHECK: [[c1024:%.+]] = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c128 : index + %1 = arith.muli %block_id_y, %c128 : index + %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<128x128xf32, #xegpu.layout> -> vector<128x128xf32> + %4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + %5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + + //CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]] iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) -> (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>) + //CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16> + //CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16> + //CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32> + //CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16> + //CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16> + //CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout>, !xegpu.tensor_desc<128x128xf16, #xegpu.layout>, vector<128x128xf32>) { + %8 = xegpu.load_nd %arg4 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> -> vector<128x128xf16> + %9 = xegpu.load_nd %arg5 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> -> vector<128x128xf16> + %10 = xegpu.dpas %8, %9, %arg6 {layout_result_0 = #xegpu.layout} : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32> + %11 = xegpu.update_nd_offset %arg4, [%c0, %c128] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + %12 = xegpu.update_nd_offset %arg5, [%c128, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + scf.yield %11, %12, %10 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout>, !xegpu.tensor_desc<128x128xf16, #xegpu.layout>, vector<128x128xf32> + } + %7 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout> + xegpu.store_nd %6#2, %7 : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout> + gpu.return + } + + + gpu.func @test_scf_while_and_condition(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c1_i32 = arith.constant 1 : i32 + %c10_i32 = arith.constant 10 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + + // CHECK: scf.while {{.*}} : (vector<16xf32>, i32) -> (vector<16xf32>, i32) + %3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) { + %4 = arith.cmpi slt, %arg3, %c10_i32 : i32 + // CHECK: scf.condition{{.*}} : vector<16xf32>, i32 + scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32 + } do { + // CHECK: ([[arg2:%.+]]: vector<16xf32>, [[arg3:%.+]]: i32) + ^bb0(%arg2: vector<256xf32>, %arg3: i32): + xegpu.store_nd %arg2, %2 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + %4 = arith.addi %arg3, %c1_i32 : i32 + %5 = xegpu.update_nd_offset %0, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout> + %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + scf.yield %6, %4 : vector<256xf32>, i32 + } + gpu.return + } + + } From 449a2edea22369ae160a38dcb15cfe3b3bf4d2e7 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 4 Jun 2025 17:27:54 +0000 Subject: [PATCH 04/14] refactor --- .../Transforms/XeGPUWgToSgDistribute.cpp | 60 ++++++++----------- 1 file changed, 25 insertions(+), 35 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 5bf8d7975a131..e29da76898c58 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -29,6 +29,30 @@ namespace xegpu { using namespace mlir; namespace { + +static std::pair, int> +getSgShapeAndCount(ArrayRef shape, xegpu::LayoutAttr layout) { + int count = 1; + SmallVector sgShape(shape); + + if (layout && layout.isWgLayout()) { + DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout(); + auto sgLayout = llvm::to_vector_of(sgLayoutAttr.asArrayRef()); + if (DenseI32ArrayAttr sgDataAttr = layout.getSgData()) + sgShape = llvm::to_vector_of(sgDataAttr.asArrayRef()); + else + sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape); + SmallVector distUnit = computeElementwiseMul(sgLayout, sgShape); + // Clamp distUnit to the original shape to handle cases where data is + // shared among subgroups, which may cause distUnit to exceed the original + // shape. + for (size_t i = 0; i < distUnit.size(); ++i) + distUnit[i] = std::min(shape[i], distUnit[i]); + count = computeProduct(shape) / computeProduct(distUnit); + } + return std::make_pair(sgShape, count); +}; + /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor /// from a workgroup descriptor. It replaces the offsets and sizes with /// appropriate values for the subgroup. @@ -129,18 +153,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "sgLayout attribute is required in layout"); - SmallVector sgShape; - if (auto sgDataAttr = layout.getSgData()) { - sgShape = llvm::to_vector_of(sgDataAttr.asArrayRef()); - } else { - assert(wgShape.size() == sgLayout.size() && - "sgLayout and wgShape must have the same rank"); - sgShape.reserve(wgShape.size()); - for (size_t i = 0; i < wgShape.size(); ++i) { - assert(sgLayout[i] != 0 && "sgLayout elements must be non-zero"); - sgShape.push_back(wgShape[i] / sgLayout[i]); - } - } + SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; // TODO : Handle order attribute // Get the subgroup ID @@ -389,29 +402,6 @@ struct XeGPUWgToSgDistributePass } // namespace void XeGPUWgToSgDistributePass::runOnOperation() { - auto getSgShapeAndCount = [](ArrayRef shape, - xegpu::LayoutAttr layout) { - int count = 1; - SmallVector sgShape(shape); - - if (layout && layout.isWgLayout()) { - DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout(); - auto sgLayout = llvm::to_vector_of(sgLayoutAttr.asArrayRef()); - if (DenseI32ArrayAttr sgDataAttr = layout.getSgData()) - sgShape = llvm::to_vector_of(sgDataAttr.asArrayRef()); - else - sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape); - SmallVector distUnit = computeElementwiseMul(sgLayout, sgShape); - // Clamp distUnit to the original shape to handle cases where data is - // shared among subgroups, which may cause distUnit to exceed the original - // shape. - for (size_t i = 0; i < distUnit.size(); ++i) - distUnit[i] = std::min(shape[i], distUnit[i]); - count = computeProduct(shape) / computeProduct(distUnit); - } - return std::make_pair(sgShape, count); - }; - TypeConverter converter; converter.addConversion([&](Type type) -> Type { return type; }); converter.addConversion( From b2032a4a8ab66226738a0bf32a6e8f63dda2f39c Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 4 Jun 2025 18:00:31 +0000 Subject: [PATCH 05/14] format --- mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index e29da76898c58..3bec5724c764d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -51,7 +51,7 @@ getSgShapeAndCount(ArrayRef shape, xegpu::LayoutAttr layout) { count = computeProduct(shape) / computeProduct(distUnit); } return std::make_pair(sgShape, count); -}; +} /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor /// from a workgroup descriptor. It replaces the offsets and sizes with From bf37af142849d7c92256d659f3e8a7d41c8f05d0 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 4 Jun 2025 20:44:50 +0000 Subject: [PATCH 06/14] add one more unit tests --- .../Transforms/XeGPUWgToSgDistribute.cpp | 40 ++++++++----------- .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 21 ++++++++++ 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 3bec5724c764d..11d723ec0aabe 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -329,13 +329,13 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern { // Handles UnrealizedConversionCastOp generated during // SCFStructuralTypeConversions (step 1). This op may appear as either a -// target or source materialization for Vector or TensorDesc, e.g.: -// 1. unrealized_conversion_cast %1 : tensor_desc<16xf16> to -// tensor_desc<128xf16, ...> -// 2. unrealized_conversion_cast %1 : vector<256xf32> to vector<16xf32>, ... -// 3. unrealized_conversion_cast %1 : vector<16xf32>, ... to vector<256xf32> -// In all cases, the pattern simply forwards the inputs to the outputs with -// one-to-one or one-to-n patterns. +// target or source materialization for Vector or TensorDesc values, e.g.: +// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ... +// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32> +// it could be either 1:1, 1:N or N:1 cast. In all cases, the pattern +// simply forwards the inputs to the outputs using 1:1 or 1:N interface. +// TODO: remove it when context-aware type converter is ready. +// It is safe only when input codes don't contain UnrealizedConversionCastOp. struct UnrealizedConversionCastOpPattern : public OpConversionPattern { using OpConversionPattern< @@ -346,22 +346,20 @@ struct UnrealizedConversionCastOpPattern ConversionPatternRewriter &rewriter) const override { SmallVector inputs = xegpu::flattenValues(adaptor.getInputs()); - // Handles the case where cast %1 : tensor_desc<16xf16> to - // tensor_desc<128xf16, ...> The input values provided by the adaptor should - // already be distributed. - if (op.getNumOperands() == 1 && op.getNumResults() == 1 && - isa(op->getOperand(0).getType()) && - isa(op->getResult(0).getType())) { - rewriter.replaceOp(op, inputs); - return success(); - } + auto inputTy = inputs[0].getType(); + auto outputTy = op->getOpResult(0).getType(); + + if (!llvm::all_equal(op->getResultTypes()) || + !llvm::all_equal(ValueRange(inputs).getTypes()) || + !isa(inputTy) || + !isa(outputTy)) + return failure(); // Handles the case where cast %1 : vector<256xf32> to vector<16xf32>, ... // the input values provided by the adaptor should already be distributed, // and their types should correspond exactly to the result types of the // operation. - if (op.getNumOperands() == 1 && - llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) { + if (op.getNumOperands() == 1) { rewriter.replaceOp(op, inputs); return success(); } @@ -369,11 +367,7 @@ struct UnrealizedConversionCastOpPattern // Handles the case where cast %1 : vector<16xf32>, ... to vector<256xf32>. // All input values must have the same vector type, and their shape must be // evenly divisible by the output vector's shape. - auto inputTy = dyn_cast(inputs[0].getType()); - auto outputTy = dyn_cast(op->getOpResult(0).getType()); - if (op.getNumResults() == 1 && inputTy && outputTy && - llvm::all_equal(ValueRange(inputs).getTypes()) && - computeShapeRatio(outputTy.getShape(), inputTy.getShape())) { + if (op.getNumResults() == 1) { rewriter.replaceOpWithMultiple(op, {inputs}); return success(); } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index ff86e65300bb8..64e5388a2c30f 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -103,6 +103,27 @@ gpu.module @test_round_robin_assignment { gpu.return } + gpu.func @test_scf_for(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %1 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + // CHECK-LABEL: scf.for + // CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>) + %2:2 = scf.for %arg2 = %c0 to %c1024 step %c256 iter_args(%arg3 = %0, %arg4 = %1) + -> (!xegpu.tensor_desc<256xf32, #xegpu.layout>, !xegpu.tensor_desc<256xf32, #xegpu.layout>) { + %3 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + xegpu.store_nd %3, %arg3 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + %4 = xegpu.update_nd_offset %arg3, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout> + %5 = xegpu.update_nd_offset %arg4, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout> + scf.yield %4, %5 : !xegpu.tensor_desc<256xf32, #xegpu.layout>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + } + gpu.return + } + gpu.func @test_scf_while_and_condition(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { %c1_i32 = arith.constant 1 : i32 %c10_i32 = arith.constant 10 : i32 From 605eee03bc029648c5812b229af3f6f1802e9fc1 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 4 Jun 2025 21:27:06 +0000 Subject: [PATCH 07/14] refine --- .../Transforms/XeGPUWgToSgDistribute.cpp | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 11d723ec0aabe..efdeff2615198 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" @@ -329,13 +330,12 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern { // Handles UnrealizedConversionCastOp generated during // SCFStructuralTypeConversions (step 1). This op may appear as either a -// target or source materialization for Vector or TensorDesc values, e.g.: +// target or source materialization for Vector values, e.g.: // 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ... // 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32> -// it could be either 1:1, 1:N or N:1 cast. In all cases, the pattern +// it could be either 1:N or N:1 cast. In both cases, the pattern // simply forwards the inputs to the outputs using 1:1 or 1:N interface. // TODO: remove it when context-aware type converter is ready. -// It is safe only when input codes don't contain UnrealizedConversionCastOp. struct UnrealizedConversionCastOpPattern : public OpConversionPattern { using OpConversionPattern< @@ -346,20 +346,19 @@ struct UnrealizedConversionCastOpPattern ConversionPatternRewriter &rewriter) const override { SmallVector inputs = xegpu::flattenValues(adaptor.getInputs()); - auto inputTy = inputs[0].getType(); - auto outputTy = op->getOpResult(0).getType(); + auto inputTy = dyn_cast(inputs[0].getType()); + auto outputTy = dyn_cast(op->getOpResult(0).getType()); - if (!llvm::all_equal(op->getResultTypes()) || - !llvm::all_equal(ValueRange(inputs).getTypes()) || - !isa(inputTy) || - !isa(outputTy)) + if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) || + !llvm::all_equal(ValueRange(inputs).getTypes())) return failure(); // Handles the case where cast %1 : vector<256xf32> to vector<16xf32>, ... // the input values provided by the adaptor should already be distributed, // and their types should correspond exactly to the result types of the // operation. - if (op.getNumOperands() == 1) { + if (op.getNumOperands() == 1 && + llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) { rewriter.replaceOp(op, inputs); return success(); } @@ -367,7 +366,10 @@ struct UnrealizedConversionCastOpPattern // Handles the case where cast %1 : vector<16xf32>, ... to vector<256xf32>. // All input values must have the same vector type, and their shape must be // evenly divisible by the output vector's shape. - if (op.getNumResults() == 1) { + // TODO: it is not safe to do such forward, since such N:1 cast could be + // from others + if (op.getNumResults() == 1 && + computeShapeRatio(outputTy.getShape(), inputTy.getShape())) { rewriter.replaceOpWithMultiple(op, {inputs}); return success(); } @@ -396,6 +398,7 @@ struct XeGPUWgToSgDistributePass } // namespace void XeGPUWgToSgDistributePass::runOnOperation() { + TypeConverter converter; converter.addConversion([&](Type type) -> Type { return type; }); converter.addConversion( @@ -414,6 +417,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return success(); }); + // Step 1: Apply SCFStructuralTypeConversions to SCF operations with + // VectorType operands. This first converts such operands to RankedTensorType, + // propagates the layout attribute into the encoding attribute, and finally + // converts the RankedTensorType to VectorType based on the encoding. + xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter); + + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + ConversionTarget target(*ctx); + converter.addConversion( [&](xegpu::TensorDescType type, SmallVectorImpl &result) -> std::optional { @@ -434,13 +447,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return success(); }); - // step1: perform SCFStructuralTypeConversions on SCF ops - xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter); - - MLIRContext *ctx = &getContext(); - RewritePatternSet patterns(ctx); - ConversionTarget target(*ctx); - auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType { if (auto createOp = dyn_cast(op)) return createOp.getType(); @@ -476,7 +482,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() { target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - // step2: Perform for workgroup to subgroup distribution for rest ops + // Step 2: Perform workgroup to subgroup distribution for TensorDesc values, + // as well as XeGPU, Arith, and Vector operations. + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, + target); xegpu::populateXeGPUWgToSgDistributePatterns(patterns); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) From 87934c4461cee6274c22475be2fcad81cf5efee6 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 5 Jun 2025 14:30:36 +0000 Subject: [PATCH 08/14] refine --- .../Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index efdeff2615198..8f400e525d480 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -398,6 +398,11 @@ struct XeGPUWgToSgDistributePass } // namespace void XeGPUWgToSgDistributePass::runOnOperation() { + // Track existing UnrealizedConversionCastOps + SmallVector existingCastOps; + getOperation()->walk([&](UnrealizedConversionCastOp castOp) { + existingCastOps.push_back(castOp.getOperation()); + }); TypeConverter converter; converter.addConversion([&](Type type) -> Type { return type; }); @@ -478,7 +483,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addIllegalOp(); + target.addDynamicallyLegalOp( + [=](UnrealizedConversionCastOp op) { + return llvm::is_contained(existingCastOps, op.getOperation()); + }); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); From 1ca7b30540cf007bae260acd59d29bd6cede061f Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 5 Jun 2025 20:42:37 +0000 Subject: [PATCH 09/14] add unit tests for scf.if --- mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 59 ++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index d016d3a30a339..c6103408cb467 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -206,7 +206,6 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { gpu.return } - gpu.func @test_scf_while_and_condition(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { %c1_i32 = arith.constant 1 : i32 %c10_i32 = arith.constant 10 : i32 @@ -232,5 +231,63 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { gpu.return } + gpu.func @test_scf_if(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c10 = arith.constant 10 : index + %id = gpu.subgroup_id : index + + %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %1 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + + %4 = arith.cmpi eq, %id, %c10 : index + // CHECK-LABEL: scf.if + // CHECK-SAME: (vector<16xf32>) + %5 = scf.if %4 -> (vector<256xf32>) { + // CHECK-LABEL: xegpu.load_nd + // CHECK-SAME: !xegpu.tensor_desc<16xf32> -> vector<16xf32> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + // CHECK-LABEL: scf.yield + // CHECK-SAME: vector<16xf32> + scf.yield %2 : vector<256xf32> + } else { + // CHECK-LABEL: xegpu.load_nd + // CHECK-SAME: !xegpu.tensor_desc<16xf32> -> vector<16xf32> + %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + // CHECK-LABEL: scf.yield + // CHECK-SAME: vector<16xf32> + scf.yield %3 : vector<256xf32> + } {layout_result_0 = #xegpu.layout} + xegpu.store_nd %5, %0 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + gpu.return + } + + gpu.func @test_scf_if_tensor_desc(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c10 = arith.constant 10 : index + %id = gpu.subgroup_id : index + + %t = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %d = xegpu.load_nd %t : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + + %0 = arith.cmpi eq, %id, %c10 : index + // CHECK-LABEL: scf.if + // CHECK-SAME: !xegpu.tensor_desc<16xf32> + %1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout>) { + // CHECK-LABEL: xegpu.create_nd_tdesc + // CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32> + %2 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + // CHECK-LABEL: scf.yield + // CHECK-SAME: !xegpu.tensor_desc<16xf32> + scf.yield %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout> + } else { + // CHECK-LABEL: xegpu.create_nd_tdesc + // CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32> + %3 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + // CHECK-LABEL: scf.yield + // CHECK-SAME: !xegpu.tensor_desc<16xf32> + scf.yield %3 : !xegpu.tensor_desc<256xf32, #xegpu.layout> + } + xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + gpu.return + } + } From e650862dd4ecf8fb3c9f4775d9c975eb10355682 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 5 Jun 2025 20:52:31 +0000 Subject: [PATCH 10/14] add round-robin unit tests for scf.if --- .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 49 +++++++++++++++++++ mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 2 +- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index 64e5388a2c30f..a7b48f51c4c49 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -147,4 +147,53 @@ gpu.module @test_round_robin_assignment { } gpu.return } + + gpu.func @test_scf_if(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c10 = arith.constant 10 : index + %0 = gpu.subgroup_id : index + %1 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %3 = arith.cmpi eq, %0, %c10 : index + // CHECK-LABEL: scf.if + // CHECK-SAME: (vector<16xf32>, vector<16xf32>) + %4 = scf.if %3 -> (vector<256xf32>) { + %5 = xegpu.load_nd %1 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + // CHECK-LABEL: scf.yield + // CHECK-SAME: vector<16xf32>, vector<16xf32> + scf.yield %5 : vector<256xf32> + } else { + %5 = xegpu.load_nd %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + // CHECK-LABEL: scf.yield + // CHECK-SAME: vector<16xf32>, vector<16xf32> + scf.yield %5 : vector<256xf32> + } {layout_result_0 = #xegpu.layout} + xegpu.store_nd %4, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + gpu.return + } + + gpu.func @test_scf_if_tensor_desc(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c10 = arith.constant 10 : index + %id = gpu.subgroup_id : index + + %t = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %d = xegpu.load_nd %t : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + + %0 = arith.cmpi eq, %id, %c10 : index + // CHECK-LABEL: scf.if + // CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>) + %1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout>) { + %2 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + // CHECK-LABEL: scf.yield + // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32> + scf.yield %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout> + } else { + %3 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + // CHECK-LABEL: scf.yield + // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32> + scf.yield %3 : !xegpu.tensor_desc<256xf32, #xegpu.layout> + } + xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + gpu.return + } + } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index c6103408cb467..e80bb065db230 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -269,7 +269,7 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { %0 = arith.cmpi eq, %id, %c10 : index // CHECK-LABEL: scf.if - // CHECK-SAME: !xegpu.tensor_desc<16xf32> + // CHECK-SAME: (!xegpu.tensor_desc<16xf32>) %1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout>) { // CHECK-LABEL: xegpu.create_nd_tdesc // CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32> From 3544073b041f2bf960ad68009556a856a1214e05 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 5 Jun 2025 21:34:57 +0000 Subject: [PATCH 11/14] update layout info --- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 8f400e525d480..981fb368d3fea 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -498,4 +498,20 @@ void XeGPUWgToSgDistributePass::runOnOperation() { if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); + + // Remove sg_layout and sg_data attributes from the Layout + // attribute for each VectorType result of the operation. + // For Structured Control Flow ops, the layout is simply removed, + // since in 1:N case, the layout for new results are missing. + // Layout propagation pass will activated. + getOperation()->walk([](Operation *op) { + for (OpResult result : op->getOpResults()) { + std::string name = xegpu::getLayoutName(result); + if (auto layout = op->getAttrOfType(name)) { + op->removeAttr(name); + if (!isa(op)) + op->setAttr(name, layout.dropInstData()); + } + } + }); } From 8c91b39b2cc807d1b9e6ea728ec43ea288ddf1a5 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Tue, 10 Jun 2025 21:06:28 +0000 Subject: [PATCH 12/14] address comments --- .../Transforms/XeGPUWgToSgDistribute.cpp | 40 +++++++++++++++---- .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 2 + mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 29 +++++++++----- 3 files changed, 53 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 981fb368d3fea..3ec5cf9e587bd 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -335,6 +335,29 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern { // 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32> // it could be either 1:N or N:1 cast. In both cases, the pattern // simply forwards the inputs to the outputs using 1:1 or 1:N interface. +// for example, the following scf::forOp +// ``` +// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) { +// %n = use(%arg1): vector<128x128xf16> +// scf.yield %n : vector<128x128xf16> +// } +// ``` +// Could be converted to: +// ``` +// %1 = unrealized_conversion_cast %0 +// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16> +// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2) +// -> (vector<16x16xf16>, vector<16x16xf16) { +// %m = unrealized_conversion_cast %arg1, %arg2 +// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16> +// %n = use(%m): vector<128x128xf16> +// %b = unrealized_conversion_cast %n +// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16> +// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16> +// } +// %cast = unrealized_conversion_cast %for:2 +// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16> +// ``` // TODO: remove it when context-aware type converter is ready. struct UnrealizedConversionCastOpPattern : public OpConversionPattern { @@ -353,8 +376,9 @@ struct UnrealizedConversionCastOpPattern !llvm::all_equal(ValueRange(inputs).getTypes())) return failure(); - // Handles the case where cast %1 : vector<256xf32> to vector<16xf32>, ... - // the input values provided by the adaptor should already be distributed, + // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...". + // It is generated by source materialization (e.g., inits to scf forOp). + // The input values provided by the adaptor should already be distributed, // and their types should correspond exactly to the result types of the // operation. if (op.getNumOperands() == 1 && @@ -363,11 +387,13 @@ struct UnrealizedConversionCastOpPattern return success(); } - // Handles the case where cast %1 : vector<16xf32>, ... to vector<256xf32>. - // All input values must have the same vector type, and their shape must be - // evenly divisible by the output vector's shape. + // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>". + // It is generated by target materialization (e.g., arguments/results + // of scf forOp). All input values must have the same vector type, and + // their shape must be evenly divisible by the output vector's shape + // (determined by the nature of the workgroup to subgroup distribution). // TODO: it is not safe to do such forward, since such N:1 cast could be - // from others + // from others. if (op.getNumResults() == 1 && computeShapeRatio(outputTy.getShape(), inputTy.getShape())) { rewriter.replaceOpWithMultiple(op, {inputs}); @@ -510,7 +536,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { if (auto layout = op->getAttrOfType(name)) { op->removeAttr(name); if (!isa(op)) - op->setAttr(name, layout.dropInstData()); + op->setAttr(name, layout.dropSgLayoutAndData()); } } }); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index a7b48f51c4c49..35ad16d8cd9a9 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -119,6 +119,8 @@ gpu.module @test_round_robin_assignment { xegpu.store_nd %3, %arg3 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> %4 = xegpu.update_nd_offset %arg3, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout> %5 = xegpu.update_nd_offset %arg4, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout> + // CHECK-LABEL: scf.yield + // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32> scf.yield %4, %5 : !xegpu.tensor_desc<256xf32, #xegpu.layout>, !xegpu.tensor_desc<256xf32, #xegpu.layout> } gpu.return diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index e80bb065db230..466842c968448 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -186,22 +186,29 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { %4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout> %5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout> - //CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]] iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) -> (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>) - //CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16> - //CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16> - //CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32> - //CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16> - //CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16> - //CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32> - %6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout>, !xegpu.tensor_desc<128x128xf16, #xegpu.layout>, vector<128x128xf32>) { + // CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]] + // CHECK-SAME: iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) -> + // CHECK-SAME: (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>) + // CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16> + // CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16> + // CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32> + // CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16> + // CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16> + // CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) + -> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout>, + !xegpu.tensor_desc<128x128xf16, #xegpu.layout>, vector<128x128xf32>) { %8 = xegpu.load_nd %arg4 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> -> vector<128x128xf16> %9 = xegpu.load_nd %arg5 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> -> vector<128x128xf16> - %10 = xegpu.dpas %8, %9, %arg6 {layout_result_0 = #xegpu.layout} : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32> + %10 = xegpu.dpas %8, %9, %arg6 {layout_result_0 = #xegpu.layout} + : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32> %11 = xegpu.update_nd_offset %arg4, [%c0, %c128] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> %12 = xegpu.update_nd_offset %arg5, [%c128, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> - scf.yield %11, %12, %10 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout>, !xegpu.tensor_desc<128x128xf16, #xegpu.layout>, vector<128x128xf32> + scf.yield %11, %12, %10 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout>, + !xegpu.tensor_desc<128x128xf16, #xegpu.layout>, vector<128x128xf32> } - %7 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout> + %7 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> + -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout> xegpu.store_nd %6#2, %7 : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout> gpu.return } From 89cac4d4a4fa1c3c52c92522f17aa802a04f0880 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Tue, 10 Jun 2025 23:03:23 +0000 Subject: [PATCH 13/14] fix --- mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 3ec5cf9e587bd..b3717143cdf0d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -441,7 +441,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { int count; SmallVector subShape; std::tie(subShape, count) = getSgShapeAndCount( - shape, dyn_cast(type.getEncoding())); + shape, dyn_cast_if_present(type.getEncoding())); auto newTy = VectorType::get(subShape, elemTy); result.append(count, newTy); From f39fe3c22993f56ac4b4f67493296105bd59de7a Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Fri, 13 Jun 2025 16:01:06 +0000 Subject: [PATCH 14/14] address comments --- .../Transforms/XeGPUWgToSgDistribute.cpp | 59 +++++++++++-------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index b3717143cdf0d..a26c6b52f0ddc 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -430,34 +430,43 @@ void XeGPUWgToSgDistributePass::runOnOperation() { existingCastOps.push_back(castOp.getOperation()); }); - TypeConverter converter; - converter.addConversion([&](Type type) -> Type { return type; }); - converter.addConversion( - [&](RankedTensorType type, - SmallVectorImpl &result) -> std::optional { - Type elemTy = type.getElementType(); - ArrayRef shape = type.getShape(); - - int count; - SmallVector subShape; - std::tie(subShape, count) = getSgShapeAndCount( - shape, dyn_cast_if_present(type.getEncoding())); - - auto newTy = VectorType::get(subShape, elemTy); - result.append(count, newTy); - return success(); - }); - - // Step 1: Apply SCFStructuralTypeConversions to SCF operations with - // VectorType operands. This first converts such operands to RankedTensorType, - // propagates the layout attribute into the encoding attribute, and finally - // converts the RankedTensorType to VectorType based on the encoding. - xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter); + { + // Step 1: Apply SCFStructuralTypeConversions to SCF operations with + // VectorType operands. This first converts such operands to + // RankedTensorType, propagates the layout attribute into the encoding + // attribute, and finally converts the RankedTensorType to VectorType based + // on the encoding. + + TypeConverter converter; + converter.addConversion([&](Type type) -> Type { return type; }); + converter.addConversion( + [&](RankedTensorType type, + SmallVectorImpl &result) -> std::optional { + Type elemTy = type.getElementType(); + ArrayRef shape = type.getShape(); + + int count; + SmallVector subShape; + std::tie(subShape, count) = getSgShapeAndCount( + shape, + dyn_cast_if_present(type.getEncoding())); + + auto newTy = VectorType::get(subShape, elemTy); + result.append(count, newTy); + return success(); + }); + + xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), + converter); + } + // Step 2: Perform workgroup to subgroup distribution for TensorDesc values, + // as well as XeGPU, Arith, and Vector operations. MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); ConversionTarget target(*ctx); - + TypeConverter converter; + converter.addConversion([&](Type type) -> Type { return type; }); converter.addConversion( [&](xegpu::TensorDescType type, SmallVectorImpl &result) -> std::optional { @@ -516,8 +525,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() { target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - // Step 2: Perform workgroup to subgroup distribution for TensorDesc values, - // as well as XeGPU, Arith, and Vector operations. scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, target); xegpu::populateXeGPUWgToSgDistributePatterns(patterns);