From 1d18b895bface3094ac4868601ebeadc0ae2758c Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 11 Jul 2025 14:26:08 +0000 Subject: [PATCH 1/9] Add support for subgroup_id_range --- .../Transforms/XeGPUWgToSgDistribute.cpp | 40 ++++++++- mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 83 +++++++++++++++++++ 2 files changed, 122 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index be7b860dd1729..56dc132d8083d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -174,8 +174,46 @@ struct WgToSgCreateNdOp : public OpConversionPattern { sgDataDim[i] = rewriter.create(loc, sgShape[i]); } + // Check if there is warp specialization. + auto isWarpSpecialized = [](Operation *op, int64_t &startRange, + int64_t &endRange) -> bool { + Operation *parent = op->getParentOp(); + // Find the outermost scf::IfOp with xegpu.sg_id_range. + while (parent) { + if (auto ifOp = dyn_cast(parent)) { + if (Attribute attr = ifOp->getAttr("xegpu.sg_id_range")) { + if (auto denseAttr = dyn_cast(attr)) { + auto values = denseAttr.asArrayRef(); + if (values.size() == 2) { + startRange = values[0]; + endRange = values[1]; + } + } + break; + } + } + parent = parent->getParentOp(); + } + // Return false if startRange is 0 + return (startRange > 0 && endRange > startRange); + }; + + int64_t startRange = -1, endRange = -1; + bool warpSpecialized = isWarpSpecialized(op, startRange, endRange); + + // If warp specialization is detected, adjust the subgroup id accordingly + Value adjustedSgId = linearSgId; + if (warpSpecialized) { + // Subtract startRange from the original subgroup id to get the adjusted + // sg id + Value startRangeVal = + rewriter.create(loc, startRange); + adjustedSgId = + rewriter.createOrFold(loc, linearSgId, startRangeVal); + } + auto deLinearizeSgId = - affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim); + affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim); if (failed(deLinearizeSgId)) return failure(); SmallVector sgIds = *deLinearizeSgId; diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index 44b11c304cc80..71eb732ac4953 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -296,5 +296,88 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { gpu.return } + // CHECK-LABEL: @warp_specialized + gpu.func @warp_specialized(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) { + %sg_id = gpu.subgroup_id : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c31 = arith.constant 31 : index + %c3 = arith.constant 3 : index + %cond1 = arith.cmpi sge, %sg_id, %c0 : index + %cond2 = arith.cmpi slt, %sg_id, %c1 : index + %cond = arith.andi %cond1, %cond2 : i1 + scf.if %cond { + // CHECK-NOT: index.sub + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + -> vector<256x128xf32> + } {xegpu.sg_id_range = array} + %cond3 = arith.cmpi sge, %sg_id, %c1 : index + %cond4 = arith.cmpi slt, %sg_id, %c2 : index + %cond5 = arith.andi %cond3, %cond4 : i1 + scf.if %cond5 { + // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C1]] + %tdesc_a = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + -> vector<256x128xf32> + %tdesc_b = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x256xf32> + -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<128x256xf32, #xegpu.layout> + -> vector<128x256xf32> + %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout} : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32> + }{xegpu.sg_id_range = array} + %cond6 = arith.cmpi sge, %sg_id, %c2 : index + %cond7 = arith.cmpi slt, %sg_id, %c31 : index + %cond8 = arith.andi %cond6, %cond7 : i1 + scf.if %cond8 { + // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]] + %tdesc = xegpu.create_nd_tdesc %src2[0, 0] : memref<128x64xf32> + -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc + : !xegpu.tensor_desc<128x64xf32, #xegpu.layout> + -> vector<128x64xf32> + %exp = math.exp %load {layout_result_0 = #xegpu.layout} : vector<128x64xf32> + }{xegpu.sg_id_range = array} + gpu.return + } + // CHECK-LABEL: @subgroup_id_range_nested_if + gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) { + %sg_id = gpu.subgroup_id : index + %c1 = arith.constant 1 : i1 + %c3 = arith.constant 3 : index + %c32 = arith.constant 32 : index + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + -> vector<256x128xf32> + %cond1 = arith.cmpi sge, %sg_id, %c3 : index + %cond2 = arith.cmpi slt, %sg_id, %c32 : index + %cond = arith.andi %cond1, %cond2 : i1 + scf.if %c1 { + scf.if %cond { + // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK: %[[C3:.*]] = arith.constant 3 : index + // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]] + %td = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x64xf32> + -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout> + %ld = xegpu.load_nd %td + : !xegpu.tensor_desc<128x64xf32, #xegpu.layout> + -> vector<128x64xf32> + %exp = math.exp %ld {layout_result_0 = #xegpu.layout} : vector<128x64xf32> + } + } {xegpu.sg_id_range = array} + gpu.return + } } From b4e3068ca9d3e74d73ae9274834cc952d304a19f Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 16 Jul 2025 17:02:23 +0000 Subject: [PATCH 2/9] Add xegpu.sg_id_range attribute --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 27 ++++++++++++++++ .../Transforms/XeGPUWgToSgDistribute.cpp | 32 ++++++++----------- mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 8 ++--- 3 files changed, 45 insertions(+), 22 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 84c1dc1373ee5..306b6ec1eed16 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -315,4 +315,31 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> { let genVerifyDecl = 1; } +def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> { + let summary = [{Specifies a half-open range}]; + let description = [{ + `RangeAttr` is an attribute that defines a half-open range [start, end). + The range is inclusive of the start value and exclusive of the end value. + One usage of this attribute can be for warp specialization. + For warp specialization, this attribute can be attached to a scf.if op like + ```mlir + scf.if %cond { + // some operations + }{sg_id_range = #xegpu.range<[2, 4]>} + ``` + In this case, the scf.if op will only be executed for subgroup IDs 2 and 3. + }]; + + let parameters = (ins + "IntegerAttr": $start, + "IntegerAttr": $end + ); + + let builders = [ + AttrBuilder<(ins "int":$start, "int":$end)> + ]; + + let assemblyFormat = "`<` `[`$start ```,` $end `]``>`"; +} + #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 56dc132d8083d..eb89cca0070ac 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -175,41 +175,37 @@ struct WgToSgCreateNdOp : public OpConversionPattern { } // Check if there is warp specialization. - auto isWarpSpecialized = [](Operation *op, int64_t &startRange, - int64_t &endRange) -> bool { + auto isWarpSpecialized = [](Operation *op, int64_t &startOfRange, + int64_t &endOfRange) -> bool { Operation *parent = op->getParentOp(); // Find the outermost scf::IfOp with xegpu.sg_id_range. while (parent) { if (auto ifOp = dyn_cast(parent)) { - if (Attribute attr = ifOp->getAttr("xegpu.sg_id_range")) { - if (auto denseAttr = dyn_cast(attr)) { - auto values = denseAttr.asArrayRef(); - if (values.size() == 2) { - startRange = values[0]; - endRange = values[1]; - } - } + if (auto attr = llvm::dyn_cast_or_null( + ifOp->getAttr("sg_id_range"))) { + startOfRange = attr.getStart().getInt(); + endOfRange = attr.getEnd().getInt(); break; } } parent = parent->getParentOp(); } - // Return false if startRange is 0 - return (startRange > 0 && endRange > startRange); + // Return false if startOfRange is 0 + return (startOfRange > 0 && endOfRange > startOfRange); }; - int64_t startRange = -1, endRange = -1; - bool warpSpecialized = isWarpSpecialized(op, startRange, endRange); + int64_t startOfRange = -1, endOfRange = -1; + bool warpSpecialized = isWarpSpecialized(op, startOfRange, endOfRange); // If warp specialization is detected, adjust the subgroup id accordingly Value adjustedSgId = linearSgId; if (warpSpecialized) { - // Subtract startRange from the original subgroup id to get the adjusted + // Subtract startOfRange from the original subgroup id to get the adjusted // sg id - Value startRangeVal = - rewriter.create(loc, startRange); + Value startOfRangeVal = + rewriter.create(loc, startOfRange); adjustedSgId = - rewriter.createOrFold(loc, linearSgId, startRangeVal); + rewriter.createOrFold(loc, linearSgId, startOfRangeVal); } auto deLinearizeSgId = diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index 71eb732ac4953..39cd8c6158685 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -314,7 +314,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { %load = xegpu.load_nd %tdesc : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> -> vector<256x128xf32> - } {xegpu.sg_id_range = array} + } {sg_id_range = #xegpu.range<[0, 1]>} %cond3 = arith.cmpi sge, %sg_id, %c1 : index %cond4 = arith.cmpi slt, %sg_id, %c2 : index %cond5 = arith.andi %cond3, %cond4 : i1 @@ -333,7 +333,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { : !xegpu.tensor_desc<128x256xf32, #xegpu.layout> -> vector<128x256xf32> %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout} : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32> - }{xegpu.sg_id_range = array} + }{sg_id_range = #xegpu.range<[1, 2]>} %cond6 = arith.cmpi sge, %sg_id, %c2 : index %cond7 = arith.cmpi slt, %sg_id, %c31 : index %cond8 = arith.andi %cond6, %cond7 : i1 @@ -347,7 +347,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { : !xegpu.tensor_desc<128x64xf32, #xegpu.layout> -> vector<128x64xf32> %exp = math.exp %load {layout_result_0 = #xegpu.layout} : vector<128x64xf32> - }{xegpu.sg_id_range = array} + }{sg_id_range = #xegpu.range<[2, 32]>} gpu.return } @@ -377,7 +377,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { -> vector<128x64xf32> %exp = math.exp %ld {layout_result_0 = #xegpu.layout} : vector<128x64xf32> } - } {xegpu.sg_id_range = array} + } {sg_id_range = #xegpu.range<[3, 8]>} gpu.return } } From 70fe19cfb4811f4c5619a6f47affc0a5f01998eb Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 17 Jul 2025 00:54:11 +0000 Subject: [PATCH 3/9] Update tests --- mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 29 ++++----------------- 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index 39cd8c6158685..74c27a87cfb17 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -314,30 +314,11 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { %load = xegpu.load_nd %tdesc : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> -> vector<256x128xf32> - } {sg_id_range = #xegpu.range<[0, 1]>} - %cond3 = arith.cmpi sge, %sg_id, %c1 : index - %cond4 = arith.cmpi slt, %sg_id, %c2 : index + } {sg_id_range = #xegpu.range<[0, 32]>} + %cond3 = arith.cmpi sge, %sg_id, %c2 : index + %cond4 = arith.cmpi slt, %sg_id, %c31 : index %cond5 = arith.andi %cond3, %cond4 : i1 - scf.if %cond5 { - // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK: %[[C1:.*]] = arith.constant 1 : index - // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C1]] - %tdesc_a = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> - -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> - %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> - -> vector<256x128xf32> - %tdesc_b = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x256xf32> - -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout> - %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<128x256xf32, #xegpu.layout> - -> vector<128x256xf32> - %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout} : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32> - }{sg_id_range = #xegpu.range<[1, 2]>} - %cond6 = arith.cmpi sge, %sg_id, %c2 : index - %cond7 = arith.cmpi slt, %sg_id, %c31 : index - %cond8 = arith.andi %cond6, %cond7 : i1 - scf.if %cond8 { + scf.if %cond5 { // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]] @@ -377,7 +358,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { -> vector<128x64xf32> %exp = math.exp %ld {layout_result_0 = #xegpu.layout} : vector<128x64xf32> } - } {sg_id_range = #xegpu.range<[3, 8]>} + } {sg_id_range = #xegpu.range<[3, 32]>} gpu.return } } From e6528effd34b26b215c7b226c825a9ffcb2de677 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 17 Jul 2025 04:26:49 +0000 Subject: [PATCH 4/9] Add check for sgLayout size and sg_id_range --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 4 +- .../Transforms/XeGPUWgToSgDistribute.cpp | 50 ++++++++++--------- mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 4 +- 3 files changed, 31 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 306b6ec1eed16..64f1830bd286c 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -320,8 +320,8 @@ def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> { let description = [{ `RangeAttr` is an attribute that defines a half-open range [start, end). The range is inclusive of the start value and exclusive of the end value. - One usage of this attribute can be for warp specialization. - For warp specialization, this attribute can be attached to a scf.if op like + One usage of this attribute can be for subgroup specialization. + For subgroup specialization, this attribute can be attached to a scf.if op like ```mlir scf.if %cond { // some operations diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index eb89cca0070ac..e4212d6cb2020 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -34,6 +34,26 @@ using namespace mlir; namespace { +// Check if there is sg id specialization. +static bool isSgIdSpecialized(Operation *op, int64_t &startOfRange, + int64_t &endOfRange) { + Operation *parent = op->getParentOp(); + // Find the outermost scf::IfOp with xegpu.sg_id_range. + while (parent) { + if (auto ifOp = dyn_cast(parent)) { + if (auto attr = llvm::dyn_cast_or_null( + ifOp->getAttr("sg_id_range"))) { + startOfRange = attr.getStart().getInt(); + endOfRange = attr.getEnd().getInt(); + break; + } + } + parent = parent->getParentOp(); + } + // Return false if startOfRange is 0 + return (startOfRange > 0 && endOfRange > startOfRange); +} + static std::pair, int> getSgShapeAndCount(ArrayRef shape, xegpu::LayoutAttr layout) { int count = 1; @@ -174,32 +194,16 @@ struct WgToSgCreateNdOp : public OpConversionPattern { sgDataDim[i] = rewriter.create(loc, sgShape[i]); } - // Check if there is warp specialization. - auto isWarpSpecialized = [](Operation *op, int64_t &startOfRange, - int64_t &endOfRange) -> bool { - Operation *parent = op->getParentOp(); - // Find the outermost scf::IfOp with xegpu.sg_id_range. - while (parent) { - if (auto ifOp = dyn_cast(parent)) { - if (auto attr = llvm::dyn_cast_or_null( - ifOp->getAttr("sg_id_range"))) { - startOfRange = attr.getStart().getInt(); - endOfRange = attr.getEnd().getInt(); - break; - } - } - parent = parent->getParentOp(); - } - // Return false if startOfRange is 0 - return (startOfRange > 0 && endOfRange > startOfRange); - }; - int64_t startOfRange = -1, endOfRange = -1; - bool warpSpecialized = isWarpSpecialized(op, startOfRange, endOfRange); + bool sgIdSpecialized = isSgIdSpecialized(op, startOfRange, endOfRange); - // If warp specialization is detected, adjust the subgroup id accordingly Value adjustedSgId = linearSgId; - if (warpSpecialized) { + if (sgIdSpecialized) { + int64_t expectedSgLayoutSize = endOfRange - startOfRange; + if (computeProduct(sgLayout) != expectedSgLayoutSize) { + return rewriter.notifyMatchFailure( + op, "sg_layout size must match the sg_id_range"); + } // Subtract startOfRange from the original subgroup id to get the adjusted // sg id Value startOfRangeVal = diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index 74c27a87cfb17..4014cdff24a6e 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -328,7 +328,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { : !xegpu.tensor_desc<128x64xf32, #xegpu.layout> -> vector<128x64xf32> %exp = math.exp %load {layout_result_0 = #xegpu.layout} : vector<128x64xf32> - }{sg_id_range = #xegpu.range<[2, 32]>} + }{sg_id_range = #xegpu.range<[2, 18]>} gpu.return } @@ -358,7 +358,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { -> vector<128x64xf32> %exp = math.exp %ld {layout_result_0 = #xegpu.layout} : vector<128x64xf32> } - } {sg_id_range = #xegpu.range<[3, 32]>} + } {sg_id_range = #xegpu.range<[3, 19]>} gpu.return } } From 07b9eff085d90ac2c812d87388e5edc5f70e7a71 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 17 Jul 2025 04:37:21 +0000 Subject: [PATCH 5/9] Fix variable name --- mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index e4212d6cb2020..2dab657e60bd3 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -199,8 +199,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern { Value adjustedSgId = linearSgId; if (sgIdSpecialized) { - int64_t expectedSgLayoutSize = endOfRange - startOfRange; - if (computeProduct(sgLayout) != expectedSgLayoutSize) { + int64_t sgCount = endOfRange - startOfRange; + if (computeProduct(sgLayout) != sgCount) { return rewriter.notifyMatchFailure( op, "sg_layout size must match the sg_id_range"); } From 09fdbfc355098326e59b5162cf84ee523ce12639 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 18 Jul 2025 04:01:39 +0000 Subject: [PATCH 6/9] Change variable name --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 5 +++-- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 11 ++++++----- mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 4 ++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 64f1830bd286c..7974e1968b112 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -320,8 +320,9 @@ def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> { let description = [{ `RangeAttr` is an attribute that defines a half-open range [start, end). The range is inclusive of the start value and exclusive of the end value. - One usage of this attribute can be for subgroup specialization. - For subgroup specialization, this attribute can be attached to a scf.if op like + One usage of this attribute can be to specify the subgroup id range. + The subgroup id range can be specified using this attribute, + and it can be attached to a scf.if op like ```mlir scf.if %cond { // some operations diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 2dab657e60bd3..d0455cbaa1834 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -34,9 +34,9 @@ using namespace mlir; namespace { -// Check if there is sg id specialization. -static bool isSgIdSpecialized(Operation *op, int64_t &startOfRange, - int64_t &endOfRange) { +// Check if there is sg id range attached to the scf.if op. +static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange, + int64_t &endOfRange) { Operation *parent = op->getParentOp(); // Find the outermost scf::IfOp with xegpu.sg_id_range. while (parent) { @@ -195,10 +195,11 @@ struct WgToSgCreateNdOp : public OpConversionPattern { } int64_t startOfRange = -1, endOfRange = -1; - bool sgIdSpecialized = isSgIdSpecialized(op, startOfRange, endOfRange); + bool sgIdRangeSpecified = + isSgIdRangeSpecified(op, startOfRange, endOfRange); Value adjustedSgId = linearSgId; - if (sgIdSpecialized) { + if (sgIdRangeSpecified) { int64_t sgCount = endOfRange - startOfRange; if (computeProduct(sgLayout) != sgCount) { return rewriter.notifyMatchFailure( diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index 4014cdff24a6e..fb8a3965fbf71 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -296,8 +296,8 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { gpu.return } - // CHECK-LABEL: @warp_specialized - gpu.func @warp_specialized(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) { + // CHECK-LABEL: @subgroup_id_range + gpu.func @subgroup_id_range(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) { %sg_id = gpu.subgroup_id : index %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index From 1cecfbe5e0ad8607e7d1a14b614ea743b6299a69 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 21 Jul 2025 15:32:08 +0000 Subject: [PATCH 7/9] remove braces --- mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index d0455cbaa1834..732b63061e1f4 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -201,10 +201,9 @@ struct WgToSgCreateNdOp : public OpConversionPattern { Value adjustedSgId = linearSgId; if (sgIdRangeSpecified) { int64_t sgCount = endOfRange - startOfRange; - if (computeProduct(sgLayout) != sgCount) { + if (computeProduct(sgLayout) != sgCount) return rewriter.notifyMatchFailure( op, "sg_layout size must match the sg_id_range"); - } // Subtract startOfRange from the original subgroup id to get the adjusted // sg id Value startOfRangeVal = From 3cde9208318c2eb7471b3fe396c5f4b438c3319f Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 21 Jul 2025 20:01:25 +0000 Subject: [PATCH 8/9] add verifier for RangeAttr --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 1 + mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 7974e1968b112..4c125d5017a7b 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -341,6 +341,7 @@ def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> { ]; let assemblyFormat = "`<` `[`$start ```,` $end `]``>`"; + let genVerifyDecl = 1; } #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 7ef61de190b4c..69c9b53ada4f4 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -209,6 +209,19 @@ LayoutAttr::verify(llvm::function_ref emitError, return success(); } +//===----------------------------------------------------------------------===// +// XeGPU_RangeAttr +//===----------------------------------------------------------------------===// + +LogicalResult +RangeAttr::verify(llvm::function_ref emitError, + IntegerAttr startOfRange, IntegerAttr endOfRange) { + if (startOfRange.getInt() >= endOfRange.getInt()) + return emitError() << "EndOfRange must be greater than StartOfRange"; + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_TensorDescType //===----------------------------------------------------------------------===// From 343d63090ed8e01a297b4f12bbb9577823b0f096 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 23 Jul 2025 17:13:01 +0000 Subject: [PATCH 9/9] clean up --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 4 ++-- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 4c125d5017a7b..f113762ffc723 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -326,7 +326,7 @@ def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> { ```mlir scf.if %cond { // some operations - }{sg_id_range = #xegpu.range<[2, 4]>} + } {sg_id_range = #xegpu.range<[2, 4]>} ``` In this case, the scf.if op will only be executed for subgroup IDs 2 and 3. }]; @@ -340,7 +340,7 @@ def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> { AttrBuilder<(ins "int":$start, "int":$end)> ]; - let assemblyFormat = "`<` `[`$start ```,` $end `]``>`"; + let assemblyFormat = "`<` `[`$start `,` $end `]` `>`"; let genVerifyDecl = 1; } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 69c9b53ada4f4..76e245dfa7b9b 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -217,7 +217,9 @@ LogicalResult RangeAttr::verify(llvm::function_ref emitError, IntegerAttr startOfRange, IntegerAttr endOfRange) { if (startOfRange.getInt() >= endOfRange.getInt()) - return emitError() << "EndOfRange must be greater than StartOfRange"; + return emitError() << "'end' : " << endOfRange.getInt() + << " must be greater than 'start' : " + << startOfRange.getInt(); return success(); }