From 08f7eb9752e682cb6b3b6c4f40fad613a0b0d940 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 2 Jun 2025 18:05:59 +0000 Subject: [PATCH 01/13] Add support elementwise ops in Wg to Sg distribute pass --- .../Transforms/XeGPUWgToSgDistribute.cpp | 223 ++++ .../XeGPU/xegpu-wg-to-sg-elemwise.mlir | 971 ++++++++++++++++++ 2 files changed, 1194 insertions(+) create mode 100644 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 3bf76af674ba0..972394a7b40ad 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -8,15 +8,18 @@ #include "mlir/Dialect/XeGPU/Transforms/Passes.h" #include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" #include "mlir/Transforms/DialectConversion.h" +#include namespace mlir { namespace xegpu { @@ -314,6 +317,179 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern { } }; +// This pattern matches elementwise ops (unary/binary) in math/arith dialects +// with 1D or 2D vector types +template +struct WgToSgElementwiseOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OneToNOpAdaptor = typename OpConversionPattern::OneToNOpAdaptor; + + LogicalResult + matchAndRewrite(Op op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // All operands/results must be 1D or 2D vectors + auto resultType = dyn_cast(op.getResult().getType()); + if (!resultType || (resultType.getRank() != 1 && resultType.getRank() != 2)) + return rewriter.notifyMatchFailure( + op, "Result type is not a 1D or 2D vector"); + + ArrayRef shape = resultType.getShape(); + for (Value operand : op->getOperands()) { + auto operandType = dyn_cast(operand.getType()); + if (!operandType || operandType.getRank() != resultType.getRank() || + operandType.getShape() != shape) { + return rewriter.notifyMatchFailure( + op, "Operand type is not a 1D or 2D vector with the same shape as " + "result type"); + } + } + + // Check for layout attribute with sgLayout + auto layout = dyn_cast_or_null(op->getAttr("layout")); + if (!layout || !layout.getSgLayout()) + return rewriter.notifyMatchFailure( + op, "Operation does not have a valid layout attribute for subgroup " + "distribution"); + + // Extract sgShape from layout + SmallVector sgShape; + if (auto sgDataAttr = layout.getSgData()) { + sgShape = llvm::to_vector_of(sgDataAttr.asArrayRef()); + } else { + auto sgLayoutArr = layout.getSgLayout(); + sgShape.reserve(shape.size()); + for (size_t i = 0; i < shape.size(); ++i) { + assert(sgLayoutArr[i] != 0 && "sgLayout elements must be non-zero"); + sgShape.push_back(shape[i] / sgLayoutArr[i]); + } + } + + // Each operand is a list of values + size_t numVariants = adaptor.getOperands().empty() + ? 0 + : adaptor.getOperands().front().size(); + for (auto &operandVec : adaptor.getOperands()) + if (operandVec.size() != numVariants) + return rewriter.notifyMatchFailure( + op, "Operand lists have mismatched sizes"); + + SmallVector newResults; + + auto origResultType = dyn_cast(op->getResult(0).getType()); + VectorType newResultType = + origResultType + ? VectorType::get(sgShape, origResultType.getElementType()) + : VectorType::get(sgShape, resultType.getElementType()); + + for (size_t i = 0; i < numVariants; ++i) { + SmallVector operands; + for (auto &operandVec : adaptor.getOperands()) + operands.push_back(operandVec[i]); + + auto newOp = rewriter.create(op.getLoc(), newResultType, operands); + + // Copy all attributes except "layout", and add "layout_result_0" with + // sgLayout/data dropped + for (auto attr : op->getAttrs()) { + if (attr.getName() != "layout") + newOp->setAttr(attr.getName(), attr.getValue()); + } + newOp->setAttr("layout_result_0", layout.dropSgLayoutAndData()); + + newResults.push_back(newOp.getResult()); + } + + rewriter.replaceOpWithMultiple(op, {newResults}); + return success(); + } +}; + +// ---- ARITH ops ---- +using WgToSgAddFOp = WgToSgElementwiseOp; +using WgToSgSubFOp = WgToSgElementwiseOp; +using WgToSgNegFOp = WgToSgElementwiseOp; +using WgToSgAddIOp = WgToSgElementwiseOp; +using WgToSgSubIOp = WgToSgElementwiseOp; +using WgToSgMulFOp = WgToSgElementwiseOp; +using WgToSgMulIOp = WgToSgElementwiseOp; +using WgToSgShLIOp = WgToSgElementwiseOp; +using WgToSgShRSIOp = WgToSgElementwiseOp; +using WgToSgShRUIOp = WgToSgElementwiseOp; +using WgToSgDivFOp = WgToSgElementwiseOp; +using WgToSgDivSIOp = WgToSgElementwiseOp; +using WgToSgDivUIOp = WgToSgElementwiseOp; +using WgToSgMaximumFOp = WgToSgElementwiseOp; +using WgToSgMinimumFOp = WgToSgElementwiseOp; +using WgToSgRemSIOp = WgToSgElementwiseOp; +using WgToSgRemUIOp = WgToSgElementwiseOp; +using WgToSgTruncFOp = WgToSgElementwiseOp; +using WgToSgTruncIOp = WgToSgElementwiseOp; +using WgToSgExtFOp = WgToSgElementwiseOp; +using WgToSgExtSIOp = WgToSgElementwiseOp; +using WgToSgExtUIOp = WgToSgElementwiseOp; +using WgToSgSIToFPOp = WgToSgElementwiseOp; +using WgToSgUIToFPOp = WgToSgElementwiseOp; +using WgToSgFPToSIOp = WgToSgElementwiseOp; +using WgToSgFPToUIOp = WgToSgElementwiseOp; +using WgToSgIndexCastUIOp = WgToSgElementwiseOp; +using WgToSgIndexCastOp = WgToSgElementwiseOp; +using WgToSgBitcastOp = WgToSgElementwiseOp; +using WgToSgCmpIOp = WgToSgElementwiseOp; +using WgToSgCmpFOp = WgToSgElementwiseOp; +using WgToSgAndIOp = WgToSgElementwiseOp; +using WgToSgCeilDivSIOp = WgToSgElementwiseOp; +using WgToSgCeilDivUIOp = WgToSgElementwiseOp; +using WgToSgFloorDivSIOp = WgToSgElementwiseOp; +using WgToSgMaxNumFOp = WgToSgElementwiseOp; +using WgToSgMaxSIOp = WgToSgElementwiseOp; +using WgToSgMaxUIOp = WgToSgElementwiseOp; +using WgToSgMinNumFOp = WgToSgElementwiseOp; +using WgToSgMinSIOp = WgToSgElementwiseOp; +using WgToSgMinUIOp = WgToSgElementwiseOp; +using WgToSgOrIOp = WgToSgElementwiseOp; +using WgToSgRemFOp = WgToSgElementwiseOp; +using WgToSgSelectOp = WgToSgElementwiseOp; +using WgToSgXOrIOp = WgToSgElementwiseOp; + +// ---- MATH ops ---- +using WgToSgExpOp = WgToSgElementwiseOp; +using WgToSgSqrtOp = WgToSgElementwiseOp; +using WgToSgAbsFOp = WgToSgElementwiseOp; +using WgToSgCosOp = WgToSgElementwiseOp; +using WgToSgCoshOp = WgToSgElementwiseOp; +using WgToSgAcosOp = WgToSgElementwiseOp; +using WgToSgAcoshOp = WgToSgElementwiseOp; +using WgToSgSinOp = WgToSgElementwiseOp; +using WgToSgSinhOp = WgToSgElementwiseOp; +using WgToSgAsinOp = WgToSgElementwiseOp; +using WgToSgAsinhOp = WgToSgElementwiseOp; +using WgToSgTanOp = WgToSgElementwiseOp; +using WgToSgTanhOp = WgToSgElementwiseOp; +using WgToSgAtanOp = WgToSgElementwiseOp; +using WgToSgAtan2Op = WgToSgElementwiseOp; +using WgToSgAtanhOp = WgToSgElementwiseOp; +using WgToSgErfOp = WgToSgElementwiseOp; +using WgToSgLogOp = WgToSgElementwiseOp; +using WgToSgLog2Op = WgToSgElementwiseOp; +using WgToSgFloorOp = WgToSgElementwiseOp; +using WgToSgCeilOp = WgToSgElementwiseOp; +using WgToSgPowFOp = WgToSgElementwiseOp; +using WgToSgRsqrtOp = WgToSgElementwiseOp; +using WgToSgAbsIOp = WgToSgElementwiseOp; +using WgToSgCbrtOp = WgToSgElementwiseOp; +using WgToSgCopySignOp = WgToSgElementwiseOp; +using WgToSgCtPopOp = WgToSgElementwiseOp; +using WgToSgErfcOp = WgToSgElementwiseOp; +using WgToSgExp2Op = WgToSgElementwiseOp; +using WgToSgExpM1Op = WgToSgElementwiseOp; +using WgToSgFPowIOp = WgToSgElementwiseOp; +using WgToSgIPowIOp = WgToSgElementwiseOp; +using WgToSgLog10Op = WgToSgElementwiseOp; +using WgToSgLog1pOp = WgToSgElementwiseOp; +using WgToSgRoundOp = WgToSgElementwiseOp; +using WgToSgRoundEvenOp = WgToSgElementwiseOp; +using WgToSgTruncOp = WgToSgElementwiseOp; + } // namespace namespace mlir { @@ -322,6 +498,27 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { patterns.add( patterns.getContext()); + // Add elementwise operations that can be distributed to subgroups + patterns.add< + WgToSgAddFOp, WgToSgSubFOp, WgToSgExpOp, WgToSgSqrtOp, WgToSgAbsFOp, + WgToSgCosOp, WgToSgCoshOp, WgToSgAcosOp, WgToSgAcoshOp, WgToSgSinOp, + WgToSgSinhOp, WgToSgAsinOp, WgToSgAsinhOp, WgToSgTanOp, WgToSgTanhOp, + WgToSgAtanOp, WgToSgAtan2Op, WgToSgAtanhOp, WgToSgErfOp, WgToSgLogOp, + WgToSgLog2Op, WgToSgFloorOp, WgToSgCeilOp, WgToSgPowFOp, WgToSgRsqrtOp, + WgToSgNegFOp, WgToSgAddIOp, WgToSgSubIOp, WgToSgMulFOp, WgToSgMulIOp, + WgToSgShLIOp, WgToSgShRSIOp, WgToSgShRUIOp, WgToSgDivFOp, WgToSgDivSIOp, + WgToSgDivUIOp, WgToSgMaximumFOp, WgToSgMinimumFOp, WgToSgRemSIOp, + WgToSgRemUIOp, WgToSgTruncFOp, WgToSgTruncIOp, WgToSgExtFOp, + WgToSgExtSIOp, WgToSgExtUIOp, WgToSgSIToFPOp, WgToSgUIToFPOp, + WgToSgFPToSIOp, WgToSgFPToUIOp, WgToSgIndexCastUIOp, WgToSgIndexCastOp, + WgToSgBitcastOp, WgToSgCmpIOp, WgToSgCmpFOp, WgToSgAndIOp, + WgToSgCeilDivSIOp, WgToSgCeilDivUIOp, WgToSgFloorDivSIOp, WgToSgMaxNumFOp, + WgToSgMaxSIOp, WgToSgMaxUIOp, WgToSgMinNumFOp, WgToSgMinSIOp, + WgToSgMinUIOp, WgToSgOrIOp, WgToSgRemFOp, WgToSgSelectOp, WgToSgXOrIOp, + WgToSgAbsIOp, WgToSgCbrtOp, WgToSgCopySignOp, WgToSgCtPopOp, WgToSgErfcOp, + WgToSgExp2Op, WgToSgExpM1Op, WgToSgFPowIOp, WgToSgIPowIOp, WgToSgLog10Op, + WgToSgLog1pOp, WgToSgRoundOp, WgToSgRoundEvenOp, WgToSgTruncOp>( + patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -368,6 +565,32 @@ void XeGPUWgToSgDistributePass::runOnOperation() { auto layout = dyn_cast_or_null(op->getAttr("layout")); return isLegal(layout); }); + target.addDynamicallyLegalDialect( + [=](Operation *op) -> std::optional { + // Handle unary and binary operations + if (op->getNumOperands() < 1 || op->getNumOperands() > 2) + return true; + + // check if input and output are vectors + VectorType resultType = + dyn_cast(op->getResult(0).getType()); + if (!resultType || resultType.getRank() != 2) + return true; + + // Check if all operands are vectors + for (Value operand : op->getOperands()) { + VectorType operandType = dyn_cast(operand.getType()); + if (!operandType || operandType.getRank() != 2 || + operandType.getShape() != resultType.getShape()) { + return true; + } + } + + // check layout attribute + auto layout = dyn_cast_or_null( + op->getAttrOfType("layout")); + return isLegal(layout); + }); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir new file mode 100644 index 0000000000000..c45312e4c2d74 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir @@ -0,0 +1,971 @@ +// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s + +gpu.module @test_elementwise_ops { + // CHECK-LABEL: test_elemwise_ops + gpu.func @test_elemwise_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_c = xegpu.load_nd %tdesc_c + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + %load_d = xegpu.load_nd %tdesc_d + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + + // Floating point ops + // CHECK: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.subf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.exp {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.sqrt {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.absf {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.cos {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.cosh {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.acos {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.acosh {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.sin {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.sinh {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.asin {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.asinh {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.tan {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.tanh {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.atan {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.atan2 {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.atanh {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.erf {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.log {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.log2 {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.floor {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.ceil {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.rsqrt {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.negf {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.mulf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.divf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.maximumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.minimumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.addi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.subi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.muli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.shli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.shrsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.shrui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.divsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.divui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.remsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.remui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + %addf = arith.addf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %subf = arith.subf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %exp = math.exp %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %sqrt = math.sqrt %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %absf = math.absf %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %cos = math.cos %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %cosh = math.cosh %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %acos = math.acos %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %acosh = math.acosh %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %sin = math.sin %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %sinh = math.sinh %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %asin = math.asin %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %asinh = math.asinh %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %tan = math.tan %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %tanh = math.tanh %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %atan = math.atan %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %atan2 = math.atan2 %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %atanh = math.atanh %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %erf = math.erf %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %log = math.log %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %log2 = math.log2 %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %floor = math.floor %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %ceil = math.ceil %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %powf = math.powf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %rsqrt = math.rsqrt %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %negf = arith.negf %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %mulf = arith.mulf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %divf = arith.divf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %maximumf = arith.maximumf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %minimumf = arith.minimumf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + + // Integer ops + %addi = arith.addi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %subi = arith.subi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %muli = arith.muli %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %shli = arith.shli %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %shrsi = arith.shrsi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %shrui = arith.shrui %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %divsi = arith.divsi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %divui = arith.divui %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %remsi = arith.remsi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %remui = arith.remui %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + + gpu.return + } + + // 1 to N decomposition of elementwise operations + // CHECK-LABEL: test_elemwise_ops_sg_rr_assignment + gpu.func @test_elemwise_ops_sg_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_c = xegpu.load_nd %tdesc_c + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + %load_d = xegpu.load_nd %tdesc_d + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + + // Floating point ops + // CHECK-COUNT-12: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.subf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.exp {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.sqrt {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.absf {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.cos {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.cosh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.acos {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.acosh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.sin {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.sinh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.asin {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.asinh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.tan {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.tanh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.atan {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.atan2 {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.atanh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.erf {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.log {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.log2 {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.floor {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.ceil {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: math.rsqrt {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.mulf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.divf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.maximumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.minimumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.addi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.subi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.muli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.shli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.shrsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.shrui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.divsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.divui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.remsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.remui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + %addf = arith.addf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %subf = arith.subf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %exp = math.exp %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %sqrt = math.sqrt %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %absf = math.absf %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %cos = math.cos %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %cosh = math.cosh %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %acos = math.acos %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %acosh = math.acosh %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %sin = math.sin %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %sinh = math.sinh %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %asin = math.asin %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %asinh = math.asinh %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %tan = math.tan %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %tanh = math.tanh %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %atan = math.atan %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %atan2 = math.atan2 %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %atanh = math.atanh %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %erf = math.erf %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %log = math.log %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %log2 = math.log2 %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %floor = math.floor %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %ceil = math.ceil %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %powf = math.powf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %rsqrt = math.rsqrt %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %negf = arith.negf %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %mulf = arith.mulf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %divf = arith.divf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %maximumf = arith.maximumf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %minimumf = arith.minimumf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + + // Integer ops + %addi = arith.addi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %subi = arith.subi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %muli = arith.muli %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %shli = arith.shli %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %shrsi = arith.shrsi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %shrui = arith.shrui %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %divsi = arith.divsi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %divui = arith.divui %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %remsi = arith.remsi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %remui = arith.remui %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + + gpu.return + } + + // CHECK-LABEL: test_all_type_conversion_ops + gpu.func @test_all_type_conversion_ops( + %a: memref<24x32xf32>, %b: memref<24x32xi32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + %load_c = xegpu.load_nd %tdesc_c + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + %load_d = xegpu.load_nd %tdesc_d + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + + // CHECK: arith.truncf {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> to vector<12x8xf16> + // CHECK: arith.trunci {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> to vector<12x8xi16> + // CHECK: arith.extf {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf16> to vector<12x8xf32> + // CHECK: arith.extsi {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi16> to vector<12x8xi32> + // CHECK: arith.extui {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi16> to vector<12x8xi32> + // CHECK: arith.sitofp {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> to vector<12x8xf32> + // CHECK: arith.uitofp {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> to vector<12x8xf32> + // CHECK: arith.fptosi {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> to vector<12x8xi32> + // CHECK: arith.fptoui {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> to vector<12x8xi32> + // CHECK: arith.index_castui {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> to vector<12x8xindex> + // CHECK: arith.index_cast {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xindex> to vector<12x8xi32> + // CHECK: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> to vector<12x8xf32> + // TruncFOp: f32 -> f16 + %truncf = arith.truncf %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> to vector<24x32xf16> + // TruncIOp: i32 -> i16 + %trunci = arith.trunci %load_b + {layout = #xegpu.layout} + : vector<24x32xi32> to vector<24x32xi16> + // ExtFOp: f16 -> f32 + %truncf16 = arith.truncf %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> to vector<24x32xf16> + %extf = arith.extf %truncf16 + {layout = #xegpu.layout} + : vector<24x32xf16> to vector<24x32xf32> + // ExtSIOp: i16 -> i32 + %extsi = arith.extsi %trunci + {layout = #xegpu.layout} + : vector<24x32xi16> to vector<24x32xi32> + // ExtUIOp: i16 -> i32 (unsigned) + %extui = arith.extui %trunci + {layout = #xegpu.layout} + : vector<24x32xi16> to vector<24x32xi32> + // SIToFPOp: i32 -> f32 + %sitofp = arith.sitofp %load_b + {layout = #xegpu.layout} + : vector<24x32xi32> to vector<24x32xf32> + // UIToFPOp: i32 -> f32 (unsigned) + %uitofp = arith.uitofp %load_b + {layout = #xegpu.layout} + : vector<24x32xi32> to vector<24x32xf32> + // FPToSIOp: f32 -> i32 + %fptosi = arith.fptosi %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> to vector<24x32xi32> + // FPToUIOp: f32 -> i32 (unsigned) + %fptoui = arith.fptoui %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> to vector<24x32xi32> + // IndexCastUIOp: i32 -> index + %indexcastui = arith.index_castui %load_b + {layout = #xegpu.layout} + : vector<24x32xi32> to vector<24x32xindex> + // IndexCastOp: index -> i32 + %indexcast = arith.index_cast %indexcastui + {layout = #xegpu.layout} + : vector<24x32xindex> to vector<24x32xi32> + // BitcastOp: i32 -> f32 + %bitcast = arith.bitcast %load_b + {layout = #xegpu.layout} + : vector<24x32xi32> to vector<24x32xf32> + gpu.return + } + + + // CHECK-LABEL: gpu.func @test_all_type_conversion_ops_rr_assignment + gpu.func @test_all_type_conversion_ops_rr_assignment( + %a: memref<24x32xf32>, %b: memref<24x32xi32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + %load_c = xegpu.load_nd %tdesc_c + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + %load_d = xegpu.load_nd %tdesc_d + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + + // CHECK-COUNT-12: arith.truncf {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> to vector<2x2xf16> + // CHECK-COUNT-12: arith.trunci {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> to vector<2x2xi16> + // CHECK-COUNT-12: arith.extf {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf16> to vector<2x2xf32> + // CHECK-COUNT-12: arith.extsi {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi16> to vector<2x2xi32> + // CHECK-COUNT-12: arith.extui {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi16> to vector<2x2xi32> + // CHECK-COUNT-12: arith.sitofp {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> to vector<2x2xf32> + // CHECK-COUNT-12: arith.uitofp {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> to vector<2x2xf32> + // CHECK-COUNT-12: arith.fptosi {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> to vector<2x2xi32> + // CHECK-COUNT-12: arith.fptoui {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> to vector<2x2xi32> + // CHECK-COUNT-12: arith.index_castui {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> to vector<2x2xindex> + // CHECK-COUNT-12: arith.index_cast {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xindex> to vector<2x2xi32> + // CHECK-COUNT-12: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> to vector<2x2xf32> + // TruncFOp: f32 -> f16 + %truncf = arith.truncf %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> to vector<24x32xf16> + // TruncIOp: i32 -> i16 + %trunci = arith.trunci %load_b + {layout = #xegpu.layout} + : vector<24x32xi32> to vector<24x32xi16> + // ExtFOp: f16 -> f32 + %truncf16 = arith.truncf %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> to vector<24x32xf16> + %extf = arith.extf %truncf16 + {layout = #xegpu.layout} + : vector<24x32xf16> to vector<24x32xf32> + // ExtSIOp: i16 -> i32 + %extsi = arith.extsi %trunci + {layout = #xegpu.layout} + : vector<24x32xi16> to vector<24x32xi32> + // ExtUIOp: i16 -> i32 (unsigned) + %extui = arith.extui %trunci + {layout = #xegpu.layout} + : vector<24x32xi16> to vector<24x32xi32> + // SIToFPOp: i32 -> f32 + %sitofp = arith.sitofp %load_b + {layout = #xegpu.layout} + : vector<24x32xi32> to vector<24x32xf32> + // UIToFPOp: i32 -> f32 (unsigned) + %uitofp = arith.uitofp %load_b + {layout = #xegpu.layout} + : vector<24x32xi32> to vector<24x32xf32> + // FPToSIOp: f32 -> i32 + %fptosi = arith.fptosi %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> to vector<24x32xi32> + // FPToUIOp: f32 -> i32 (unsigned) + %fptoui = arith.fptoui %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> to vector<24x32xi32> + // IndexCastUIOp: i32 -> index + %indexcastui = arith.index_castui %load_b + {layout = #xegpu.layout} + : vector<24x32xi32> to vector<24x32xindex> + // IndexCastOp: index -> i32 + %indexcast = arith.index_cast %indexcastui + {layout = #xegpu.layout} + : vector<24x32xindex> to vector<24x32xi32> + // BitcastOp: i32 -> f32 + %bitcast = arith.bitcast %load_b + {layout = #xegpu.layout} + : vector<24x32xi32> to vector<24x32xf32> + gpu.return + } + + // CHECK-LABEL: gpu.func @test_cmp_ops + gpu.func @test_cmp_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_c = xegpu.load_nd %tdesc_c + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + %load_d = xegpu.load_nd %tdesc_d + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + + // CHECK: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.cmpi ne, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.cmpi slt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.cmpi sle, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.cmpi sgt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.cmpi sge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.cmpi ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.cmpi ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.cmpi ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.cmpi uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.cmpf oeq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.cmpf ogt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.cmpf oge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.cmpf olt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.cmpf ole, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.cmpf one, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.cmpf ord, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.cmpf ueq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.cmpf ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.cmpf uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.cmpf ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.cmpf une, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.cmpf uno, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // Integer comparisons + %cmpi_eq = arith.cmpi eq, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_ne = arith.cmpi ne, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_slt = arith.cmpi slt, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_sle = arith.cmpi sle, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_sgt = arith.cmpi sgt, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_sge = arith.cmpi sge, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_ult = arith.cmpi ult, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_ule = arith.cmpi ule, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_ugt = arith.cmpi ugt, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_uge = arith.cmpi uge, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + + // Floating point comparisons + %cmpf_oeq = arith.cmpf oeq, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_ogt = arith.cmpf ogt, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_oge = arith.cmpf oge, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_olt = arith.cmpf olt, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_ole = arith.cmpf ole, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_one = arith.cmpf one, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_ord = arith.cmpf ord, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_ueq = arith.cmpf ueq, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_ugt = arith.cmpf ugt, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_uge = arith.cmpf uge, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_ult = arith.cmpf ult, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_ule = arith.cmpf ule, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_une = arith.cmpf une, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_uno = arith.cmpf uno, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + gpu.return + } + + // CHECK-LABEL: gpu.func @test_cmp_ops_rr_assignment + gpu.func @test_cmp_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_c = xegpu.load_nd %tdesc_c + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + %load_d = xegpu.load_nd %tdesc_d + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + + // CHECK-COUNT-12: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.cmpi ne, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.cmpi slt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.cmpi sle, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.cmpi sgt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.cmpi sge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.cmpi ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.cmpi ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.cmpi ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-COUNT-12: arith.cmpi uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // Floating point comparisons + // CHECK-COUNT-12: arith.cmpf oeq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.cmpf ogt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.cmpf oge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.cmpf olt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.cmpf ole, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.cmpf one, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.cmpf ord, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.cmpf ueq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.cmpf ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.cmpf uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.cmpf ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.cmpf une, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-COUNT-12: arith.cmpf uno, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + + // Integer comparisons + %cmpi_eq = arith.cmpi eq, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_ne = arith.cmpi ne, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_slt = arith.cmpi slt, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_sle = arith.cmpi sle, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_sgt = arith.cmpi sgt, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_sge = arith.cmpi sge, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_ult = arith.cmpi ult, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_ule = arith.cmpi ule, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_ugt = arith.cmpi ugt, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %cmpi_uge = arith.cmpi uge, %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + + // Floating point comparisons + %cmpf_oeq = arith.cmpf oeq, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_ogt = arith.cmpf ogt, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_oge = arith.cmpf oge, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_olt = arith.cmpf olt, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_ole = arith.cmpf ole, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_one = arith.cmpf one, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_ord = arith.cmpf ord, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_ueq = arith.cmpf ueq, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_ugt = arith.cmpf ugt, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_uge = arith.cmpf uge, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_ult = arith.cmpf ult, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_ule = arith.cmpf ule, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_une = arith.cmpf une, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf_uno = arith.cmpf uno, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + + gpu.return + } + + gpu.func @test_extra_elemwise_ops( + %a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>, %e: memref<24x32xi1>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %tdesc_e = xegpu.create_nd_tdesc %e[0, 0] : memref<24x32xi1> + -> !xegpu.tensor_desc<24x32xi1, #xegpu.layout> + + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_c = xegpu.load_nd %tdesc_c + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + %load_d = xegpu.load_nd %tdesc_d + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + %load_e = xegpu.load_nd %tdesc_e + : !xegpu.tensor_desc<24x32xi1, #xegpu.layout> + -> vector<24x32xi1> + + // CHECK: arith.andi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.ori {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.xori {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.ceildivsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.ceildivui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.floordivsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.maxnumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.maxsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.maxui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.minnumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.minsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.minui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: arith.remf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.absi {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: math.cbrt {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.copysign {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.ctpop {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: math.erfc {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.exp2 {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.expm1 {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.fpowi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32>, vector<12x8xi32> + // CHECK: math.ipowi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> + // CHECK: math.log10 {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.log1p {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.round {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.roundeven {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: math.trunc {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // arith ops + %andi = arith.andi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %ori = arith.ori %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %xori = arith.xori %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %ceildivsi = arith.ceildivsi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %ceildivui = arith.ceildivui %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %floordivsi = arith.floordivsi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %maxnumf = arith.maxnumf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %maxsi = arith.maxsi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %maxui = arith.maxui %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %minnumf = arith.minnumf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %minsi = arith.minsi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %minui = arith.minui %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %remf = arith.remf %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %cmpf = arith.cmpf ult, %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + + // math ops + %absi = math.absi %load_c + {layout = #xegpu.layout} + : vector<24x32xi32> + %cbrt = math.cbrt %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %copysign = math.copysign %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32> + %ctpop = math.ctpop %load_c + {layout = #xegpu.layout} + : vector<24x32xi32> + %erfc = math.erfc %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %exp2 = math.exp2 %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %expm1 = math.expm1 %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %fpowi = math.fpowi %load_a, %load_c + {layout = #xegpu.layout} + : vector<24x32xf32>, vector<24x32xi32> + %ipowi = math.ipowi %load_c, %load_d + {layout = #xegpu.layout} + : vector<24x32xi32> + %log10 = math.log10 %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %log1p = math.log1p %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %round = math.round %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %roundeven = math.roundeven %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + %trunc = math.trunc %load_a + {layout = #xegpu.layout} + : vector<24x32xf32> + gpu.return + } +} From e215e22faeb6e1e8123ebedf914f17053d3d3851 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 6 Jun 2025 15:17:08 +0000 Subject: [PATCH 02/13] Clean up --- .../Transforms/XeGPUWgToSgDistribute.cpp | 6 +- .../XeGPU/xegpu-wg-to-sg-elemwise.mlir | 105 +++++++++++++++--- 2 files changed, 92 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 972394a7b40ad..771642f1a34e9 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -317,8 +317,7 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern { } }; -// This pattern matches elementwise ops (unary/binary) in math/arith dialects -// with 1D or 2D vector types +// This pattern transforms elementwise ops (unary/binary) in math/arith dialect template struct WgToSgElementwiseOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -344,7 +343,6 @@ struct WgToSgElementwiseOp : public OpConversionPattern { } } - // Check for layout attribute with sgLayout auto layout = dyn_cast_or_null(op->getAttr("layout")); if (!layout || !layout.getSgLayout()) return rewriter.notifyMatchFailure( @@ -364,7 +362,6 @@ struct WgToSgElementwiseOp : public OpConversionPattern { } } - // Each operand is a list of values size_t numVariants = adaptor.getOperands().empty() ? 0 : adaptor.getOperands().front().size(); @@ -586,7 +583,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() { } } - // check layout attribute auto layout = dyn_cast_or_null( op->getAttrOfType("layout")); return isLegal(layout); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir index c45312e4c2d74..85767f4f2bd67 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir @@ -1,8 +1,8 @@ // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s -gpu.module @test_elementwise_ops { - // CHECK-LABEL: test_elemwise_ops - gpu.func @test_elemwise_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { +gpu.module @elementwise_ops { + // CHECK-LABEL: elemwise_ops + gpu.func @elemwise_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> @@ -193,8 +193,8 @@ gpu.module @test_elementwise_ops { } // 1 to N decomposition of elementwise operations - // CHECK-LABEL: test_elemwise_ops_sg_rr_assignment - gpu.func @test_elemwise_ops_sg_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { + // CHECK-LABEL: elemwise_ops_rr_assignment + gpu.func @elemwise_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> @@ -219,45 +219,85 @@ gpu.module @test_elementwise_ops { // Floating point ops // CHECK-COUNT-12: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.addf // CHECK-COUNT-12: arith.subf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.subf // CHECK-COUNT-12: math.exp {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.exp // CHECK-COUNT-12: math.sqrt {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.sqrt // CHECK-COUNT-12: math.absf {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.absf // CHECK-COUNT-12: math.cos {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.cos // CHECK-COUNT-12: math.cosh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.cosh // CHECK-COUNT-12: math.acos {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.acos // CHECK-COUNT-12: math.acosh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.acosh // CHECK-COUNT-12: math.sin {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.sin // CHECK-COUNT-12: math.sinh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.sinh // CHECK-COUNT-12: math.asin {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.asin // CHECK-COUNT-12: math.asinh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.asinh // CHECK-COUNT-12: math.tan {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.tan // CHECK-COUNT-12: math.tanh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.tanh // CHECK-COUNT-12: math.atan {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.atan // CHECK-COUNT-12: math.atan2 {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.atan2 // CHECK-COUNT-12: math.atanh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.atanh // CHECK-COUNT-12: math.erf {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.erf // CHECK-COUNT-12: math.log {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.log // CHECK-COUNT-12: math.log2 {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.log2 // CHECK-COUNT-12: math.floor {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.floor // CHECK-COUNT-12: math.ceil {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.ceil // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.powf // CHECK-COUNT-12: math.rsqrt {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: math.rsqrt // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.negf // CHECK-COUNT-12: arith.mulf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.mulf // CHECK-COUNT-12: arith.divf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.divf // CHECK-COUNT-12: arith.maximumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.maximumf // CHECK-COUNT-12: arith.minimumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.minimumf // CHECK-COUNT-12: arith.addi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.addi // CHECK-COUNT-12: arith.subi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.subi // CHECK-COUNT-12: arith.muli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.muli // CHECK-COUNT-12: arith.shli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.shli // CHECK-COUNT-12: arith.shrsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.shrsi // CHECK-COUNT-12: arith.shrui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.shrui // CHECK-COUNT-12: arith.divsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.divsi // CHECK-COUNT-12: arith.divui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.divui // CHECK-COUNT-12: arith.remsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.remsi // CHECK-COUNT-12: arith.remui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.remui %addf = arith.addf %load_a, %load_b {layout = #xegpu.layout} : vector<24x32xf32> @@ -384,8 +424,8 @@ gpu.module @test_elementwise_ops { gpu.return } - // CHECK-LABEL: test_all_type_conversion_ops - gpu.func @test_all_type_conversion_ops( + // CHECK-LABEL: type_conversion_ops + gpu.func @type_conversion_ops( %a: memref<24x32xf32>, %b: memref<24x32xi32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> @@ -476,8 +516,8 @@ gpu.module @test_elementwise_ops { } - // CHECK-LABEL: gpu.func @test_all_type_conversion_ops_rr_assignment - gpu.func @test_all_type_conversion_ops_rr_assignment( + // CHECK-LABEL: gpu.func @type_conversion_ops_rr_assignment + gpu.func @type_conversion_ops_rr_assignment( %a: memref<24x32xf32>, %b: memref<24x32xi32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> @@ -502,17 +542,29 @@ gpu.module @test_elementwise_ops { -> vector<24x32xi32> // CHECK-COUNT-12: arith.truncf {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> to vector<2x2xf16> + // CHECK-NOT: arith.truncf // CHECK-COUNT-12: arith.trunci {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> to vector<2x2xi16> + // CHECK-NOT: arith.trunci // CHECK-COUNT-12: arith.extf {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf16> to vector<2x2xf32> + // CHECK-NOT: arith.extf // CHECK-COUNT-12: arith.extsi {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi16> to vector<2x2xi32> + // CHECK-NOT: arith.extsi // CHECK-COUNT-12: arith.extui {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi16> to vector<2x2xi32> + // CHECK-NOT: arith.extui // CHECK-COUNT-12: arith.sitofp {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> to vector<2x2xf32> + // CHECK-NOT: arith.sitofp // CHECK-COUNT-12: arith.uitofp {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> to vector<2x2xf32> + // CHECK-NOT: arith.uitofp // CHECK-COUNT-12: arith.fptosi {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> to vector<2x2xi32> + // CHECK-NOT: arith.fptosi // CHECK-COUNT-12: arith.fptoui {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> to vector<2x2xi32> + // CHECK-NOT: arith.fptoui // CHECK-COUNT-12: arith.index_castui {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> to vector<2x2xindex> + // CHECK-NOT: arith.index_castui // CHECK-COUNT-12: arith.index_cast {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xindex> to vector<2x2xi32> + // CHECK-NOT: arith.index_cast // CHECK-COUNT-12: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> to vector<2x2xf32> + // CHECK-NOT: arith.bitcast // TruncFOp: f32 -> f16 %truncf = arith.truncf %load_a {layout = #xegpu.layout} @@ -567,8 +619,8 @@ gpu.module @test_elementwise_ops { gpu.return } - // CHECK-LABEL: gpu.func @test_cmp_ops - gpu.func @test_cmp_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { + // CHECK-LABEL: gpu.func @comparison_ops + gpu.func @comparison_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> @@ -693,8 +745,8 @@ gpu.module @test_elementwise_ops { gpu.return } - // CHECK-LABEL: gpu.func @test_cmp_ops_rr_assignment - gpu.func @test_cmp_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { + // CHECK-LABEL: gpu.func @comparison_ops_rr_assignment + gpu.func @comparison_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> @@ -718,30 +770,54 @@ gpu.module @test_elementwise_ops { -> vector<24x32xi32> // CHECK-COUNT-12: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.cmpi eq // CHECK-COUNT-12: arith.cmpi ne, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.cmpi ne // CHECK-COUNT-12: arith.cmpi slt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.cmpi slt // CHECK-COUNT-12: arith.cmpi sle, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.cmpi sle // CHECK-COUNT-12: arith.cmpi sgt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.cmpi sgt // CHECK-COUNT-12: arith.cmpi sge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.cmpi sge // CHECK-COUNT-12: arith.cmpi ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.cmpi ult // CHECK-COUNT-12: arith.cmpi ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.cmpi ule // CHECK-COUNT-12: arith.cmpi ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.cmpi ugt // CHECK-COUNT-12: arith.cmpi uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> + // CHECK-NOT: arith.cmpi uge // Floating point comparisons // CHECK-COUNT-12: arith.cmpf oeq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.cmpf oeq // CHECK-COUNT-12: arith.cmpf ogt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.cmpf ogt // CHECK-COUNT-12: arith.cmpf oge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.cmpf oge // CHECK-COUNT-12: arith.cmpf olt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.cmpf olt // CHECK-COUNT-12: arith.cmpf ole, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.cmpf ole // CHECK-COUNT-12: arith.cmpf one, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.cmpf one // CHECK-COUNT-12: arith.cmpf ord, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.cmpf ord // CHECK-COUNT-12: arith.cmpf ueq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.cmpf ueq // CHECK-COUNT-12: arith.cmpf ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.cmpf ugt // CHECK-COUNT-12: arith.cmpf uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.cmpf uge // CHECK-COUNT-12: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.cmpf ult // CHECK-COUNT-12: arith.cmpf ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.cmpf ule // CHECK-COUNT-12: arith.cmpf une, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.cmpf une // CHECK-COUNT-12: arith.cmpf uno, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> + // CHECK-NOT: arith.cmpf uno // Integer comparisons %cmpi_eq = arith.cmpi eq, %load_c, %load_d @@ -822,7 +898,8 @@ gpu.module @test_elementwise_ops { gpu.return } - gpu.func @test_extra_elemwise_ops( + // CHECK-LABEL: gpu.func @elementwise_ops + gpu.func @elementwise_ops( %a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>, %e: memref<24x32xi1>) { %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> From 6960b5cfead392b8d3b1731498a014f450a58adb Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 6 Jun 2025 17:00:36 +0000 Subject: [PATCH 03/13] Refine --- .../Transforms/XeGPUWgToSgDistribute.cpp | 154 ++++++------------ 1 file changed, 49 insertions(+), 105 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 771642f1a34e9..e1687031d259a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -401,92 +401,6 @@ struct WgToSgElementwiseOp : public OpConversionPattern { } }; -// ---- ARITH ops ---- -using WgToSgAddFOp = WgToSgElementwiseOp; -using WgToSgSubFOp = WgToSgElementwiseOp; -using WgToSgNegFOp = WgToSgElementwiseOp; -using WgToSgAddIOp = WgToSgElementwiseOp; -using WgToSgSubIOp = WgToSgElementwiseOp; -using WgToSgMulFOp = WgToSgElementwiseOp; -using WgToSgMulIOp = WgToSgElementwiseOp; -using WgToSgShLIOp = WgToSgElementwiseOp; -using WgToSgShRSIOp = WgToSgElementwiseOp; -using WgToSgShRUIOp = WgToSgElementwiseOp; -using WgToSgDivFOp = WgToSgElementwiseOp; -using WgToSgDivSIOp = WgToSgElementwiseOp; -using WgToSgDivUIOp = WgToSgElementwiseOp; -using WgToSgMaximumFOp = WgToSgElementwiseOp; -using WgToSgMinimumFOp = WgToSgElementwiseOp; -using WgToSgRemSIOp = WgToSgElementwiseOp; -using WgToSgRemUIOp = WgToSgElementwiseOp; -using WgToSgTruncFOp = WgToSgElementwiseOp; -using WgToSgTruncIOp = WgToSgElementwiseOp; -using WgToSgExtFOp = WgToSgElementwiseOp; -using WgToSgExtSIOp = WgToSgElementwiseOp; -using WgToSgExtUIOp = WgToSgElementwiseOp; -using WgToSgSIToFPOp = WgToSgElementwiseOp; -using WgToSgUIToFPOp = WgToSgElementwiseOp; -using WgToSgFPToSIOp = WgToSgElementwiseOp; -using WgToSgFPToUIOp = WgToSgElementwiseOp; -using WgToSgIndexCastUIOp = WgToSgElementwiseOp; -using WgToSgIndexCastOp = WgToSgElementwiseOp; -using WgToSgBitcastOp = WgToSgElementwiseOp; -using WgToSgCmpIOp = WgToSgElementwiseOp; -using WgToSgCmpFOp = WgToSgElementwiseOp; -using WgToSgAndIOp = WgToSgElementwiseOp; -using WgToSgCeilDivSIOp = WgToSgElementwiseOp; -using WgToSgCeilDivUIOp = WgToSgElementwiseOp; -using WgToSgFloorDivSIOp = WgToSgElementwiseOp; -using WgToSgMaxNumFOp = WgToSgElementwiseOp; -using WgToSgMaxSIOp = WgToSgElementwiseOp; -using WgToSgMaxUIOp = WgToSgElementwiseOp; -using WgToSgMinNumFOp = WgToSgElementwiseOp; -using WgToSgMinSIOp = WgToSgElementwiseOp; -using WgToSgMinUIOp = WgToSgElementwiseOp; -using WgToSgOrIOp = WgToSgElementwiseOp; -using WgToSgRemFOp = WgToSgElementwiseOp; -using WgToSgSelectOp = WgToSgElementwiseOp; -using WgToSgXOrIOp = WgToSgElementwiseOp; - -// ---- MATH ops ---- -using WgToSgExpOp = WgToSgElementwiseOp; -using WgToSgSqrtOp = WgToSgElementwiseOp; -using WgToSgAbsFOp = WgToSgElementwiseOp; -using WgToSgCosOp = WgToSgElementwiseOp; -using WgToSgCoshOp = WgToSgElementwiseOp; -using WgToSgAcosOp = WgToSgElementwiseOp; -using WgToSgAcoshOp = WgToSgElementwiseOp; -using WgToSgSinOp = WgToSgElementwiseOp; -using WgToSgSinhOp = WgToSgElementwiseOp; -using WgToSgAsinOp = WgToSgElementwiseOp; -using WgToSgAsinhOp = WgToSgElementwiseOp; -using WgToSgTanOp = WgToSgElementwiseOp; -using WgToSgTanhOp = WgToSgElementwiseOp; -using WgToSgAtanOp = WgToSgElementwiseOp; -using WgToSgAtan2Op = WgToSgElementwiseOp; -using WgToSgAtanhOp = WgToSgElementwiseOp; -using WgToSgErfOp = WgToSgElementwiseOp; -using WgToSgLogOp = WgToSgElementwiseOp; -using WgToSgLog2Op = WgToSgElementwiseOp; -using WgToSgFloorOp = WgToSgElementwiseOp; -using WgToSgCeilOp = WgToSgElementwiseOp; -using WgToSgPowFOp = WgToSgElementwiseOp; -using WgToSgRsqrtOp = WgToSgElementwiseOp; -using WgToSgAbsIOp = WgToSgElementwiseOp; -using WgToSgCbrtOp = WgToSgElementwiseOp; -using WgToSgCopySignOp = WgToSgElementwiseOp; -using WgToSgCtPopOp = WgToSgElementwiseOp; -using WgToSgErfcOp = WgToSgElementwiseOp; -using WgToSgExp2Op = WgToSgElementwiseOp; -using WgToSgExpM1Op = WgToSgElementwiseOp; -using WgToSgFPowIOp = WgToSgElementwiseOp; -using WgToSgIPowIOp = WgToSgElementwiseOp; -using WgToSgLog10Op = WgToSgElementwiseOp; -using WgToSgLog1pOp = WgToSgElementwiseOp; -using WgToSgRoundOp = WgToSgElementwiseOp; -using WgToSgRoundEvenOp = WgToSgElementwiseOp; -using WgToSgTruncOp = WgToSgElementwiseOp; - } // namespace namespace mlir { @@ -497,25 +411,55 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { patterns.getContext()); // Add elementwise operations that can be distributed to subgroups patterns.add< - WgToSgAddFOp, WgToSgSubFOp, WgToSgExpOp, WgToSgSqrtOp, WgToSgAbsFOp, - WgToSgCosOp, WgToSgCoshOp, WgToSgAcosOp, WgToSgAcoshOp, WgToSgSinOp, - WgToSgSinhOp, WgToSgAsinOp, WgToSgAsinhOp, WgToSgTanOp, WgToSgTanhOp, - WgToSgAtanOp, WgToSgAtan2Op, WgToSgAtanhOp, WgToSgErfOp, WgToSgLogOp, - WgToSgLog2Op, WgToSgFloorOp, WgToSgCeilOp, WgToSgPowFOp, WgToSgRsqrtOp, - WgToSgNegFOp, WgToSgAddIOp, WgToSgSubIOp, WgToSgMulFOp, WgToSgMulIOp, - WgToSgShLIOp, WgToSgShRSIOp, WgToSgShRUIOp, WgToSgDivFOp, WgToSgDivSIOp, - WgToSgDivUIOp, WgToSgMaximumFOp, WgToSgMinimumFOp, WgToSgRemSIOp, - WgToSgRemUIOp, WgToSgTruncFOp, WgToSgTruncIOp, WgToSgExtFOp, - WgToSgExtSIOp, WgToSgExtUIOp, WgToSgSIToFPOp, WgToSgUIToFPOp, - WgToSgFPToSIOp, WgToSgFPToUIOp, WgToSgIndexCastUIOp, WgToSgIndexCastOp, - WgToSgBitcastOp, WgToSgCmpIOp, WgToSgCmpFOp, WgToSgAndIOp, - WgToSgCeilDivSIOp, WgToSgCeilDivUIOp, WgToSgFloorDivSIOp, WgToSgMaxNumFOp, - WgToSgMaxSIOp, WgToSgMaxUIOp, WgToSgMinNumFOp, WgToSgMinSIOp, - WgToSgMinUIOp, WgToSgOrIOp, WgToSgRemFOp, WgToSgSelectOp, WgToSgXOrIOp, - WgToSgAbsIOp, WgToSgCbrtOp, WgToSgCopySignOp, WgToSgCtPopOp, WgToSgErfcOp, - WgToSgExp2Op, WgToSgExpM1Op, WgToSgFPowIOp, WgToSgIPowIOp, WgToSgLog10Op, - WgToSgLog1pOp, WgToSgRoundOp, WgToSgRoundEvenOp, WgToSgTruncOp>( - patterns.getContext()); + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, + WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, + WgToSgElementwiseOp, + WgToSgElementwiseOp, + WgToSgElementwiseOp, + WgToSgElementwiseOp, + WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, + WgToSgElementwiseOp, + WgToSgElementwiseOp, + WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, + WgToSgElementwiseOp>(patterns.getContext()); } } // namespace xegpu } // namespace mlir From 37ea14752d02a669bfbff6bcd66e3982512a379e Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 6 Jun 2025 19:33:50 +0000 Subject: [PATCH 04/13] Clean up --- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index e1687031d259a..80e4cfcc566de 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -451,13 +451,13 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgElementwiseOp, WgToSgElementwiseOp, WgToSgElementwiseOp, WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, + WgToSgElementwiseOp, WgToSgElementwiseOp, WgToSgElementwiseOp>(patterns.getContext()); } From adcdec5039007dcb9ad4ced6e4b3302383059854 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 11 Jun 2025 17:50:14 +0000 Subject: [PATCH 05/13] Use OpTrait instead of templating --- .../Transforms/XeGPUWgToSgDistribute.cpp | 117 +++++------------- 1 file changed, 29 insertions(+), 88 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 80e4cfcc566de..8b079bcd23259 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -318,20 +318,19 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern { }; // This pattern transforms elementwise ops (unary/binary) in math/arith dialect -template -struct WgToSgElementwiseOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using OneToNOpAdaptor = typename OpConversionPattern::OneToNOpAdaptor; +struct WgToSgElementwiseOp : public ConversionPattern { + WgToSgElementwiseOp(MLIRContext *ctx) + : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} LogicalResult - matchAndRewrite(Op op, OneToNOpAdaptor adaptor, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - // All operands/results must be 1D or 2D vectors - auto resultType = dyn_cast(op.getResult().getType()); - if (!resultType || (resultType.getRank() != 1 && resultType.getRank() != 2)) - return rewriter.notifyMatchFailure( - op, "Result type is not a 1D or 2D vector"); + // Only match ops with elementwise trait + if (!OpTrait::hasElementwiseMappableTraits(op)) + return rewriter.notifyMatchFailure(op, "Not an elementwise op"); + // All operands/results must be 1D or 2D vectors + auto resultType = dyn_cast(op->getResult(0).getType()); ArrayRef shape = resultType.getShape(); for (Value operand : op->getOperands()) { auto operandType = dyn_cast(operand.getType()); @@ -362,38 +361,32 @@ struct WgToSgElementwiseOp : public OpConversionPattern { } } - size_t numVariants = adaptor.getOperands().empty() - ? 0 - : adaptor.getOperands().front().size(); - for (auto &operandVec : adaptor.getOperands()) + size_t numVariants = operands.empty() ? 0 : operands.front().size(); + for (auto &operandVec : operands) if (operandVec.size() != numVariants) return rewriter.notifyMatchFailure( op, "Operand lists have mismatched sizes"); SmallVector newResults; - - auto origResultType = dyn_cast(op->getResult(0).getType()); VectorType newResultType = - origResultType - ? VectorType::get(sgShape, origResultType.getElementType()) - : VectorType::get(sgShape, resultType.getElementType()); + VectorType::get(sgShape, resultType.getElementType()); for (size_t i = 0; i < numVariants; ++i) { - SmallVector operands; - for (auto &operandVec : adaptor.getOperands()) - operands.push_back(operandVec[i]); - - auto newOp = rewriter.create(op.getLoc(), newResultType, operands); - - // Copy all attributes except "layout", and add "layout_result_0" with - // sgLayout/data dropped + SmallVector opOperands; + for (auto &operandVec : operands) + opOperands.push_back(operandVec[i]); + + OperationState state(op->getLoc(), op->getName()); + state.addOperands(opOperands); + state.addTypes(newResultType); + // Copy all attributes except "layout" for (auto attr : op->getAttrs()) { if (attr.getName() != "layout") - newOp->setAttr(attr.getName(), attr.getValue()); + state.addAttribute(attr.getName(), attr.getValue()); } - newOp->setAttr("layout_result_0", layout.dropSgLayoutAndData()); - - newResults.push_back(newOp.getResult()); + state.addAttribute("layout_result_0", layout.dropSgLayoutAndData()); + Operation *newOp = rewriter.create(state); + newResults.push_back(newOp->getResult(0)); } rewriter.replaceOpWithMultiple(op, {newResults}); @@ -407,59 +400,8 @@ namespace mlir { namespace xegpu { void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { patterns.add( - patterns.getContext()); - // Add elementwise operations that can be distributed to subgroups - patterns.add< - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp, - WgToSgElementwiseOp>(patterns.getContext()); + WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, + WgToSgElementwiseOp>(patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -508,17 +450,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() { }); target.addDynamicallyLegalDialect( [=](Operation *op) -> std::optional { - // Handle unary and binary operations - if (op->getNumOperands() < 1 || op->getNumOperands() > 2) + // Only handle elementwise mappable ops + if (!OpTrait::hasElementwiseMappableTraits(op)) return true; - // check if input and output are vectors VectorType resultType = dyn_cast(op->getResult(0).getType()); if (!resultType || resultType.getRank() != 2) return true; - // Check if all operands are vectors + // Check if all operands are vectors of the same shape for (Value operand : op->getOperands()) { VectorType operandType = dyn_cast(operand.getType()); if (!operandType || operandType.getRank() != 2 || From d2556801f69070bbf05a055e45b76edb81e89133 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 11 Jun 2025 20:09:41 +0000 Subject: [PATCH 06/13] refactor --- .../Transforms/XeGPUWgToSgDistribute.cpp | 30 +- .../XeGPU/xegpu-wg-to-sg-elemwise.mlir | 370 +++++++++--------- 2 files changed, 198 insertions(+), 202 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 8b079bcd23259..24f7209cfe226 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Transforms/DialectConversion.h" #include @@ -317,7 +318,7 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern { } }; -// This pattern transforms elementwise ops (unary/binary) in math/arith dialect +// This pattern transforms elementwise ops in math/arith dialect struct WgToSgElementwiseOp : public ConversionPattern { WgToSgElementwiseOp(MLIRContext *ctx) : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} @@ -329,20 +330,10 @@ struct WgToSgElementwiseOp : public ConversionPattern { if (!OpTrait::hasElementwiseMappableTraits(op)) return rewriter.notifyMatchFailure(op, "Not an elementwise op"); - // All operands/results must be 1D or 2D vectors auto resultType = dyn_cast(op->getResult(0).getType()); ArrayRef shape = resultType.getShape(); - for (Value operand : op->getOperands()) { - auto operandType = dyn_cast(operand.getType()); - if (!operandType || operandType.getRank() != resultType.getRank() || - operandType.getShape() != shape) { - return rewriter.notifyMatchFailure( - op, "Operand type is not a 1D or 2D vector with the same shape as " - "result type"); - } - } - auto layout = dyn_cast_or_null(op->getAttr("layout")); + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0)); if (!layout || !layout.getSgLayout()) return rewriter.notifyMatchFailure( op, "Operation does not have a valid layout attribute for subgroup " @@ -379,13 +370,14 @@ struct WgToSgElementwiseOp : public ConversionPattern { OperationState state(op->getLoc(), op->getName()); state.addOperands(opOperands); state.addTypes(newResultType); - // Copy all attributes except "layout" + // Copy all attributes, but update "layout_result_0" to drop + // sgLayout/sgData for (auto attr : op->getAttrs()) { - if (attr.getName() != "layout") + if (attr.getName() != "layout_result_0") state.addAttribute(attr.getName(), attr.getValue()); } - state.addAttribute("layout_result_0", layout.dropSgLayoutAndData()); Operation *newOp = rewriter.create(state); + xegpu::setLayoutAttr(newOp->getResult(0), layout.dropSgLayoutAndData()); newResults.push_back(newOp->getResult(0)); } @@ -448,6 +440,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { auto layout = dyn_cast_or_null(op->getAttr("layout")); return isLegal(layout); }); + target.addDynamicallyLegalDialect( [=](Operation *op) -> std::optional { // Only handle elementwise mappable ops @@ -456,20 +449,19 @@ void XeGPUWgToSgDistributePass::runOnOperation() { VectorType resultType = dyn_cast(op->getResult(0).getType()); - if (!resultType || resultType.getRank() != 2) + if (!resultType) return true; // Check if all operands are vectors of the same shape for (Value operand : op->getOperands()) { VectorType operandType = dyn_cast(operand.getType()); - if (!operandType || operandType.getRank() != 2 || - operandType.getShape() != resultType.getShape()) { + if (!operandType || operandType.getShape() != resultType.getShape()) { return true; } } auto layout = dyn_cast_or_null( - op->getAttrOfType("layout")); + op->getAttrOfType("layout_result_0")); return isLegal(layout); }); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir index 85767f4f2bd67..21e128239851e 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir @@ -67,126 +67,126 @@ gpu.module @elementwise_ops { // CHECK: arith.remsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> // CHECK: arith.remui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> %addf = arith.addf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %subf = arith.subf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %exp = math.exp %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %sqrt = math.sqrt %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %absf = math.absf %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cos = math.cos %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cosh = math.cosh %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %acos = math.acos %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %acosh = math.acosh %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %sin = math.sin %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %sinh = math.sinh %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %asin = math.asin %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %asinh = math.asinh %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %tan = math.tan %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %tanh = math.tanh %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %atan = math.atan %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %atan2 = math.atan2 %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %atanh = math.atanh %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %erf = math.erf %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %log = math.log %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %log2 = math.log2 %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %floor = math.floor %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %ceil = math.ceil %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %powf = math.powf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %rsqrt = math.rsqrt %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %negf = arith.negf %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %mulf = arith.mulf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %divf = arith.divf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %maximumf = arith.maximumf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %minimumf = arith.minimumf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> // Integer ops %addi = arith.addi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %subi = arith.subi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %muli = arith.muli %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %shli = arith.shli %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %shrsi = arith.shrsi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %shrui = arith.shrui %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %divsi = arith.divsi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %divui = arith.divui %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %remsi = arith.remsi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %remui = arith.remui %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> gpu.return @@ -299,126 +299,126 @@ gpu.module @elementwise_ops { // CHECK-COUNT-12: arith.remui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> // CHECK-NOT: arith.remui %addf = arith.addf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %subf = arith.subf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %exp = math.exp %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %sqrt = math.sqrt %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %absf = math.absf %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cos = math.cos %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cosh = math.cosh %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %acos = math.acos %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %acosh = math.acosh %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %sin = math.sin %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %sinh = math.sinh %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %asin = math.asin %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %asinh = math.asinh %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %tan = math.tan %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %tanh = math.tanh %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %atan = math.atan %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %atan2 = math.atan2 %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %atanh = math.atanh %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %erf = math.erf %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %log = math.log %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %log2 = math.log2 %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %floor = math.floor %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %ceil = math.ceil %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %powf = math.powf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %rsqrt = math.rsqrt %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %negf = arith.negf %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %mulf = arith.mulf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %divf = arith.divf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %maximumf = arith.maximumf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %minimumf = arith.minimumf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> // Integer ops %addi = arith.addi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %subi = arith.subi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %muli = arith.muli %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %shli = arith.shli %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %shrsi = arith.shrsi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %shrui = arith.shrui %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %divsi = arith.divsi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %divui = arith.divui %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %remsi = arith.remsi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %remui = arith.remui %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> gpu.return @@ -463,54 +463,54 @@ gpu.module @elementwise_ops { // CHECK: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> to vector<12x8xf32> // TruncFOp: f32 -> f16 %truncf = arith.truncf %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> to vector<24x32xf16> // TruncIOp: i32 -> i16 %trunci = arith.trunci %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> to vector<24x32xi16> // ExtFOp: f16 -> f32 %truncf16 = arith.truncf %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> to vector<24x32xf16> %extf = arith.extf %truncf16 - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf16> to vector<24x32xf32> // ExtSIOp: i16 -> i32 %extsi = arith.extsi %trunci - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi16> to vector<24x32xi32> // ExtUIOp: i16 -> i32 (unsigned) %extui = arith.extui %trunci - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi16> to vector<24x32xi32> // SIToFPOp: i32 -> f32 %sitofp = arith.sitofp %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> to vector<24x32xf32> // UIToFPOp: i32 -> f32 (unsigned) %uitofp = arith.uitofp %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> to vector<24x32xf32> // FPToSIOp: f32 -> i32 %fptosi = arith.fptosi %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> to vector<24x32xi32> // FPToUIOp: f32 -> i32 (unsigned) %fptoui = arith.fptoui %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> to vector<24x32xi32> // IndexCastUIOp: i32 -> index %indexcastui = arith.index_castui %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> to vector<24x32xindex> // IndexCastOp: index -> i32 %indexcast = arith.index_cast %indexcastui - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xindex> to vector<24x32xi32> // BitcastOp: i32 -> f32 %bitcast = arith.bitcast %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> to vector<24x32xf32> gpu.return } @@ -567,54 +567,54 @@ gpu.module @elementwise_ops { // CHECK-NOT: arith.bitcast // TruncFOp: f32 -> f16 %truncf = arith.truncf %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> to vector<24x32xf16> // TruncIOp: i32 -> i16 %trunci = arith.trunci %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> to vector<24x32xi16> // ExtFOp: f16 -> f32 %truncf16 = arith.truncf %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> to vector<24x32xf16> %extf = arith.extf %truncf16 - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf16> to vector<24x32xf32> // ExtSIOp: i16 -> i32 %extsi = arith.extsi %trunci - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi16> to vector<24x32xi32> // ExtUIOp: i16 -> i32 (unsigned) %extui = arith.extui %trunci - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi16> to vector<24x32xi32> // SIToFPOp: i32 -> f32 %sitofp = arith.sitofp %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> to vector<24x32xf32> // UIToFPOp: i32 -> f32 (unsigned) %uitofp = arith.uitofp %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> to vector<24x32xf32> // FPToSIOp: f32 -> i32 %fptosi = arith.fptosi %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> to vector<24x32xi32> // FPToUIOp: f32 -> i32 (unsigned) %fptoui = arith.fptoui %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> to vector<24x32xi32> // IndexCastUIOp: i32 -> index %indexcastui = arith.index_castui %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> to vector<24x32xindex> // IndexCastOp: index -> i32 %indexcast = arith.index_cast %indexcastui - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xindex> to vector<24x32xi32> // BitcastOp: i32 -> f32 %bitcast = arith.bitcast %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> to vector<24x32xf32> gpu.return } @@ -669,78 +669,78 @@ gpu.module @elementwise_ops { // CHECK: arith.cmpf uno, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> // Integer comparisons %cmpi_eq = arith.cmpi eq, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_ne = arith.cmpi ne, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_slt = arith.cmpi slt, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_sle = arith.cmpi sle, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_sgt = arith.cmpi sgt, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_sge = arith.cmpi sge, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_ult = arith.cmpi ult, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_ule = arith.cmpi ule, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_ugt = arith.cmpi ugt, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_uge = arith.cmpi uge, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> // Floating point comparisons %cmpf_oeq = arith.cmpf oeq, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_ogt = arith.cmpf ogt, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_oge = arith.cmpf oge, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_olt = arith.cmpf olt, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_ole = arith.cmpf ole, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_one = arith.cmpf one, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_ord = arith.cmpf ord, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_ueq = arith.cmpf ueq, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_ugt = arith.cmpf ugt, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_uge = arith.cmpf uge, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_ult = arith.cmpf ult, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_ule = arith.cmpf ule, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_une = arith.cmpf une, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_uno = arith.cmpf uno, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> gpu.return } @@ -821,78 +821,78 @@ gpu.module @elementwise_ops { // Integer comparisons %cmpi_eq = arith.cmpi eq, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_ne = arith.cmpi ne, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_slt = arith.cmpi slt, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_sle = arith.cmpi sle, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_sgt = arith.cmpi sgt, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_sge = arith.cmpi sge, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_ult = arith.cmpi ult, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_ule = arith.cmpi ule, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_ugt = arith.cmpi ugt, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cmpi_uge = arith.cmpi uge, %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> // Floating point comparisons %cmpf_oeq = arith.cmpf oeq, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_ogt = arith.cmpf ogt, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_oge = arith.cmpf oge, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_olt = arith.cmpf olt, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_ole = arith.cmpf ole, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_one = arith.cmpf one, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_ord = arith.cmpf ord, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_ueq = arith.cmpf ueq, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_ugt = arith.cmpf ugt, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_uge = arith.cmpf uge, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_ult = arith.cmpf ult, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_ule = arith.cmpf ule, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_une = arith.cmpf une, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf_uno = arith.cmpf uno, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> gpu.return @@ -942,6 +942,7 @@ gpu.module @elementwise_ops { // CHECK: arith.minui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> // CHECK: arith.remf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> // CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + // CHECK: arith.select {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi1>, vector<12x8xf32> // CHECK: math.absi {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> // CHECK: math.cbrt {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> // CHECK: math.copysign {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> @@ -958,91 +959,94 @@ gpu.module @elementwise_ops { // CHECK: math.trunc {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> // arith ops %andi = arith.andi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %ori = arith.ori %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %xori = arith.xori %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %ceildivsi = arith.ceildivsi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %ceildivui = arith.ceildivui %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %floordivsi = arith.floordivsi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %maxnumf = arith.maxnumf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %maxsi = arith.maxsi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %maxui = arith.maxui %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %minnumf = arith.minnumf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %minsi = arith.minsi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %minui = arith.minui %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %remf = arith.remf %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %cmpf = arith.cmpf ult, %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> + %select = arith.select %cmpf, %load_a, %load_b + {layout_result_0 = #xegpu.layout} + : vector<24x32xi1>, vector<24x32xf32> // math ops %absi = math.absi %load_c - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %cbrt = math.cbrt %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %copysign = math.copysign %load_a, %load_b - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %ctpop = math.ctpop %load_c - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %erfc = math.erfc %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %exp2 = math.exp2 %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %expm1 = math.expm1 %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %fpowi = math.fpowi %load_a, %load_c - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32>, vector<24x32xi32> %ipowi = math.ipowi %load_c, %load_d - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xi32> %log10 = math.log10 %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %log1p = math.log1p %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %round = math.round %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %roundeven = math.roundeven %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> %trunc = math.trunc %load_a - {layout = #xegpu.layout} + {layout_result_0 = #xegpu.layout} : vector<24x32xf32> gpu.return } -} +} \ No newline at end of file From 94d0f1b6a934364fe8539c1884ef019ba02f0e2e Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 11 Jun 2025 20:39:02 +0000 Subject: [PATCH 07/13] newline --- mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir index 21e128239851e..8ada50fa89dde 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir @@ -1049,4 +1049,4 @@ gpu.module @elementwise_ops { : vector<24x32xf32> gpu.return } -} \ No newline at end of file +} From 7fd99761295561f4a9eb7037631be18e9eea9f21 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 12 Jun 2025 21:25:27 +0000 Subject: [PATCH 08/13] Clean up tests --- .../XeGPU/xegpu-wg-to-sg-elemwise.mlir | 1212 +++-------------- 1 file changed, 162 insertions(+), 1050 deletions(-) diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir index 8ada50fa89dde..3672e4d9912cf 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir @@ -1,1052 +1,164 @@ // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s -gpu.module @elementwise_ops { - // CHECK-LABEL: elemwise_ops - gpu.func @elemwise_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - - %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - -> vector<24x32xf32> - %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - -> vector<24x32xf32> - %load_c = xegpu.load_nd %tdesc_c - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - %load_d = xegpu.load_nd %tdesc_d - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - - // Floating point ops - // CHECK: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.subf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.exp {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.sqrt {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.absf {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.cos {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.cosh {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.acos {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.acosh {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.sin {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.sinh {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.asin {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.asinh {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.tan {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.tanh {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.atan {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.atan2 {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.atanh {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.erf {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.log {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.log2 {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.floor {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.ceil {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.rsqrt {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.negf {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.mulf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.divf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.maximumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.minimumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.addi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.subi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.muli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.shli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.shrsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.shrui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.divsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.divui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.remsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.remui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - %addf = arith.addf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %subf = arith.subf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %exp = math.exp %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %sqrt = math.sqrt %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %absf = math.absf %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cos = math.cos %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cosh = math.cosh %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %acos = math.acos %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %acosh = math.acosh %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %sin = math.sin %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %sinh = math.sinh %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %asin = math.asin %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %asinh = math.asinh %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %tan = math.tan %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %tanh = math.tanh %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %atan = math.atan %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %atan2 = math.atan2 %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %atanh = math.atanh %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %erf = math.erf %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %log = math.log %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %log2 = math.log2 %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %floor = math.floor %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %ceil = math.ceil %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %powf = math.powf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %rsqrt = math.rsqrt %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %negf = arith.negf %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %mulf = arith.mulf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %divf = arith.divf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %maximumf = arith.maximumf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %minimumf = arith.minimumf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - - // Integer ops - %addi = arith.addi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %subi = arith.subi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %muli = arith.muli %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %shli = arith.shli %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %shrsi = arith.shrsi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %shrui = arith.shrui %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %divsi = arith.divsi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %divui = arith.divui %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %remsi = arith.remsi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %remui = arith.remui %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - - gpu.return - } - - // 1 to N decomposition of elementwise operations - // CHECK-LABEL: elemwise_ops_rr_assignment - gpu.func @elemwise_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - - %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - -> vector<24x32xf32> - %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - -> vector<24x32xf32> - %load_c = xegpu.load_nd %tdesc_c - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - %load_d = xegpu.load_nd %tdesc_d - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - - // Floating point ops - // CHECK-COUNT-12: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.addf - // CHECK-COUNT-12: arith.subf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.subf - // CHECK-COUNT-12: math.exp {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.exp - // CHECK-COUNT-12: math.sqrt {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.sqrt - // CHECK-COUNT-12: math.absf {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.absf - // CHECK-COUNT-12: math.cos {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.cos - // CHECK-COUNT-12: math.cosh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.cosh - // CHECK-COUNT-12: math.acos {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.acos - // CHECK-COUNT-12: math.acosh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.acosh - // CHECK-COUNT-12: math.sin {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.sin - // CHECK-COUNT-12: math.sinh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.sinh - // CHECK-COUNT-12: math.asin {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.asin - // CHECK-COUNT-12: math.asinh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.asinh - // CHECK-COUNT-12: math.tan {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.tan - // CHECK-COUNT-12: math.tanh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.tanh - // CHECK-COUNT-12: math.atan {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.atan - // CHECK-COUNT-12: math.atan2 {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.atan2 - // CHECK-COUNT-12: math.atanh {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.atanh - // CHECK-COUNT-12: math.erf {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.erf - // CHECK-COUNT-12: math.log {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.log - // CHECK-COUNT-12: math.log2 {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.log2 - // CHECK-COUNT-12: math.floor {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.floor - // CHECK-COUNT-12: math.ceil {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.ceil - // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.powf - // CHECK-COUNT-12: math.rsqrt {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: math.rsqrt - // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.negf - // CHECK-COUNT-12: arith.mulf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.mulf - // CHECK-COUNT-12: arith.divf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.divf - // CHECK-COUNT-12: arith.maximumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.maximumf - // CHECK-COUNT-12: arith.minimumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.minimumf - // CHECK-COUNT-12: arith.addi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.addi - // CHECK-COUNT-12: arith.subi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.subi - // CHECK-COUNT-12: arith.muli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.muli - // CHECK-COUNT-12: arith.shli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.shli - // CHECK-COUNT-12: arith.shrsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.shrsi - // CHECK-COUNT-12: arith.shrui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.shrui - // CHECK-COUNT-12: arith.divsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.divsi - // CHECK-COUNT-12: arith.divui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.divui - // CHECK-COUNT-12: arith.remsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.remsi - // CHECK-COUNT-12: arith.remui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.remui - %addf = arith.addf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %subf = arith.subf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %exp = math.exp %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %sqrt = math.sqrt %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %absf = math.absf %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cos = math.cos %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cosh = math.cosh %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %acos = math.acos %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %acosh = math.acosh %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %sin = math.sin %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %sinh = math.sinh %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %asin = math.asin %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %asinh = math.asinh %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %tan = math.tan %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %tanh = math.tanh %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %atan = math.atan %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %atan2 = math.atan2 %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %atanh = math.atanh %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %erf = math.erf %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %log = math.log %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %log2 = math.log2 %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %floor = math.floor %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %ceil = math.ceil %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %powf = math.powf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %rsqrt = math.rsqrt %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %negf = arith.negf %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %mulf = arith.mulf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %divf = arith.divf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %maximumf = arith.maximumf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %minimumf = arith.minimumf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - - // Integer ops - %addi = arith.addi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %subi = arith.subi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %muli = arith.muli %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %shli = arith.shli %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %shrsi = arith.shrsi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %shrui = arith.shrui %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %divsi = arith.divsi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %divui = arith.divui %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %remsi = arith.remsi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %remui = arith.remui %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - - gpu.return - } - - // CHECK-LABEL: type_conversion_ops - gpu.func @type_conversion_ops( - %a: memref<24x32xf32>, %b: memref<24x32xi32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - - %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - -> vector<24x32xf32> - %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - %load_c = xegpu.load_nd %tdesc_c - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - %load_d = xegpu.load_nd %tdesc_d - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - - // CHECK: arith.truncf {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> to vector<12x8xf16> - // CHECK: arith.trunci {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> to vector<12x8xi16> - // CHECK: arith.extf {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf16> to vector<12x8xf32> - // CHECK: arith.extsi {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi16> to vector<12x8xi32> - // CHECK: arith.extui {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi16> to vector<12x8xi32> - // CHECK: arith.sitofp {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> to vector<12x8xf32> - // CHECK: arith.uitofp {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> to vector<12x8xf32> - // CHECK: arith.fptosi {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> to vector<12x8xi32> - // CHECK: arith.fptoui {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> to vector<12x8xi32> - // CHECK: arith.index_castui {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> to vector<12x8xindex> - // CHECK: arith.index_cast {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xindex> to vector<12x8xi32> - // CHECK: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> to vector<12x8xf32> - // TruncFOp: f32 -> f16 - %truncf = arith.truncf %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> to vector<24x32xf16> - // TruncIOp: i32 -> i16 - %trunci = arith.trunci %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> to vector<24x32xi16> - // ExtFOp: f16 -> f32 - %truncf16 = arith.truncf %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> to vector<24x32xf16> - %extf = arith.extf %truncf16 - {layout_result_0 = #xegpu.layout} - : vector<24x32xf16> to vector<24x32xf32> - // ExtSIOp: i16 -> i32 - %extsi = arith.extsi %trunci - {layout_result_0 = #xegpu.layout} - : vector<24x32xi16> to vector<24x32xi32> - // ExtUIOp: i16 -> i32 (unsigned) - %extui = arith.extui %trunci - {layout_result_0 = #xegpu.layout} - : vector<24x32xi16> to vector<24x32xi32> - // SIToFPOp: i32 -> f32 - %sitofp = arith.sitofp %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> to vector<24x32xf32> - // UIToFPOp: i32 -> f32 (unsigned) - %uitofp = arith.uitofp %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> to vector<24x32xf32> - // FPToSIOp: f32 -> i32 - %fptosi = arith.fptosi %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> to vector<24x32xi32> - // FPToUIOp: f32 -> i32 (unsigned) - %fptoui = arith.fptoui %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> to vector<24x32xi32> - // IndexCastUIOp: i32 -> index - %indexcastui = arith.index_castui %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> to vector<24x32xindex> - // IndexCastOp: index -> i32 - %indexcast = arith.index_cast %indexcastui - {layout_result_0 = #xegpu.layout} - : vector<24x32xindex> to vector<24x32xi32> - // BitcastOp: i32 -> f32 - %bitcast = arith.bitcast %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> to vector<24x32xf32> - gpu.return - } - - - // CHECK-LABEL: gpu.func @type_conversion_ops_rr_assignment - gpu.func @type_conversion_ops_rr_assignment( - %a: memref<24x32xf32>, %b: memref<24x32xi32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - - %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - -> vector<24x32xf32> - %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - %load_c = xegpu.load_nd %tdesc_c - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - %load_d = xegpu.load_nd %tdesc_d - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - - // CHECK-COUNT-12: arith.truncf {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> to vector<2x2xf16> - // CHECK-NOT: arith.truncf - // CHECK-COUNT-12: arith.trunci {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> to vector<2x2xi16> - // CHECK-NOT: arith.trunci - // CHECK-COUNT-12: arith.extf {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf16> to vector<2x2xf32> - // CHECK-NOT: arith.extf - // CHECK-COUNT-12: arith.extsi {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi16> to vector<2x2xi32> - // CHECK-NOT: arith.extsi - // CHECK-COUNT-12: arith.extui {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi16> to vector<2x2xi32> - // CHECK-NOT: arith.extui - // CHECK-COUNT-12: arith.sitofp {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> to vector<2x2xf32> - // CHECK-NOT: arith.sitofp - // CHECK-COUNT-12: arith.uitofp {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> to vector<2x2xf32> - // CHECK-NOT: arith.uitofp - // CHECK-COUNT-12: arith.fptosi {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> to vector<2x2xi32> - // CHECK-NOT: arith.fptosi - // CHECK-COUNT-12: arith.fptoui {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> to vector<2x2xi32> - // CHECK-NOT: arith.fptoui - // CHECK-COUNT-12: arith.index_castui {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> to vector<2x2xindex> - // CHECK-NOT: arith.index_castui - // CHECK-COUNT-12: arith.index_cast {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xindex> to vector<2x2xi32> - // CHECK-NOT: arith.index_cast - // CHECK-COUNT-12: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> to vector<2x2xf32> - // CHECK-NOT: arith.bitcast - // TruncFOp: f32 -> f16 - %truncf = arith.truncf %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> to vector<24x32xf16> - // TruncIOp: i32 -> i16 - %trunci = arith.trunci %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> to vector<24x32xi16> - // ExtFOp: f16 -> f32 - %truncf16 = arith.truncf %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> to vector<24x32xf16> - %extf = arith.extf %truncf16 - {layout_result_0 = #xegpu.layout} - : vector<24x32xf16> to vector<24x32xf32> - // ExtSIOp: i16 -> i32 - %extsi = arith.extsi %trunci - {layout_result_0 = #xegpu.layout} - : vector<24x32xi16> to vector<24x32xi32> - // ExtUIOp: i16 -> i32 (unsigned) - %extui = arith.extui %trunci - {layout_result_0 = #xegpu.layout} - : vector<24x32xi16> to vector<24x32xi32> - // SIToFPOp: i32 -> f32 - %sitofp = arith.sitofp %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> to vector<24x32xf32> - // UIToFPOp: i32 -> f32 (unsigned) - %uitofp = arith.uitofp %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> to vector<24x32xf32> - // FPToSIOp: f32 -> i32 - %fptosi = arith.fptosi %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> to vector<24x32xi32> - // FPToUIOp: f32 -> i32 (unsigned) - %fptoui = arith.fptoui %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> to vector<24x32xi32> - // IndexCastUIOp: i32 -> index - %indexcastui = arith.index_castui %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> to vector<24x32xindex> - // IndexCastOp: index -> i32 - %indexcast = arith.index_cast %indexcastui - {layout_result_0 = #xegpu.layout} - : vector<24x32xindex> to vector<24x32xi32> - // BitcastOp: i32 -> f32 - %bitcast = arith.bitcast %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> to vector<24x32xf32> - gpu.return - } - - // CHECK-LABEL: gpu.func @comparison_ops - gpu.func @comparison_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - - %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - -> vector<24x32xf32> - %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - -> vector<24x32xf32> - %load_c = xegpu.load_nd %tdesc_c - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - %load_d = xegpu.load_nd %tdesc_d - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - - // CHECK: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.cmpi ne, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.cmpi slt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.cmpi sle, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.cmpi sgt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.cmpi sge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.cmpi ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.cmpi ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.cmpi ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.cmpi uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.cmpf oeq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.cmpf ogt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.cmpf oge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.cmpf olt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.cmpf ole, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.cmpf one, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.cmpf ord, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.cmpf ueq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.cmpf ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.cmpf uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.cmpf ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.cmpf une, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.cmpf uno, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // Integer comparisons - %cmpi_eq = arith.cmpi eq, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_ne = arith.cmpi ne, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_slt = arith.cmpi slt, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_sle = arith.cmpi sle, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_sgt = arith.cmpi sgt, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_sge = arith.cmpi sge, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_ult = arith.cmpi ult, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_ule = arith.cmpi ule, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_ugt = arith.cmpi ugt, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_uge = arith.cmpi uge, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - - // Floating point comparisons - %cmpf_oeq = arith.cmpf oeq, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_ogt = arith.cmpf ogt, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_oge = arith.cmpf oge, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_olt = arith.cmpf olt, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_ole = arith.cmpf ole, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_one = arith.cmpf one, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_ord = arith.cmpf ord, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_ueq = arith.cmpf ueq, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_ugt = arith.cmpf ugt, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_uge = arith.cmpf uge, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_ult = arith.cmpf ult, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_ule = arith.cmpf ule, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_une = arith.cmpf une, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_uno = arith.cmpf uno, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - gpu.return - } - - // CHECK-LABEL: gpu.func @comparison_ops_rr_assignment - gpu.func @comparison_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - - %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - -> vector<24x32xf32> - %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - -> vector<24x32xf32> - %load_c = xegpu.load_nd %tdesc_c - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - %load_d = xegpu.load_nd %tdesc_d - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - - // CHECK-COUNT-12: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.cmpi eq - // CHECK-COUNT-12: arith.cmpi ne, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.cmpi ne - // CHECK-COUNT-12: arith.cmpi slt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.cmpi slt - // CHECK-COUNT-12: arith.cmpi sle, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.cmpi sle - // CHECK-COUNT-12: arith.cmpi sgt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.cmpi sgt - // CHECK-COUNT-12: arith.cmpi sge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.cmpi sge - // CHECK-COUNT-12: arith.cmpi ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.cmpi ult - // CHECK-COUNT-12: arith.cmpi ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.cmpi ule - // CHECK-COUNT-12: arith.cmpi ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.cmpi ugt - // CHECK-COUNT-12: arith.cmpi uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xi32> - // CHECK-NOT: arith.cmpi uge - // Floating point comparisons - // CHECK-COUNT-12: arith.cmpf oeq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.cmpf oeq - // CHECK-COUNT-12: arith.cmpf ogt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.cmpf ogt - // CHECK-COUNT-12: arith.cmpf oge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.cmpf oge - // CHECK-COUNT-12: arith.cmpf olt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.cmpf olt - // CHECK-COUNT-12: arith.cmpf ole, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.cmpf ole - // CHECK-COUNT-12: arith.cmpf one, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.cmpf one - // CHECK-COUNT-12: arith.cmpf ord, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.cmpf ord - // CHECK-COUNT-12: arith.cmpf ueq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.cmpf ueq - // CHECK-COUNT-12: arith.cmpf ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.cmpf ugt - // CHECK-COUNT-12: arith.cmpf uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.cmpf uge - // CHECK-COUNT-12: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.cmpf ult - // CHECK-COUNT-12: arith.cmpf ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.cmpf ule - // CHECK-COUNT-12: arith.cmpf une, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.cmpf une - // CHECK-COUNT-12: arith.cmpf uno, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<2x2xf32> - // CHECK-NOT: arith.cmpf uno - - // Integer comparisons - %cmpi_eq = arith.cmpi eq, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_ne = arith.cmpi ne, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_slt = arith.cmpi slt, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_sle = arith.cmpi sle, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_sgt = arith.cmpi sgt, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_sge = arith.cmpi sge, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_ult = arith.cmpi ult, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_ule = arith.cmpi ule, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_ugt = arith.cmpi ugt, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cmpi_uge = arith.cmpi uge, %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - - // Floating point comparisons - %cmpf_oeq = arith.cmpf oeq, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_ogt = arith.cmpf ogt, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_oge = arith.cmpf oge, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_olt = arith.cmpf olt, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_ole = arith.cmpf ole, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_one = arith.cmpf one, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_ord = arith.cmpf ord, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_ueq = arith.cmpf ueq, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_ugt = arith.cmpf ugt, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_uge = arith.cmpf uge, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_ult = arith.cmpf ult, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_ule = arith.cmpf ule, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_une = arith.cmpf une, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf_uno = arith.cmpf uno, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - - gpu.return - } - - // CHECK-LABEL: gpu.func @elementwise_ops - gpu.func @elementwise_ops( - %a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>, %e: memref<24x32xi1>) { - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> - -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - %tdesc_e = xegpu.create_nd_tdesc %e[0, 0] : memref<24x32xi1> - -> !xegpu.tensor_desc<24x32xi1, #xegpu.layout> - - %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - -> vector<24x32xf32> - %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> - -> vector<24x32xf32> - %load_c = xegpu.load_nd %tdesc_c - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - %load_d = xegpu.load_nd %tdesc_d - : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> - -> vector<24x32xi32> - %load_e = xegpu.load_nd %tdesc_e - : !xegpu.tensor_desc<24x32xi1, #xegpu.layout> - -> vector<24x32xi1> - - // CHECK: arith.andi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.ori {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.xori {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.ceildivsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.ceildivui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.floordivsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.maxnumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.maxsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.maxui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.minnumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.minsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.minui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: arith.remf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: arith.select {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi1>, vector<12x8xf32> - // CHECK: math.absi {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: math.cbrt {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.copysign {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.ctpop {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: math.erfc {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.exp2 {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.expm1 {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.fpowi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32>, vector<12x8xi32> - // CHECK: math.ipowi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xi32> - // CHECK: math.log10 {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.log1p {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.round {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.roundeven {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // CHECK: math.trunc {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> - // arith ops - %andi = arith.andi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %ori = arith.ori %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %xori = arith.xori %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %ceildivsi = arith.ceildivsi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %ceildivui = arith.ceildivui %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %floordivsi = arith.floordivsi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %maxnumf = arith.maxnumf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %maxsi = arith.maxsi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %maxui = arith.maxui %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %minnumf = arith.minnumf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %minsi = arith.minsi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %minui = arith.minui %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %remf = arith.remf %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %cmpf = arith.cmpf ult, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %select = arith.select %cmpf, %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xi1>, vector<24x32xf32> - - // math ops - %absi = math.absi %load_c - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %cbrt = math.cbrt %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %copysign = math.copysign %load_a, %load_b - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %ctpop = math.ctpop %load_c - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %erfc = math.erfc %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %exp2 = math.exp2 %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %expm1 = math.expm1 %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %fpowi = math.fpowi %load_a, %load_c - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32>, vector<24x32xi32> - %ipowi = math.ipowi %load_c, %load_d - {layout_result_0 = #xegpu.layout} - : vector<24x32xi32> - %log10 = math.log10 %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %log1p = math.log1p %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %round = math.round %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %roundeven = math.roundeven %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - %trunc = math.trunc %load_a - {layout_result_0 = #xegpu.layout} - : vector<24x32xf32> - gpu.return - } -} +gpu.module @test_elementwise_ops { + // CHECK-LABEL: unary_ops + gpu.func @unary_ops(%a: memref<24x32xf32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + // CHECK: math.exp {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + %exp = math.exp %load_a + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + // CHECK: arith.negf {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + %negf = arith.negf %load_a + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + gpu.return + } + + // CHECK-LABEL: binary_ops + gpu.func @binary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + // CHECK: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xf32> + %addf = arith.addf %load_a, %load_b + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + // CHECK: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xf32> + %powf = math.powf %load_a, %load_b + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + gpu.return + } + + // CHECK-LABEL: ternary_ops + gpu.func @ternary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi1>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi1> + -> !xegpu.tensor_desc<24x32xi1, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_c = xegpu.load_nd %tdesc_c + : !xegpu.tensor_desc<24x32xi1, #xegpu.layout> + -> vector<24x32xi1> + // CHECK: arith.select {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xi1>, vector<12x8xf32> + %select = arith.select %load_c, %load_a, %load_b + {layout_result_0 = #xegpu.layout} + : vector<24x32xi1>, vector<24x32xf32> + // CHECK: math.fma {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xf32> + %fma = math.fma %load_a, %load_b, %load_a + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + gpu.return + } + + // CHECK-LABEL: type_conversion_ops + gpu.func @type_conversion_ops(%a: memref<24x32xf32>, %b: memref<24x32xi32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + // CHECK: arith.truncf {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xf32> to vector<12x8xf16> + %truncf = arith.truncf %load_a + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> to vector<24x32xf16> + // CHECK: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xi32> to vector<12x8xf32> + %bitcast = arith.bitcast %load_b + {layout_result_0 = #xegpu.layout} + : vector<24x32xi32> to vector<24x32xf32> + gpu.return + } + + // CHECK-LABEL: comparison_ops + gpu.func @comparison_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_c = xegpu.load_nd %tdesc_c + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + %load_d = xegpu.load_nd %tdesc_d + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + // CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xf32> + %cmpf = arith.cmpf ult, %load_a, %load_b + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + // CHECK: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xi32> + %cmpi = arith.cmpi eq, %load_c, %load_d + {layout_result_0 = #xegpu.layout} + : vector<24x32xi32> + gpu.return + } + + // 1 to N decomposition of elementwise operations + // CHECK-LABEL: elementwise_ops_rr_assignment + gpu.func @elementwise_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME-COUNT-12: : vector<2x2xf32> + // CHECK-NOT: arith.negf + %negf = arith.negf %load_a + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME-COUNT-12: : vector<2x2xf32> + // CHECK-NOT: math.powf + %powf = math.powf %load_a, %load_b + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + gpu.return + } +} \ No newline at end of file From 8a0b3dfc80af78b2627257af801ca6aaee8e3bb1 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 12 Jun 2025 21:42:40 +0000 Subject: [PATCH 09/13] Newline --- mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir index 3672e4d9912cf..64f01d61d6e80 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir @@ -161,4 +161,4 @@ gpu.module @test_elementwise_ops { : vector<24x32xf32> gpu.return } -} \ No newline at end of file +} From 5f7c8f350c5714b113a1b6d24d43c5e8da3d6710 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 16 Jun 2025 03:00:09 +0000 Subject: [PATCH 10/13] Clean up --- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 8f991fb1f5aea..8382aee45e1ca 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -341,24 +341,21 @@ struct WgToSgElementwiseOp : public ConversionPattern { ConversionPatternRewriter &rewriter) const override { // Only match ops with elementwise trait if (!OpTrait::hasElementwiseMappableTraits(op)) - return rewriter.notifyMatchFailure(op, "Not an elementwise op"); + return failure(); auto resultType = dyn_cast(op->getResult(0).getType()); ArrayRef wgShape = resultType.getShape(); xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0)); if (!layout || !layout.getSgLayout()) - return rewriter.notifyMatchFailure( - op, "Operation does not have a valid layout attribute for subgroup " - "distribution"); + return failure(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; size_t numVariants = operands.empty() ? 0 : operands.front().size(); for (auto &operandVec : operands) if (operandVec.size() != numVariants) - return rewriter.notifyMatchFailure( - op, "Operand lists have mismatched sizes"); + return failure(); SmallVector newResults; VectorType newResultType = @@ -375,7 +372,7 @@ struct WgToSgElementwiseOp : public ConversionPattern { // Copy all attributes, but update "layout_result_0" to drop // sgLayout/sgData for (auto attr : op->getAttrs()) { - if (attr.getName() != "layout_result_0") + if (!isa(attr.getValue())) state.addAttribute(attr.getName(), attr.getValue()); } Operation *newOp = rewriter.create(state); @@ -598,10 +595,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() { } } - auto layout = dyn_cast_or_null( - op->getAttrOfType("layout_result_0")); + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0)); return isLegal(layout); }); + target.addDynamicallyLegalOp( [=](UnrealizedConversionCastOp op) { return llvm::is_contained(existingCastOps, op.getOperation()); From 69b57868372047f55a0f44b893cb3545bf4705ae Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 16 Jun 2025 18:32:27 +0000 Subject: [PATCH 11/13] Feedback --- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 8382aee45e1ca..23abfd3fdb4cb 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -339,8 +339,8 @@ struct WgToSgElementwiseOp : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - // Only match ops with elementwise trait - if (!OpTrait::hasElementwiseMappableTraits(op)) + // Only match ops with elementwise trait and single result. + if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) return failure(); auto resultType = dyn_cast(op->getResult(0).getType()); @@ -353,9 +353,12 @@ struct WgToSgElementwiseOp : public ConversionPattern { SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; size_t numVariants = operands.empty() ? 0 : operands.front().size(); - for (auto &operandVec : operands) - if (operandVec.size() != numVariants) - return failure(); + // Only VectorType operands are supported here. + // TODO: Support other types. + if (llvm::any_of(operands, [&](const ValueRange &operandVec) { + return operandVec.size() != numVariants; + })) + return failure(); SmallVector newResults; VectorType newResultType = From b15c720132842933cf88fd8afbe4bbe8b9dec3da Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 17 Jun 2025 15:55:40 +0000 Subject: [PATCH 12/13] Address comments --- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 23abfd3fdb4cb..f828f5b52424b 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -331,7 +331,7 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern { } }; -// This pattern transforms elementwise ops in math/arith dialect +// This pattern transforms elementwise ops to work at subgroup level. struct WgToSgElementwiseOp : public ConversionPattern { WgToSgElementwiseOp(MLIRContext *ctx) : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} @@ -344,6 +344,8 @@ struct WgToSgElementwiseOp : public ConversionPattern { return failure(); auto resultType = dyn_cast(op->getResult(0).getType()); + assert(resultType && "Expected result to be a VectorType"); + ArrayRef wgShape = resultType.getShape(); xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0)); @@ -353,8 +355,7 @@ struct WgToSgElementwiseOp : public ConversionPattern { SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; size_t numVariants = operands.empty() ? 0 : operands.front().size(); - // Only VectorType operands are supported here. - // TODO: Support other types. + if (llvm::any_of(operands, [&](const ValueRange &operandVec) { return operandVec.size() != numVariants; })) @@ -375,11 +376,12 @@ struct WgToSgElementwiseOp : public ConversionPattern { // Copy all attributes, but update "layout_result_0" to drop // sgLayout/sgData for (auto attr : op->getAttrs()) { - if (!isa(attr.getValue())) + if (isa(attr.getValue())) + state.addAttribute(attr.getName(), layout.dropSgLayoutAndData()); + else state.addAttribute(attr.getName(), attr.getValue()); } Operation *newOp = rewriter.create(state); - xegpu::setLayoutAttr(newOp->getResult(0), layout.dropSgLayoutAndData()); newResults.push_back(newOp->getResult(0)); } @@ -591,6 +593,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return true; // Check if all operands are vectors of the same shape + // TODO: Support other types. for (Value operand : op->getOperands()) { VectorType operandType = dyn_cast(operand.getType()); if (!operandType || operandType.getShape() != resultType.getShape()) { From b62979772af4df820e99edc1346ef27f70eaecc4 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 17 Jun 2025 16:15:13 +0000 Subject: [PATCH 13/13] fix --- mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index f828f5b52424b..e3563d10bc6f1 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -376,7 +376,7 @@ struct WgToSgElementwiseOp : public ConversionPattern { // Copy all attributes, but update "layout_result_0" to drop // sgLayout/sgData for (auto attr : op->getAttrs()) { - if (isa(attr.getValue())) + if (auto layout = dyn_cast(attr.getValue())) state.addAttribute(attr.getName(), layout.dropSgLayoutAndData()); else state.addAttribute(attr.getName(), attr.getValue());