diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index d022361d1e376..64eb21cbc3c4c 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -336,4 +336,33 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> { let genVerifyDecl = 1; } +def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> { + let summary = [{Specifies a half-open range}]; + let description = [{ + `RangeAttr` is an attribute that defines a half-open range [start, end). + The range is inclusive of the start value and exclusive of the end value. + One usage of this attribute can be to specify the subgroup id range. + The subgroup id range can be specified using this attribute, + and it can be attached to a scf.if op like + ```mlir + scf.if %cond { + // some operations + } {sg_id_range = #xegpu.range<[2, 4]>} + ``` + In this case, the scf.if op will only be executed for subgroup IDs 2 and 3. + }]; + + let parameters = (ins + "IntegerAttr": $start, + "IntegerAttr": $end + ); + + let builders = [ + AttrBuilder<(ins "int":$start, "int":$end)> + ]; + + let assemblyFormat = "`<` `[`$start `,` $end `]` `>`"; + let genVerifyDecl = 1; +} + #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 8ab404d52eab4..3c0ca114a62d4 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -211,6 +211,21 @@ LayoutAttr::verify(llvm::function_ref emitError, return success(); } +//===----------------------------------------------------------------------===// +// XeGPU_RangeAttr +//===----------------------------------------------------------------------===// + +LogicalResult +RangeAttr::verify(llvm::function_ref emitError, + IntegerAttr startOfRange, IntegerAttr endOfRange) { + if (startOfRange.getInt() >= endOfRange.getInt()) + return emitError() << "'end' : " << endOfRange.getInt() + << " must be greater than 'start' : " + << startOfRange.getInt(); + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_TensorDescType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index ef52323a9f46b..229a289838c60 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -34,6 +34,26 @@ using namespace mlir; namespace { +// Check if there is sg id range attached to the scf.if op. +static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange, + int64_t &endOfRange) { + Operation *parent = op->getParentOp(); + // Find the outermost scf::IfOp with xegpu.sg_id_range. + while (parent) { + if (auto ifOp = dyn_cast(parent)) { + if (auto attr = llvm::dyn_cast_or_null( + ifOp->getAttr("sg_id_range"))) { + startOfRange = attr.getStart().getInt(); + endOfRange = attr.getEnd().getInt(); + break; + } + } + parent = parent->getParentOp(); + } + // Return false if startOfRange is 0 + return (startOfRange > 0 && endOfRange > startOfRange); +} + static std::pair, int> getSgShapeAndCount(ArrayRef shape, xegpu::LayoutAttr layout) { int count = 1; @@ -174,8 +194,26 @@ struct WgToSgCreateNdOp : public OpConversionPattern { sgDataDim[i] = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]); } + int64_t startOfRange = -1, endOfRange = -1; + bool sgIdRangeSpecified = + isSgIdRangeSpecified(op, startOfRange, endOfRange); + + Value adjustedSgId = linearSgId; + if (sgIdRangeSpecified) { + int64_t sgCount = endOfRange - startOfRange; + if (computeProduct(sgLayout) != sgCount) + return rewriter.notifyMatchFailure( + op, "sg_layout size must match the sg_id_range"); + // Subtract startOfRange from the original subgroup id to get the adjusted + // sg id + Value startOfRangeVal = + rewriter.create(loc, startOfRange); + adjustedSgId = + rewriter.createOrFold(loc, linearSgId, startOfRangeVal); + } + auto deLinearizeSgId = - affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim); + affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim); if (failed(deLinearizeSgId)) return failure(); SmallVector sgIds = *deLinearizeSgId; diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index 8a81a286da23a..d51122417fb61 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -327,5 +327,70 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> gpu.return } -} + // CHECK-LABEL: @subgroup_id_range + gpu.func @subgroup_id_range(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) { + %sg_id = gpu.subgroup_id : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c31 = arith.constant 31 : index + %c3 = arith.constant 3 : index + %cond1 = arith.cmpi sge, %sg_id, %c0 : index + %cond2 = arith.cmpi slt, %sg_id, %c1 : index + %cond = arith.andi %cond1, %cond2 : i1 + scf.if %cond { + // CHECK-NOT: index.sub + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + -> vector<256x128xf32> + } {sg_id_range = #xegpu.range<[0, 32]>} + %cond3 = arith.cmpi sge, %sg_id, %c2 : index + %cond4 = arith.cmpi slt, %sg_id, %c31 : index + %cond5 = arith.andi %cond3, %cond4 : i1 + scf.if %cond5 { + // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]] + %tdesc = xegpu.create_nd_tdesc %src2[0, 0] : memref<128x64xf32> + -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc + : !xegpu.tensor_desc<128x64xf32, #xegpu.layout> + -> vector<128x64xf32> + %exp = math.exp %load {layout_result_0 = #xegpu.layout} : vector<128x64xf32> + }{sg_id_range = #xegpu.range<[2, 18]>} + gpu.return + } + + // CHECK-LABEL: @subgroup_id_range_nested_if + gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) { + %sg_id = gpu.subgroup_id : index + %c1 = arith.constant 1 : i1 + %c3 = arith.constant 3 : index + %c32 = arith.constant 32 : index + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + -> vector<256x128xf32> + %cond1 = arith.cmpi sge, %sg_id, %c3 : index + %cond2 = arith.cmpi slt, %sg_id, %c32 : index + %cond = arith.andi %cond1, %cond2 : i1 + scf.if %c1 { + scf.if %cond { + // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK: %[[C3:.*]] = arith.constant 3 : index + // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]] + %td = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x64xf32> + -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout> + %ld = xegpu.load_nd %td + : !xegpu.tensor_desc<128x64xf32, #xegpu.layout> + -> vector<128x64xf32> + %exp = math.exp %ld {layout_result_0 = #xegpu.layout} : vector<128x64xf32> + } + } {sg_id_range = #xegpu.range<[3, 19]>} + gpu.return + } +}