Skip to content

Commit 6872e6d

Browse files
committed
add delinearizeSubgroupId interface
1 parent 36e2c3a commit 6872e6d

File tree

5 files changed

+35
-2
lines changed

5 files changed

+35
-2
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/IR/BuiltinTypes.h"
1616
#include "mlir/IR/Dialect.h"
1717
#include "mlir/IR/TypeUtilities.h"
18+
#include "mlir/IR/Value.h"
1819
#include "mlir/Interfaces/ShapedOpInterfaces.h"
1920
#include "mlir/Interfaces/SideEffectInterfaces.h"
2021
#include "mlir/Interfaces/ViewLikeInterface.h"

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,11 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
187187
"getEffectiveSgLayout">,
188188
InterfaceMethod<"Get the effective sg data",
189189
"std::optional<llvm::SmallVector<int>>",
190-
"getEffectiveSgData">
190+
"getEffectiveSgData">,
191+
InterfaceMethod<"Delinearize the Subgroup Id",
192+
"FailureOr<SmallVector<Value>>",
193+
"delinearizeSubgroupId",
194+
(ins "Value":$linearId, "Location":$loc, "OpBuilder &": $builder)>
191195
];
192196
}
193197

@@ -358,6 +362,10 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
358362
return llvm::to_vector(data.asArrayRef());
359363
return std::nullopt;
360364
}
365+
366+
FailureOr<SmallVector<Value>>
367+
delinearizeSubgroupId(Value linearId, Location loc, OpBuilder &builder);
368+
361369
}];
362370

363371
let assemblyFormat = "`<` struct(params) `>`";
@@ -409,6 +417,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
409417
return std::nullopt;
410418
}
411419

420+
FailureOr<llvm::SmallVector<Value>>
421+
delinearizeSubgroupId(Value linearId, Location loc, OpBuilder &builder);
422+
412423
DenseI32ArrayAttr getOrder() const {
413424
return getParent().getOrder();
414425
}

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Affine/Utils.h"
910
#include "mlir/Dialect/Utils/IndexingUtils.h"
1011
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1112
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
@@ -211,6 +212,18 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
211212
return success();
212213
}
213214

215+
FailureOr<SmallVector<Value>>
216+
LayoutAttr::delinearizeSubgroupId(Value linearId, Location loc,
217+
OpBuilder &builder) {
218+
assert(isWgLayout() && "delinearizeSubgroupId is only available for "
219+
"workgroup-level layout attribute.");
220+
auto dims =
221+
llvm::map_to_vector(getSgLayout().asArrayRef(), [&](int32_t d) -> Value {
222+
return arith::ConstantIndexOp::create(builder, loc, d);
223+
});
224+
return affine::delinearizeIndex(builder, loc, linearId, dims);
225+
}
226+
214227
//===----------------------------------------------------------------------===//
215228
// XeGPU_SliceAttr
216229
//===----------------------------------------------------------------------===//
@@ -232,6 +245,12 @@ SliceAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
232245
return success();
233246
}
234247

248+
FailureOr<SmallVector<Value>>
249+
SliceAttr::delinearizeSubgroupId(Value linearId, Location loc,
250+
OpBuilder &builder) {
251+
return getParent().delinearizeSubgroupId(linearId, loc, builder);
252+
}
253+
235254
//===----------------------------------------------------------------------===//
236255
// XeGPU_TensorDescType
237256
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,9 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
838838
} // namespace xegpu
839839
} // namespace mlir
840840

841+
namespace mlir {
841842
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
843+
} // namespace mlir
842844
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
843845
#define GET_OP_CLASSES
844846
#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
175175
}
176176

177177
auto deLinearizeSgId =
178-
affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim);
178+
layout.delinearizeSubgroupId(linearSgId, loc, rewriter);
179179
if (failed(deLinearizeSgId))
180180
return failure();
181181
SmallVector<Value> sgIds = *deLinearizeSgId;

0 commit comments

Comments
 (0)