diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index a26c6b52f0ddc..e3563d10bc6f1 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -8,10 +8,12 @@ #include "mlir/Dialect/XeGPU/Transforms/Passes.h" #include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" @@ -19,6 +21,7 @@ #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Transforms/DialectConversion.h" +#include namespace mlir { namespace xegpu { @@ -328,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern { } }; +// This pattern transforms elementwise ops to work at subgroup level. +struct WgToSgElementwiseOp : public ConversionPattern { + WgToSgElementwiseOp(MLIRContext *ctx) + : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // Only match ops with elementwise trait and single result. + if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) + return failure(); + + auto resultType = dyn_cast(op->getResult(0).getType()); + assert(resultType && "Expected result to be a VectorType"); + + ArrayRef wgShape = resultType.getShape(); + + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0)); + if (!layout || !layout.getSgLayout()) + return failure(); + + SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; + + size_t numVariants = operands.empty() ? 0 : operands.front().size(); + + if (llvm::any_of(operands, [&](const ValueRange &operandVec) { + return operandVec.size() != numVariants; + })) + return failure(); + + SmallVector newResults; + VectorType newResultType = + VectorType::get(sgShape, resultType.getElementType()); + + for (size_t i = 0; i < numVariants; ++i) { + SmallVector opOperands; + for (auto &operandVec : operands) + opOperands.push_back(operandVec[i]); + + OperationState state(op->getLoc(), op->getName()); + state.addOperands(opOperands); + state.addTypes(newResultType); + // Copy all attributes, but update "layout_result_0" to drop + // sgLayout/sgData + for (auto attr : op->getAttrs()) { + if (auto layout = dyn_cast(attr.getValue())) + state.addAttribute(attr.getName(), layout.dropSgLayoutAndData()); + else + state.addAttribute(attr.getName(), attr.getValue()); + } + Operation *newOp = rewriter.create(state); + newResults.push_back(newOp->getResult(0)); + } + + rewriter.replaceOpWithMultiple(op, {newResults}); + return success(); + } +}; + // Handles UnrealizedConversionCastOp generated during // SCFStructuralTypeConversions (step 1). This op may appear as either a // target or source materialization for Vector values, e.g.: @@ -411,7 +473,8 @@ namespace xegpu { void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>( + patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -518,6 +581,30 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); + target.addDynamicallyLegalDialect( + [=](Operation *op) -> std::optional { + // Only handle elementwise mappable ops + if (!OpTrait::hasElementwiseMappableTraits(op)) + return true; + + VectorType resultType = + dyn_cast(op->getResult(0).getType()); + if (!resultType) + return true; + + // Check if all operands are vectors of the same shape + // TODO: Support other types. + for (Value operand : op->getOperands()) { + VectorType operandType = dyn_cast(operand.getType()); + if (!operandType || operandType.getShape() != resultType.getShape()) { + return true; + } + } + + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0)); + return isLegal(layout); + }); + target.addDynamicallyLegalOp( [=](UnrealizedConversionCastOp op) { return llvm::is_contained(existingCastOps, op.getOperation()); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir new file mode 100644 index 0000000000000..64f01d61d6e80 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir @@ -0,0 +1,164 @@ +// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s + +gpu.module @test_elementwise_ops { + // CHECK-LABEL: unary_ops + gpu.func @unary_ops(%a: memref<24x32xf32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + // CHECK: math.exp {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + %exp = math.exp %load_a + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + // CHECK: arith.negf {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x8xf32> + %negf = arith.negf %load_a + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + gpu.return + } + + // CHECK-LABEL: binary_ops + gpu.func @binary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + // CHECK: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xf32> + %addf = arith.addf %load_a, %load_b + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + // CHECK: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xf32> + %powf = math.powf %load_a, %load_b + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + gpu.return + } + + // CHECK-LABEL: ternary_ops + gpu.func @ternary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi1>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi1> + -> !xegpu.tensor_desc<24x32xi1, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_c = xegpu.load_nd %tdesc_c + : !xegpu.tensor_desc<24x32xi1, #xegpu.layout> + -> vector<24x32xi1> + // CHECK: arith.select {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xi1>, vector<12x8xf32> + %select = arith.select %load_c, %load_a, %load_b + {layout_result_0 = #xegpu.layout} + : vector<24x32xi1>, vector<24x32xf32> + // CHECK: math.fma {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xf32> + %fma = math.fma %load_a, %load_b, %load_a + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + gpu.return + } + + // CHECK-LABEL: type_conversion_ops + gpu.func @type_conversion_ops(%a: memref<24x32xf32>, %b: memref<24x32xi32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + // CHECK: arith.truncf {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xf32> to vector<12x8xf16> + %truncf = arith.truncf %load_a + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> to vector<24x32xf16> + // CHECK: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xi32> to vector<12x8xf32> + %bitcast = arith.bitcast %load_b + {layout_result_0 = #xegpu.layout} + : vector<24x32xi32> to vector<24x32xf32> + gpu.return + } + + // CHECK-LABEL: comparison_ops + gpu.func @comparison_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32> + -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_c = xegpu.load_nd %tdesc_c + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + %load_d = xegpu.load_nd %tdesc_d + : !xegpu.tensor_desc<24x32xi32, #xegpu.layout> + -> vector<24x32xi32> + // CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xf32> + %cmpf = arith.cmpf ult, %load_a, %load_b + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + // CHECK: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xi32> + %cmpi = arith.cmpi eq, %load_c, %load_d + {layout_result_0 = #xegpu.layout} + : vector<24x32xi32> + gpu.return + } + + // 1 to N decomposition of elementwise operations + // CHECK-LABEL: elementwise_ops_rr_assignment + gpu.func @elementwise_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>) { + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME-COUNT-12: : vector<2x2xf32> + // CHECK-NOT: arith.negf + %negf = arith.negf %load_a + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME-COUNT-12: : vector<2x2xf32> + // CHECK-NOT: math.powf + %powf = math.powf %load_a, %load_b + {layout_result_0 = #xegpu.layout} + : vector<24x32xf32> + gpu.return + } +}