@@ -37,55 +37,61 @@ void XeGPUDialect::initialize() {
3737 >();
3838}
3939
40- // / Generates instructions to compute offsets for a subgroup identified by
41- // / its multidimensional indices (sgId), using the specified subgroup layout
42- // / (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
43- // / dimensions (sizePerWg).
40+ // A `srcShape` consists of N distribution units, each being `subShapesLayout` x
41+ // `subShape`. A `delinearizedId` is used to identify a particular `subShape`
42+ // within each distribution unit.
43+ // Example:
44+ // WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
45+ // distribution unit of shape 64x64, we have 2x4 such distribution units.
46+ // `delinearizedId` is used to identify a 16x32 of a subgroup in each
47+ // distribution unit.
4448static SmallVector<SmallVector<Value>>
45- genOffsetsComputingInsts (OpBuilder &builder, Location loc,
46- SmallVector<Value> sgId, ArrayRef<int64_t > sgLayout,
47- ArrayRef<int64_t > sizePerSg,
48- ArrayRef<int64_t > sizePerWg) {
49-
50- SmallVector<SmallVector<Value>> offsets;
49+ genCoordinates (OpBuilder &builder, Location loc,
50+ SmallVector<Value> delinearizedId,
51+ ArrayRef<int64_t > subShapesLayout, ArrayRef<int64_t > subShape,
52+ ArrayRef<int64_t > srcShape) {
53+ SmallVector<SmallVector<Value>> coordinates;
54+
55+ // A distribution unit must be less than or equal to `srcShape`
56+ SmallVector<int64_t > distUnitShape = llvm::map_to_vector (
57+ llvm::zip_equal (srcShape,
58+ computeElementwiseMul (subShapesLayout, subShape)),
59+ [](const auto &t) { return std::min (std::get<0 >(t), std::get<1 >(t)); });
5160
52- // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
53- SmallVector<Value> localOffsets = llvm::map_to_vector (
54- llvm::zip (sgId, sizePerSg ), [&](const auto &t) -> Value {
61+ // Get the offset of `subShape` within a distribution unit.
62+ SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector (
63+ llvm::zip (delinearizedId, subShape ), [&](const auto &t) -> Value {
5564 return builder.createOrFold <index::MulOp>(
5665 loc, std::get<0 >(t),
5766 builder.createOrFold <arith::ConstantIndexOp>(loc, std::get<1 >(t)));
5867 });
5968
60- // distUnit[i] is the minimum value between sizePerWg[i] and
61- // sgLayout[i] * sizePerSg[i]
62- SmallVector<int64_t > distUnit = llvm::map_to_vector (
63- llvm::zip_equal (sizePerWg, computeElementwiseMul (sgLayout, sizePerSg)),
64- [](const auto &t) { return std::min (std::get<0 >(t), std::get<1 >(t)); });
65-
69+ // For each dist unit
6670 for (SmallVector<int64_t > unitOffs :
67- StaticTileOffsetRange (sizePerWg, distUnit)) {
71+ StaticTileOffsetRange (srcShape, distUnitShape)) {
72+ // Get dist unit offset within `srcShape`.
6873 SmallVector<Value> base =
6974 llvm::map_to_vector (unitOffs, [&](int64_t d) -> Value {
7075 return arith::ConstantIndexOp::create (builder, loc, d);
7176 });
72-
73- SmallVector<Value> adds = llvm::map_to_vector (
74- llvm::zip_equal (base, localOffsets), [&](const auto &t) -> Value {
75- return builder.createOrFold <arith::AddIOp>(loc, std::get<0 >(t),
76- std::get<1 >(t));
77- });
78-
77+ // Calculate `subShape` offset within `srcShape`.
78+ SmallVector<Value> adds =
79+ llvm::map_to_vector (llvm::zip_equal (base, distUnitLocalOffset),
80+ [&](const auto &t) -> Value {
81+ return builder.createOrFold <arith::AddIOp>(
82+ loc, std::get<0 >(t), std::get<1 >(t));
83+ });
84+ // Do not go beyond `srcShape` bounds.
7985 SmallVector<Value> mods = llvm::map_to_vector (
80- llvm::zip_equal (adds, sizePerWg ), [&](const auto &t) -> Value {
86+ llvm::zip_equal (adds, srcShape ), [&](const auto &t) -> Value {
8187 return builder.createOrFold <index::RemUOp>(
8288 loc, std::get<0 >(t),
8389 arith::ConstantIndexOp::create (builder, loc, std::get<1 >(t)));
8490 });
8591
86- offsets .push_back (mods);
92+ coordinates .push_back (mods);
8793 }
88- return offsets ;
94+ return coordinates ;
8995}
9096
9197// Checks if the given shape can be evenly distributed based on the layout
@@ -272,12 +278,7 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
272278}
273279
274280FailureOr<SmallVector<Value>>
275- LayoutAttr::delinearizeSubgroupId (OpBuilder &builder, Location loc,
276- Value linearId) {
277- // delinearizeSubgroupId is only available for
278- // workgroup-level layout attribute
279- if (!isForWorkgroup ())
280- return failure ();
281+ LayoutAttr::delinearizeId (OpBuilder &builder, Location loc, Value linearId) {
281282
282283 // TODO: handle order attribute
283284 auto hasDefaultOrder = [&]() {
@@ -287,41 +288,52 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
287288 };
288289 if (!hasDefaultOrder ())
289290 return mlir::emitError (loc, " order attribute is currently not supported." );
290-
291- auto dims =
292- llvm::map_to_vector (getEffectiveSgLayoutAsInt (), [&](int64_t d) -> Value {
293- return builder.createOrFold <arith::ConstantIndexOp>(loc, d);
294- });
291+ SmallVector<int64_t > layout;
292+ if (isForWorkgroup ()) {
293+ layout = getEffectiveSgLayoutAsInt ();
294+ } else if (isForSubgroup ()) {
295+ layout = getEffectiveLaneLayoutAsInt ();
296+ } else {
297+ return failure ();
298+ }
299+ auto dims = llvm::map_to_vector (layout, [&](int64_t d) -> Value {
300+ return builder.createOrFold <arith::ConstantIndexOp>(loc, d);
301+ });
295302
296303 return affine::delinearizeIndex (builder, loc, linearId, dims);
297304}
298305
299- // / Implements DistributeLayoutAttr::getOffsets to generate
306+ // / Implements DistributeLayoutAttr::computeDistributedCoords to generate
300307// / instructions for computing multi-dimensional offsets when distributed by
301308// / LayoutAttr.
302309FailureOr<SmallVector<SmallVector<Value>>>
303- LayoutAttr::getOffsets (OpBuilder &builder, Location loc, Value linearId,
304- ArrayRef<int64_t > shape) {
305- if (!isForWorkgroup ())
310+ LayoutAttr::computeDistributedCoords (OpBuilder &builder, Location loc,
311+ Value linearId, ArrayRef<int64_t > shape) {
312+ SmallVector<int64_t > layout;
313+ SmallVector<int64_t > subShape;
314+ if (isForWorkgroup ()) {
315+ layout = getEffectiveSgLayoutAsInt ();
316+ subShape = getEffectiveSgDataAsInt ();
317+ } else if (isForSubgroup ()) {
318+ layout = getEffectiveLaneLayoutAsInt ();
319+ subShape = getEffectiveLaneDataAsInt ();
320+ } else {
306321 return failure ();
307-
308- SmallVector<int64_t > sgLayout = getEffectiveSgLayoutAsInt ();
309- SmallVector<int64_t > sgShape = getEffectiveSgDataAsInt ();
310- if (sgShape.empty ()) {
311- if (auto derivedShape = computeShapeRatio (shape, sgLayout))
312- sgShape = derivedShape.value ();
322+ }
323+ if (subShape.empty ()) {
324+ if (auto derivedShape = computeShapeRatio (shape, layout))
325+ subShape = derivedShape.value ();
313326 else
314327 return failure ();
315328 }
316329
317330 // delinearize Ids
318- auto maybeIds = delinearizeSubgroupId (builder, loc, linearId);
331+ auto maybeIds = delinearizeId (builder, loc, linearId);
319332 if (failed (maybeIds))
320333 return failure ();
321- SmallVector<Value> sgIds = *maybeIds;
334+ SmallVector<Value> ids = *maybeIds;
322335
323- return genOffsetsComputingInsts (builder, loc, sgIds, sgLayout, sgShape,
324- shape);
336+ return genCoordinates (builder, loc, ids, layout, subShape, shape);
325337}
326338
327339// ===----------------------------------------------------------------------===//
@@ -375,34 +387,43 @@ SliceAttr SliceAttr::flatten() const {
375387}
376388
377389FailureOr<SmallVector<Value>>
378- SliceAttr::delinearizeSubgroupId (OpBuilder &builder, Location loc,
379- Value linearId) {
390+ SliceAttr::delinearizeId (OpBuilder &builder, Location loc, Value linearId) {
380391 SliceAttr attr = flatten ();
381392 auto parent = dyn_cast<LayoutAttr>(attr.getParent ());
382- return parent.delinearizeSubgroupId (builder, loc, linearId);
393+ return parent.delinearizeId (builder, loc, linearId);
383394}
384395
385- // / Implements DistributeLayoutAttr::getOffsets to generate
386- // / instructions for computing multi-dimensional offsets when distributed by
387- // / SliceAttr .
396+ // Implements DistributeLayoutAttr::computeDistributedCoords to generate
397+ // instructions for computing multi-dimensional offsets when distributed by
398+ // LayoutAttr .
388399FailureOr<SmallVector<SmallVector<Value>>>
389- SliceAttr::getOffsets (OpBuilder &builder, Location loc, Value linearId ,
390- ArrayRef<int64_t > shape) {
400+ SliceAttr::computeDistributedCoords (OpBuilder &builder, Location loc,
401+ Value linearId, ArrayRef<int64_t > shape) {
391402 assert (getRank () == static_cast <int64_t >(shape.size ()) && " invalid shape." );
392403 if (!isForWorkgroup ())
393404 return failure ();
394405
395- SmallVector<int64_t > sgLayout = getEffectiveSgLayoutAsInt ();
396- SmallVector<int64_t > sgShape = getEffectiveSgDataAsInt ();
397- if (sgShape.empty ()) {
398- if (auto derivedShape = computeShapeRatio (shape, sgLayout))
399- sgShape = derivedShape.value ();
406+ SmallVector<int64_t > layout;
407+ SmallVector<int64_t > subShape;
408+ if (isForWorkgroup ()) {
409+ layout = getEffectiveSgLayoutAsInt ();
410+ subShape = getEffectiveSgDataAsInt ();
411+ } else if (isForSubgroup ()) {
412+ layout = getEffectiveLaneLayoutAsInt ();
413+ subShape = getEffectiveLaneDataAsInt ();
414+ } else {
415+ return failure ();
416+ }
417+
418+ if (subShape.empty ()) {
419+ if (auto derivedShape = computeShapeRatio (shape, layout))
420+ subShape = derivedShape.value ();
400421 else
401422 return failure ();
402423 }
403424
404425 // delinearize Ids
405- auto maybeIds = delinearizeSubgroupId (builder, loc, linearId);
426+ auto maybeIds = delinearizeId (builder, loc, linearId);
406427 if (failed (maybeIds))
407428 return failure ();
408429
@@ -412,8 +433,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
412433 SmallVector<Value> sgIds =
413434 XeGPUDialect::slice (ArrayRef<Value>(*maybeIds), dims);
414435
415- return genOffsetsComputingInsts (builder, loc, sgIds, sgLayout, sgShape,
416- shape);
436+ return genCoordinates (builder, loc, sgIds, layout, subShape, shape);
417437}
418438
419439bool SliceAttr::isSliceOf (const xegpu::DistributeLayoutAttr &other) {
0 commit comments