diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h index f9327d63869c0..6fea10185402a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h @@ -26,6 +26,9 @@ class TensorDescType; namespace xegpu { +/// Flatten a set of ValueRange into a single SmallVector +SmallVector flattenValues(ArrayRef values); + /// If tensor descriptor has a layout attribute it is used in SIMT mode. /// In this mode, the distributed vector shape is determined as follows: /// Definitions: diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 3bf76af674ba0..a26c6b52f0ddc 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -13,9 +13,11 @@ #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -29,6 +31,29 @@ using namespace mlir; namespace { +static std::pair, int> +getSgShapeAndCount(ArrayRef shape, xegpu::LayoutAttr layout) { + int count = 1; + SmallVector sgShape(shape); + + if (layout && layout.isWgLayout()) { + DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout(); + auto sgLayout = llvm::to_vector_of(sgLayoutAttr.asArrayRef()); + if (DenseI32ArrayAttr sgDataAttr = layout.getSgData()) + sgShape = llvm::to_vector_of(sgDataAttr.asArrayRef()); + else + sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape); + SmallVector distUnit = computeElementwiseMul(sgLayout, sgShape); + // Clamp distUnit to the original shape to handle cases where data is + // shared among subgroups, which may cause distUnit to exceed the original + // shape. + for (size_t i = 0; i < distUnit.size(); ++i) + distUnit[i] = std::min(shape[i], distUnit[i]); + count = computeProduct(shape) / computeProduct(distUnit); + } + return std::make_pair(sgShape, count); +} + /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor /// from a workgroup descriptor. It replaces the offsets and sizes with /// appropriate values for the subgroup. @@ -129,18 +154,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "sgLayout attribute is required in layout"); - SmallVector sgShape; - if (auto sgDataAttr = layout.getSgData()) { - sgShape = llvm::to_vector_of(sgDataAttr.asArrayRef()); - } else { - assert(wgShape.size() == sgLayout.size() && - "sgLayout and wgShape must have the same rank"); - sgShape.reserve(wgShape.size()); - for (size_t i = 0; i < wgShape.size(); ++i) { - assert(sgLayout[i] != 0 && "sgLayout elements must be non-zero"); - sgShape.push_back(wgShape[i] / sgLayout[i]); - } - } + SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; // TODO : Handle order attribute // Get the subgroup ID @@ -266,15 +280,15 @@ struct WgToSgDpasOp : public OpConversionPattern { if (resultTy.getRank() != 2) return failure(); - auto originalLayout = - llvm::dyn_cast_or_null(op->getAttr("layout")); + auto originalLayout = xegpu::getLayoutAttr(op.getResult()); if (!originalLayout) return failure(); - SmallVector newDpasOps; size_t i = 0; + SmallVector newDpasOps; for (auto aVec : adaptor.getLhs()) { for (auto bVec : adaptor.getRhs()) { + llvm::SmallVector operands({aVec, bVec}); Value tmpC; if (op.getAcc()) { @@ -288,10 +302,10 @@ struct WgToSgDpasOp : public OpConversionPattern { llvm::cast(bVec.getType()).getShape(); VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]}, resultTy.getElementType()); - tmpC = rewriter.create( - loc, resTy, operands, - llvm::ArrayRef( - {"layout_result_0", originalLayout.dropSgLayoutAndData()})); + tmpC = rewriter.create(loc, resTy, operands); + xegpu::setLayoutAttr(cast(tmpC), + originalLayout.dropSgLayoutAndData()); + newDpasOps.push_back(tmpC); } } @@ -314,14 +328,90 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern { } }; +// Handles UnrealizedConversionCastOp generated during +// SCFStructuralTypeConversions (step 1). This op may appear as either a +// target or source materialization for Vector values, e.g.: +// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ... +// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32> +// it could be either 1:N or N:1 cast. In both cases, the pattern +// simply forwards the inputs to the outputs using 1:1 or 1:N interface. +// for example, the following scf::forOp +// ``` +// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) { +// %n = use(%arg1): vector<128x128xf16> +// scf.yield %n : vector<128x128xf16> +// } +// ``` +// Could be converted to: +// ``` +// %1 = unrealized_conversion_cast %0 +// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16> +// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2) +// -> (vector<16x16xf16>, vector<16x16xf16) { +// %m = unrealized_conversion_cast %arg1, %arg2 +// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16> +// %n = use(%m): vector<128x128xf16> +// %b = unrealized_conversion_cast %n +// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16> +// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16> +// } +// %cast = unrealized_conversion_cast %for:2 +// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16> +// ``` +// TODO: remove it when context-aware type converter is ready. +struct UnrealizedConversionCastOpPattern + : public OpConversionPattern { + using OpConversionPattern< + mlir::UnrealizedConversionCastOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector inputs = xegpu::flattenValues(adaptor.getInputs()); + + auto inputTy = dyn_cast(inputs[0].getType()); + auto outputTy = dyn_cast(op->getOpResult(0).getType()); + + if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) || + !llvm::all_equal(ValueRange(inputs).getTypes())) + return failure(); + + // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...". + // It is generated by source materialization (e.g., inits to scf forOp). + // The input values provided by the adaptor should already be distributed, + // and their types should correspond exactly to the result types of the + // operation. + if (op.getNumOperands() == 1 && + llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) { + rewriter.replaceOp(op, inputs); + return success(); + } + + // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>". + // It is generated by target materialization (e.g., arguments/results + // of scf forOp). All input values must have the same vector type, and + // their shape must be evenly divisible by the output vector's shape + // (determined by the nature of the workgroup to subgroup distribution). + // TODO: it is not safe to do such forward, since such N:1 cast could be + // from others. + if (op.getNumResults() == 1 && + computeShapeRatio(outputTy.getShape(), inputTy.getShape())) { + rewriter.replaceOpWithMultiple(op, {inputs}); + return success(); + } + + return mlir::failure(); + } +}; + } // namespace namespace mlir { namespace xegpu { void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { patterns.add( - patterns.getContext()); + WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, + UnrealizedConversionCastOpPattern>(patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -334,9 +424,68 @@ struct XeGPUWgToSgDistributePass } // namespace void XeGPUWgToSgDistributePass::runOnOperation() { + // Track existing UnrealizedConversionCastOps + SmallVector existingCastOps; + getOperation()->walk([&](UnrealizedConversionCastOp castOp) { + existingCastOps.push_back(castOp.getOperation()); + }); + + { + // Step 1: Apply SCFStructuralTypeConversions to SCF operations with + // VectorType operands. This first converts such operands to + // RankedTensorType, propagates the layout attribute into the encoding + // attribute, and finally converts the RankedTensorType to VectorType based + // on the encoding. + + TypeConverter converter; + converter.addConversion([&](Type type) -> Type { return type; }); + converter.addConversion( + [&](RankedTensorType type, + SmallVectorImpl &result) -> std::optional { + Type elemTy = type.getElementType(); + ArrayRef shape = type.getShape(); + + int count; + SmallVector subShape; + std::tie(subShape, count) = getSgShapeAndCount( + shape, + dyn_cast_if_present(type.getEncoding())); + + auto newTy = VectorType::get(subShape, elemTy); + result.append(count, newTy); + return success(); + }); + + xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), + converter); + } + + // Step 2: Perform workgroup to subgroup distribution for TensorDesc values, + // as well as XeGPU, Arith, and Vector operations. MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); ConversionTarget target(*ctx); + TypeConverter converter; + converter.addConversion([&](Type type) -> Type { return type; }); + converter.addConversion( + [&](xegpu::TensorDescType type, + SmallVectorImpl &result) -> std::optional { + Type elemTy = type.getElementType(); + ArrayRef shape = type.getShape(); + + int count; + SmallVector subShape; + xegpu::LayoutAttr layout = type.getLayoutAttr(); + std::tie(subShape, count) = getSgShapeAndCount(shape, layout); + + if (layout) + layout = layout.dropSgLayoutAndData(); + + auto newTy = xegpu::TensorDescType::get( + type.getContext(), subShape, elemTy, type.getEncoding(), layout); + result.append(count, newTy); + return success(); + }); auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType { if (auto createOp = dyn_cast(op)) @@ -353,26 +502,49 @@ void XeGPUWgToSgDistributePass::runOnOperation() { }; auto isLegal = [&](xegpu::LayoutAttr layout) -> bool { - return !layout || layout.getSgLayout() == nullptr; + return !layout || !layout.isWgLayout(); }; target.addDynamicallyLegalOp([=](Operation *op) -> bool { auto tdescTy = getTensorDescType(op); - auto layout = dyn_cast_or_null(tdescTy.getLayout()); + auto layout = dyn_cast_if_present(tdescTy.getLayout()); return isLegal(layout); }); target.addDynamicallyLegalOp([=](xegpu::DpasOp op) -> bool { - auto layout = dyn_cast_or_null(op->getAttr("layout")); + auto layout = xegpu::getLayoutAttr(op.getResult()); return isLegal(layout); }); + target.addDynamicallyLegalOp( + [=](UnrealizedConversionCastOp op) { + return llvm::is_contained(existingCastOps, op.getOperation()); + }); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, + target); xegpu::populateXeGPUWgToSgDistributePatterns(patterns); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); + + // Remove sg_layout and sg_data attributes from the Layout + // attribute for each VectorType result of the operation. + // For Structured Control Flow ops, the layout is simply removed, + // since in 1:N case, the layout for new results are missing. + // Layout propagation pass will activated. + getOperation()->walk([](Operation *op) { + for (OpResult result : op->getOpResults()) { + std::string name = xegpu::getLayoutName(result); + if (auto layout = op->getAttrOfType(name)) { + op->removeAttr(name); + if (!isa(op)) + op->setAttr(name, layout.dropSgLayoutAndData()); + } + } + }); } diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index dcaf4e85a82c5..6b85a66a8bd36 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -27,7 +27,7 @@ using namespace mlir; /// convert ArrayRef into SmallVector -static SmallVector flattenValues(ArrayRef values) { +SmallVector xegpu::flattenValues(ArrayRef values) { SmallVector result; for (const auto &vals : values) llvm::append_range(result, vals); @@ -271,7 +271,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( auto resultTy = dyn_cast(result.getType()); // Only look at ops casting from VectorType to RankedTensorType - if (!isa(inputTy) || !isa(resultTy)) + if (!inputTy || !resultTy) return WalkResult::skip(); xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input); @@ -342,7 +342,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( } if (isa(inputTy) && isa(outputTy)) { - SmallVector values = flattenValues(adaptor.getInputs()); + SmallVector values = xegpu::flattenValues(adaptor.getInputs()); auto newOp = rewriter.create( op.getLoc(), outputTy, values); rewriter.replaceOp(op, newOp); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index bee026eb2084d..35ad16d8cd9a9 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -85,7 +85,7 @@ gpu.module @test_round_robin_assignment { %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout> %dpas = xegpu.dpas %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32> gpu.return } @@ -102,4 +102,100 @@ gpu.module @test_round_robin_assignment { : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> gpu.return } + + gpu.func @test_scf_for(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %1 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + // CHECK-LABEL: scf.for + // CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>) + %2:2 = scf.for %arg2 = %c0 to %c1024 step %c256 iter_args(%arg3 = %0, %arg4 = %1) + -> (!xegpu.tensor_desc<256xf32, #xegpu.layout>, !xegpu.tensor_desc<256xf32, #xegpu.layout>) { + %3 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + xegpu.store_nd %3, %arg3 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + %4 = xegpu.update_nd_offset %arg3, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout> + %5 = xegpu.update_nd_offset %arg4, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout> + // CHECK-LABEL: scf.yield + // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32> + scf.yield %4, %5 : !xegpu.tensor_desc<256xf32, #xegpu.layout>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + } + gpu.return + } + + gpu.func @test_scf_while_and_condition(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c1_i32 = arith.constant 1 : i32 + %c10_i32 = arith.constant 10 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + //CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32) + %3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) { + %4 = arith.cmpi slt, %arg3, %c10_i32 : i32 + //CHECK: scf.condition{{.*}} : vector<16xf32>, vector<16xf32>, i32 + scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32 + } do { + // CHECK: ([[arg2:%.+]]: vector<16xf32>, [[arg3:%.+]]: vector<16xf32>, [[arg4:%.+]]: i32) + ^bb0(%arg2: vector<256xf32>, %arg3: i32): + xegpu.store_nd %arg2, %2 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + %4 = arith.addi %arg3, %c1_i32 : i32 + %5 = xegpu.update_nd_offset %0, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout> + %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + scf.yield %6, %4 : vector<256xf32>, i32 + } + gpu.return + } + + gpu.func @test_scf_if(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c10 = arith.constant 10 : index + %0 = gpu.subgroup_id : index + %1 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %3 = arith.cmpi eq, %0, %c10 : index + // CHECK-LABEL: scf.if + // CHECK-SAME: (vector<16xf32>, vector<16xf32>) + %4 = scf.if %3 -> (vector<256xf32>) { + %5 = xegpu.load_nd %1 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + // CHECK-LABEL: scf.yield + // CHECK-SAME: vector<16xf32>, vector<16xf32> + scf.yield %5 : vector<256xf32> + } else { + %5 = xegpu.load_nd %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + // CHECK-LABEL: scf.yield + // CHECK-SAME: vector<16xf32>, vector<16xf32> + scf.yield %5 : vector<256xf32> + } {layout_result_0 = #xegpu.layout} + xegpu.store_nd %4, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + gpu.return + } + + gpu.func @test_scf_if_tensor_desc(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c10 = arith.constant 10 : index + %id = gpu.subgroup_id : index + + %t = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %d = xegpu.load_nd %t : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + + %0 = arith.cmpi eq, %id, %c10 : index + // CHECK-LABEL: scf.if + // CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>) + %1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout>) { + %2 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + // CHECK-LABEL: scf.yield + // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32> + scf.yield %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout> + } else { + %3 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + // CHECK-LABEL: scf.yield + // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32> + scf.yield %3 : !xegpu.tensor_desc<256xf32, #xegpu.layout> + } + xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + gpu.return + } + } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index 7e89ada934071..466842c968448 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -5,7 +5,7 @@ gpu.module @test_1_1_assignment { // CHECK-LABEL: test_create_nd_tdesc // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) { + gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) { // CHECK: %[[SGID:.*]] = gpu.subgroup_id // CHECK: %[[C12:.*]] = arith.constant 12 : index // CHECK: %[[C4:.*]] = arith.constant 4 : index @@ -108,7 +108,7 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { : !xegpu.tensor_desc<32x24xf32, #xegpu.layout> -> vector<32x24xf32> %dpas = xegpu.dpas %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> gpu.return } @@ -142,7 +142,7 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { : !xegpu.tensor_desc<32x24xf32, #xegpu.layout> -> vector<32x24xf32> %dpas = xegpu.dpas %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> gpu.return } @@ -169,4 +169,132 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> gpu.return } + + gpu.func @test_scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) { + //CHECK: [[c0:%.+]] = arith.constant 0 : index + //CHECK: [[c128:%.+]] = arith.constant 128 : index + //CHECK: [[c1024:%.+]] = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c128 : index + %1 = arith.muli %block_id_y, %c128 : index + %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<128x128xf32, #xegpu.layout> -> vector<128x128xf32> + %4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + %5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + + // CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]] + // CHECK-SAME: iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) -> + // CHECK-SAME: (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>) + // CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16> + // CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16> + // CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32> + // CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16> + // CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16> + // CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) + -> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout>, + !xegpu.tensor_desc<128x128xf16, #xegpu.layout>, vector<128x128xf32>) { + %8 = xegpu.load_nd %arg4 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> -> vector<128x128xf16> + %9 = xegpu.load_nd %arg5 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> -> vector<128x128xf16> + %10 = xegpu.dpas %8, %9, %arg6 {layout_result_0 = #xegpu.layout} + : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32> + %11 = xegpu.update_nd_offset %arg4, [%c0, %c128] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + %12 = xegpu.update_nd_offset %arg5, [%c128, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + scf.yield %11, %12, %10 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout>, + !xegpu.tensor_desc<128x128xf16, #xegpu.layout>, vector<128x128xf32> + } + %7 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> + -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout> + xegpu.store_nd %6#2, %7 : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout> + gpu.return + } + + gpu.func @test_scf_while_and_condition(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c1_i32 = arith.constant 1 : i32 + %c10_i32 = arith.constant 10 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + + // CHECK: scf.while {{.*}} : (vector<16xf32>, i32) -> (vector<16xf32>, i32) + %3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) { + %4 = arith.cmpi slt, %arg3, %c10_i32 : i32 + // CHECK: scf.condition{{.*}} : vector<16xf32>, i32 + scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32 + } do { + // CHECK: ([[arg2:%.+]]: vector<16xf32>, [[arg3:%.+]]: i32) + ^bb0(%arg2: vector<256xf32>, %arg3: i32): + xegpu.store_nd %arg2, %2 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + %4 = arith.addi %arg3, %c1_i32 : i32 + %5 = xegpu.update_nd_offset %0, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout> + %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + scf.yield %6, %4 : vector<256xf32>, i32 + } + gpu.return + } + + gpu.func @test_scf_if(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c10 = arith.constant 10 : index + %id = gpu.subgroup_id : index + + %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %1 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + + %4 = arith.cmpi eq, %id, %c10 : index + // CHECK-LABEL: scf.if + // CHECK-SAME: (vector<16xf32>) + %5 = scf.if %4 -> (vector<256xf32>) { + // CHECK-LABEL: xegpu.load_nd + // CHECK-SAME: !xegpu.tensor_desc<16xf32> -> vector<16xf32> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + // CHECK-LABEL: scf.yield + // CHECK-SAME: vector<16xf32> + scf.yield %2 : vector<256xf32> + } else { + // CHECK-LABEL: xegpu.load_nd + // CHECK-SAME: !xegpu.tensor_desc<16xf32> -> vector<16xf32> + %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + // CHECK-LABEL: scf.yield + // CHECK-SAME: vector<16xf32> + scf.yield %3 : vector<256xf32> + } {layout_result_0 = #xegpu.layout} + xegpu.store_nd %5, %0 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + gpu.return + } + + gpu.func @test_scf_if_tensor_desc(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { + %c10 = arith.constant 10 : index + %id = gpu.subgroup_id : index + + %t = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + %d = xegpu.load_nd %t : !xegpu.tensor_desc<256xf32, #xegpu.layout> -> vector<256xf32> + + %0 = arith.cmpi eq, %id, %c10 : index + // CHECK-LABEL: scf.if + // CHECK-SAME: (!xegpu.tensor_desc<16xf32>) + %1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout>) { + // CHECK-LABEL: xegpu.create_nd_tdesc + // CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32> + %2 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + // CHECK-LABEL: scf.yield + // CHECK-SAME: !xegpu.tensor_desc<16xf32> + scf.yield %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout> + } else { + // CHECK-LABEL: xegpu.create_nd_tdesc + // CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32> + %3 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout> + // CHECK-LABEL: scf.yield + // CHECK-SAME: !xegpu.tensor_desc<16xf32> + scf.yield %3 : !xegpu.tensor_desc<256xf32, #xegpu.layout> + } + xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> + gpu.return + } + + }