Skip to content

Commit 60e20a0

Browse files
committed
add impl of getOffsets for LayoutAttr
1 parent 223fab9 commit 60e20a0

File tree

3 files changed

+113
-34
lines changed

3 files changed

+113
-34
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -183,15 +183,20 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
183183

184184
let methods = [
185185
InterfaceMethod<"Get the effective sg layout",
186-
"std::optional<llvm::SmallVector<int>>",
186+
"std::optional<SmallVector<int64_t>>",
187187
"getEffectiveSgLayout">,
188188
InterfaceMethod<"Get the effective sg data",
189-
"std::optional<llvm::SmallVector<int>>",
189+
"std::optional<SmallVector<int64_t>>",
190190
"getEffectiveSgData">,
191191
InterfaceMethod<"Delinearize the Subgroup Id",
192192
"FailureOr<SmallVector<Value>>",
193193
"delinearizeSubgroupId",
194-
(ins "Value":$linearId, "Location":$loc, "OpBuilder &": $builder)>
194+
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
195+
196+
InterfaceMethod<"Get the local offset to be accessed by the given subgroup Id",
197+
"FailureOr<SmallVector<SmallVector<Value>>>",
198+
"getOffsets",
199+
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>
195200
];
196201
}
197202

@@ -351,20 +356,23 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
351356
getLaneLayout(), getLaneData(), getOrder());
352357
}
353358

354-
std::optional<llvm::SmallVector<int32_t>> getEffectiveSgLayout() const {
359+
std::optional<SmallVector<int64_t>> getEffectiveSgLayout() const {
355360
if (DenseI32ArrayAttr layout = getSgLayout())
356-
return llvm::to_vector(layout.asArrayRef());
361+
return llvm::to_vector_of<int64_t>(layout.asArrayRef());
357362
return std::nullopt;
358363
}
359364

360-
std::optional<llvm::SmallVector<int32_t>> getEffectiveSgData() const {
365+
std::optional<SmallVector<int64_t>> getEffectiveSgData() const {
361366
if (DenseI32ArrayAttr data = getSgData())
362-
return llvm::to_vector(data.asArrayRef());
367+
return llvm::to_vector_of<int64_t>(data.asArrayRef());
363368
return std::nullopt;
364369
}
365370

366371
FailureOr<SmallVector<Value>>
367-
delinearizeSubgroupId(Value linearId, Location loc, OpBuilder &builder);
372+
delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
373+
374+
FailureOr<SmallVector<SmallVector<Value>>>
375+
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
368376

369377
}];
370378

@@ -401,24 +409,6 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
401409
);
402410

403411
let extraClassDeclaration = [{
404-
std::optional<llvm::SmallVector<int32_t>> getEffectiveSgLayout() const {
405-
if (DenseI32ArrayAttr layout = getParent().getSgLayout()) {
406-
llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
407-
return XeGPUDialect::dropDims(layout.asArrayRef(), dims);
408-
}
409-
return std::nullopt;
410-
}
411-
412-
std::optional<llvm::SmallVector<int32_t>> getEffectiveSgData() const {
413-
if (DenseI32ArrayAttr data = getParent().getSgData()) {
414-
llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
415-
return XeGPUDialect::dropDims(data.asArrayRef(), dims);
416-
}
417-
return std::nullopt;
418-
}
419-
420-
FailureOr<llvm::SmallVector<Value>>
421-
delinearizeSubgroupId(Value linearId, Location loc, OpBuilder &builder);
422412

423413
DenseI32ArrayAttr getOrder() const {
424414
return getParent().getOrder();
@@ -431,6 +421,29 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
431421
bool isSgLayout() const {
432422
return getParent().isSgLayout();
433423
}
424+
425+
std::optional<SmallVector<int64_t>> getEffectiveSgLayout() const {
426+
if (auto layout = getParent().getEffectiveSgLayout()) {
427+
ArrayRef<int64_t> dims = getDims().asArrayRef();
428+
return XeGPUDialect::dropDims(llvm::ArrayRef<int64_t>(*layout), dims);
429+
}
430+
return std::nullopt;
431+
}
432+
433+
std::optional<SmallVector<int64_t>> getEffectiveSgData() const {
434+
if (auto data = getParent().getEffectiveSgData()) {
435+
ArrayRef<int64_t> dims = getDims().asArrayRef();
436+
return XeGPUDialect::dropDims(llvm::ArrayRef<int64_t>(*data), dims);
437+
}
438+
return std::nullopt;
439+
}
440+
441+
FailureOr<SmallVector<Value>>
442+
delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
443+
444+
FailureOr<SmallVector<SmallVector<Value>>>
445+
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
446+
434447
}];
435448

436449
let assemblyFormat = "`<` $parent `,` `dims` `=` $dims `>`";

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Affine/Utils.h"
10+
#include "mlir/Dialect/Arith/Utils/Utils.h"
11+
#include "mlir/Dialect/Index/IR/IndexOps.h"
1012
#include "mlir/Dialect/Utils/IndexingUtils.h"
1113
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1214
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
@@ -213,17 +215,75 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
213215
}
214216

215217
FailureOr<SmallVector<Value>>
216-
LayoutAttr::delinearizeSubgroupId(Value linearId, Location loc,
217-
OpBuilder &builder) {
218-
assert(isWgLayout() && "delinearizeSubgroupId is only available for "
219-
"workgroup-level layout attribute.");
218+
LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
219+
Value linearId) {
220+
// delinearizeSubgroupId is only available for workgroup-level layout
221+
// attribute
222+
if (!isWgLayout())
223+
return failure();
224+
220225
auto dims =
221226
llvm::map_to_vector(getSgLayout().asArrayRef(), [&](int32_t d) -> Value {
222227
return arith::ConstantIndexOp::create(builder, loc, d);
223228
});
229+
224230
return affine::delinearizeIndex(builder, loc, linearId, dims);
225231
}
226232

233+
FailureOr<SmallVector<SmallVector<Value>>>
234+
LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
235+
ArrayRef<int64_t> shape) {
236+
if (!isWgLayout())
237+
return failure();
238+
239+
auto sgLayout = getEffectiveSgLayout().value();
240+
SmallVector<int64_t> sgShape;
241+
if (auto maybeSgShape = getEffectiveSgData())
242+
sgShape = maybeSgShape.value();
243+
else if (auto ratio = computeShapeRatio(shape, sgLayout))
244+
sgShape = ratio.value();
245+
else
246+
return failure();
247+
248+
// distUnit[i] is the minimum value between shape[i] and
249+
// sgLayout[i] * sgShape[i]
250+
SmallVector<int64_t> distUnit = llvm::map_to_vector(
251+
llvm::zip_equal(shape, computeElementwiseMul(sgLayout, sgShape)),
252+
[](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
253+
254+
// delinearize Ids
255+
auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
256+
if (failed(maybeIds))
257+
return failure();
258+
SmallVector<Value> sgIds = *maybeIds;
259+
260+
// nd local offset, localOffset[i] = sgId[i] * sgShape[i]
261+
SmallVector<Value> localOffsets = llvm::map_to_vector(
262+
llvm::zip(sgIds, sgShape), [&](const auto &t) -> Value {
263+
auto &[id, s] = t;
264+
Value d = arith::ConstantIndexOp::create(builder, loc, s);
265+
return index::MulOp::create(builder, loc, id, d);
266+
});
267+
268+
SmallVector<SmallVector<Value>> offsets;
269+
for (SmallVector<int64_t> unitOffs : StaticTileOffsetRange(shape, distUnit)) {
270+
SmallVector<Value> base =
271+
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
272+
return arith::ConstantIndexOp::create(builder, loc, d);
273+
});
274+
275+
SmallVector<Value> adds = llvm::map_to_vector(
276+
llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
277+
return arith::AddIOp::create(builder, loc, std::get<0>(t),
278+
std::get<1>(t));
279+
});
280+
281+
offsets.push_back(adds);
282+
}
283+
284+
return offsets;
285+
}
286+
227287
//===----------------------------------------------------------------------===//
228288
// XeGPU_SliceAttr
229289
//===----------------------------------------------------------------------===//
@@ -246,9 +306,15 @@ SliceAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
246306
}
247307

248308
FailureOr<SmallVector<Value>>
249-
SliceAttr::delinearizeSubgroupId(Value linearId, Location loc,
250-
OpBuilder &builder) {
251-
return getParent().delinearizeSubgroupId(linearId, loc, builder);
309+
SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
310+
Value linearId) {
311+
return getParent().delinearizeSubgroupId(builder, loc, linearId);
312+
}
313+
314+
FailureOr<SmallVector<SmallVector<Value>>>
315+
SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
316+
ArrayRef<int64_t> shape) {
317+
return failure();
252318
}
253319

254320
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
213213
}
214214

215215
auto deLinearizeSgId =
216-
layout.delinearizeSubgroupId(adjustedSgId, loc, rewriter);
216+
layout.delinearizeSubgroupId(rewriter, loc, adjustedSgId);
217217
if (failed(deLinearizeSgId))
218218
return failure();
219219
SmallVector<Value> sgIds = *deLinearizeSgId;

0 commit comments

Comments
 (0)