Skip to content

Conversation

@nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Jul 14, 2025

This PR adds a new attribute to the xegpu dialect called xegpu.range. One use case of this attribute can be to attach subgroup_id_range to scf.if of to drive the execution.

@nbpatel nbpatel marked this pull request as ready for review July 15, 2025 20:16
@llvmbot
Copy link
Member

llvmbot commented Jul 15, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/148661.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+39-1)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+83)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index be7b860dd1729..56dc132d8083d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -174,8 +174,46 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]);
     }
 
+    // Check if there is warp specialization.
+    auto isWarpSpecialized = [](Operation *op, int64_t &startRange,
+                                int64_t &endRange) -> bool {
+      Operation *parent = op->getParentOp();
+      // Find the outermost scf::IfOp with xegpu.sg_id_range.
+      while (parent) {
+        if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
+          if (Attribute attr = ifOp->getAttr("xegpu.sg_id_range")) {
+            if (auto denseAttr = dyn_cast<DenseI32ArrayAttr>(attr)) {
+              auto values = denseAttr.asArrayRef();
+              if (values.size() == 2) {
+                startRange = values[0];
+                endRange = values[1];
+              }
+            }
+            break;
+          }
+        }
+        parent = parent->getParentOp();
+      }
+      // Return false if startRange is 0
+      return (startRange > 0 && endRange > startRange);
+    };
+
+    int64_t startRange = -1, endRange = -1;
+    bool warpSpecialized = isWarpSpecialized(op, startRange, endRange);
+
+    // If warp specialization is detected, adjust the subgroup id accordingly
+    Value adjustedSgId = linearSgId;
+    if (warpSpecialized) {
+      // Subtract startRange from the original subgroup id to get the adjusted
+      // sg id
+      Value startRangeVal =
+          rewriter.create<arith::ConstantIndexOp>(loc, startRange);
+      adjustedSgId =
+          rewriter.createOrFold<index::SubOp>(loc, linearSgId, startRangeVal);
+    }
+
     auto deLinearizeSgId =
-        affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim);
+        affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim);
     if (failed(deLinearizeSgId))
       return failure();
     SmallVector<Value> sgIds = *deLinearizeSgId;
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 44b11c304cc80..71eb732ac4953 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -296,5 +296,88 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     gpu.return
   }
 
+  // CHECK-LABEL: @warp_specialized
+  gpu.func @warp_specialized(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) {
+    %sg_id = gpu.subgroup_id : index
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c31 = arith.constant 31 : index
+    %c3 = arith.constant 3 : index
+    %cond1 = arith.cmpi sge, %sg_id, %c0 : index
+    %cond2 = arith.cmpi slt, %sg_id, %c1 : index
+    %cond = arith.andi %cond1, %cond2 : i1
+    scf.if %cond {
+        // CHECK-NOT: index.sub
+        %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+          -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+        %load =  xegpu.load_nd %tdesc
+          : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+          -> vector<256x128xf32>
+    } {xegpu.sg_id_range = array<i32: 0, 1>}
+    %cond3 = arith.cmpi sge, %sg_id, %c1 : index
+    %cond4 = arith.cmpi slt, %sg_id, %c2 : index
+    %cond5 = arith.andi %cond3, %cond4 : i1
+     scf.if %cond5 {
+        // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+        // CHECK: %[[C1:.*]] = arith.constant 1 : index
+        // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C1]]
+        %tdesc_a = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+          -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+        %load_a =  xegpu.load_nd %tdesc_a
+          : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+          -> vector<256x128xf32>
+        %tdesc_b = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x256xf32>
+          -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>>
+        %load_b =  xegpu.load_nd %tdesc_b
+          : !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>>
+          -> vector<128x256xf32>
+        %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>} : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
+     }{xegpu.sg_id_range = array<i32: 1, 2>}
+    %cond6 = arith.cmpi sge, %sg_id, %c2 : index
+    %cond7 = arith.cmpi slt, %sg_id, %c31 : index
+    %cond8 = arith.andi %cond6, %cond7 : i1
+    scf.if %cond8 {
+      // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+      // CHECK: %[[C2:.*]] = arith.constant 2 : index
+      // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
+      %tdesc = xegpu.create_nd_tdesc %src2[0, 0] : memref<128x64xf32>
+        -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+      %load =  xegpu.load_nd %tdesc
+        : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+        -> vector<128x64xf32>
+      %exp = math.exp %load {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+    }{xegpu.sg_id_range = array<i32: 2, 32>}
+    gpu.return
+  }
 
+  // CHECK-LABEL: @subgroup_id_range_nested_if
+  gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
+    %sg_id = gpu.subgroup_id : index
+    %c1 = arith.constant 1 : i1 
+    %c3 = arith.constant 3 : index
+    %c32 = arith.constant 32 : index
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    %cond1 = arith.cmpi sge, %sg_id, %c3 : index
+    %cond2 = arith.cmpi slt, %sg_id, %c32 : index
+    %cond = arith.andi %cond1, %cond2 : i1
+    scf.if %c1 {
+      scf.if %cond {
+        // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+        // CHECK: %[[C3:.*]] = arith.constant 3 : index
+        // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
+        %td = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x64xf32>
+          -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+        %ld =  xegpu.load_nd %td
+          : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+          -> vector<128x64xf32>
+        %exp = math.exp %ld {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+    }
+  } {xegpu.sg_id_range = array<i32: 3, 8>}
+  gpu.return
+  }
 }

@nbpatel nbpatel requested review from Jianhui-Li and chencha3 July 16, 2025 19:23
}

// Check if there is warp specialization.
auto isWarpSpecialized = [](Operation *op, int64_t &startOfRange,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider taking this out as a separate utility function for wg-to-sg distribution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok will change it


// If warp specialization is detected, adjust the subgroup id accordingly
Value adjustedSgId = linearSgId;
if (warpSpecialized) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also need to verify that the sg id ranges match with xegpu.sg_layout

Copy link
Contributor Author

@nbpatel nbpatel Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verify that the number of the subgroups in the sg_layout are equal to than the number of subgroups specified by the sg_id_range?

}
parent = parent->getParentOp();
}
// Return false if startOfRange is 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why startOfRange can't be 0?

Copy link
Contributor Author

@nbpatel nbpatel Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can be 0 but if the starting subgroup id is 0 we don't need to adjust the id's, so the check returns false

%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>>
-> vector<128x256xf32>
%dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>} : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg_layout size should match with range size

@nbpatel nbpatel requested a review from adam-smnk July 18, 2025 15:29
@nbpatel
Copy link
Contributor Author

nbpatel commented Jul 22, 2025

Hi @adam-smnk , do you have any comments on this?

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks neat - minor comments

Comment on lines +46 to +47
startOfRange = attr.getStart().getInt();
endOfRange = attr.getEnd().getInt();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General suggestion, non-blocker here: getting int value directly would make for a nice attribute helper method

RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
IntegerAttr startOfRange, IntegerAttr endOfRange) {
if (startOfRange.getInt() >= endOfRange.getInt())
return emitError() << "EndOfRange must be greater than StartOfRange";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: in the error, I'd refer to values to their attribute names end and start
it should improve error readability

AttrBuilder<(ins "int":$start, "int":$end)>
];

let assemblyFormat = "`<` `[`$start ```,` $end `]``>`";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let assemblyFormat = "`<` `[`$start ```,` $end `]``>`";
let assemblyFormat = "`<` `[`$start `,` $end `]` `>`";

nit: minor cleanup

//===----------------------------------------------------------------------===//

LogicalResult
RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add one invalid test case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if its possible to add a negative test case with this pass...because it will always give legalization error for the create_nd_desc op if the pattern returns a failure in this case

```mlir
scf.if %cond {
// some operations
}{sg_id_range = #xegpu.range<[2, 4]>}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
}{sg_id_range = #xegpu.range<[2, 4]>}
} {sg_id_range = #xegpu.range<[2, 4]>}

@nbpatel nbpatel merged commit 65dec99 into llvm:main Jul 23, 2025
9 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
This PR adds a new attribute to the xegpu dialect called xegpu.range.
One use case of this attribute can be to attach subgroup_id_range to
scf.if of to drive the execution.
@nbpatel nbpatel deleted the xegpu-warp-specialized branch September 25, 2025 20:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants