@@ -38,47 +38,47 @@ void XeGPUDialect::initialize() {
3838 >();
3939}
4040
41- // / Generates instructions to compute offsets for a subgroup identified by
42- // / its multidimensional indices (sgId), using the specified subgroup layout
43- // / (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
44- // / dimensions (sizePerWg).
41+ // A `srcShape` consists of N distribution units, each being `subShapesLayout` x
42+ // `subShape`. A `delinearizedId` is used to identify a particular `subShape`
43+ // within each distribution unit.
4544static SmallVector<SmallVector<Value>>
46- genOffsetsComputingInsts (OpBuilder &builder, Location loc,
47- SmallVector<Value> sgId, ArrayRef<int64_t > sgLayout,
48- ArrayRef<int64_t > sizePerSg,
49- ArrayRef<int64_t > sizePerWg) {
50-
45+ genOffsets (OpBuilder &builder, Location loc, SmallVector<Value> delinearizedId,
46+ ArrayRef<int64_t > subShapesLayout, ArrayRef<int64_t > subShape,
47+ ArrayRef<int64_t > srcShape) {
5148 SmallVector<SmallVector<Value>> offsets;
5249
53- // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
54- SmallVector<Value> localOffsets = llvm::map_to_vector (
55- llvm::zip (sgId, sizePerSg), [&](const auto &t) -> Value {
50+ // A distribution unit must be less than or equal to `srcShape`
51+ SmallVector<int64_t > distUnitShape = llvm::map_to_vector (
52+ llvm::zip_equal (srcShape,
53+ computeElementwiseMul (subShapesLayout, subShape)),
54+ [](const auto &t) { return std::min (std::get<0 >(t), std::get<1 >(t)); });
55+
56+ // Get the offset of `subShape` within a distribution unit.
57+ SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector (
58+ llvm::zip (delinearizedId, subShape), [&](const auto &t) -> Value {
5659 return builder.createOrFold <index::MulOp>(
5760 loc, std::get<0 >(t),
5861 builder.createOrFold <arith::ConstantIndexOp>(loc, std::get<1 >(t)));
5962 });
6063
61- // distUnit[i] is the minimum value between sizePerWg[i] and
62- // sgLayout[i] * sizePerSg[i]
63- SmallVector<int64_t > distUnit = llvm::map_to_vector (
64- llvm::zip_equal (sizePerWg, computeElementwiseMul (sgLayout, sizePerSg)),
65- [](const auto &t) { return std::min (std::get<0 >(t), std::get<1 >(t)); });
66-
64+ // For each dist unit
6765 for (SmallVector<int64_t > unitOffs :
68- StaticTileOffsetRange (sizePerWg, distUnit)) {
66+ StaticTileOffsetRange (srcShape, distUnitShape)) {
67+ // Get dist unit offset within `srcShape`.
6968 SmallVector<Value> base =
7069 llvm::map_to_vector (unitOffs, [&](int64_t d) -> Value {
7170 return arith::ConstantIndexOp::create (builder, loc, d);
7271 });
73-
74- SmallVector<Value> adds = llvm::map_to_vector (
75- llvm::zip_equal (base, localOffsets), [&](const auto &t) -> Value {
76- return builder.createOrFold <arith::AddIOp>(loc, std::get<0 >(t),
77- std::get<1 >(t));
78- });
79-
72+ // Calculate `subShape` offset within `srcShape`.
73+ SmallVector<Value> adds =
74+ llvm::map_to_vector (llvm::zip_equal (base, distUnitLocalOffset),
75+ [&](const auto &t) -> Value {
76+ return builder.createOrFold <arith::AddIOp>(
77+ loc, std::get<0 >(t), std::get<1 >(t));
78+ });
79+ // Do not go beyond `srcShape` bounds.
8080 SmallVector<Value> mods = llvm::map_to_vector (
81- llvm::zip_equal (adds, sizePerWg ), [&](const auto &t) -> Value {
81+ llvm::zip_equal (adds, srcShape ), [&](const auto &t) -> Value {
8282 return builder.createOrFold <index::RemUOp>(
8383 loc, std::get<0 >(t),
8484 arith::ConstantIndexOp::create (builder, loc, std::get<1 >(t)));
@@ -268,12 +268,8 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
268268}
269269
270270FailureOr<SmallVector<Value>>
271- LayoutAttr::delinearizeSubgroupId (OpBuilder &builder, Location loc,
272- Value linearId) {
273- // delinearizeSubgroupId is only available for
274- // workgroup-level layout attribute
275- if (!isForWorkgroup ())
276- return failure ();
271+ LayoutAttr::delinearizeId (OpBuilder &builder, Location loc, Value linearId,
272+ xegpu::DistributionLevel idLevel) {
277273
278274 // TODO: handle order attribute
279275 auto hasDefaultOrder = [&]() {
@@ -283,41 +279,53 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
283279 };
284280 if (!hasDefaultOrder ())
285281 return mlir::emitError (loc, " order attribute is currently not supported." );
286-
287- auto dims =
288- llvm::map_to_vector (getEffectiveSgLayoutAsInt (), [&](int64_t d) -> Value {
289- return builder.createOrFold <arith::ConstantIndexOp>(loc, d);
290- });
282+ SmallVector<int64_t > layout;
283+ if (idLevel == xegpu::DistributionLevel::SG) {
284+ layout = getEffectiveSgLayoutAsInt ();
285+ } else if (idLevel == xegpu::DistributionLevel::WI) {
286+ layout = getEffectiveLaneLayoutAsInt ();
287+ } else {
288+ return failure ();
289+ }
290+ auto dims = llvm::map_to_vector (layout, [&](int64_t d) -> Value {
291+ return builder.createOrFold <arith::ConstantIndexOp>(loc, d);
292+ });
291293
292294 return affine::delinearizeIndex (builder, loc, linearId, dims);
293295}
294296
295- // / Implements DistributeLayoutAttr::getOffsets to generate
297+ // / Implements DistributeLayoutAttr::computeDistributedCoords to generate
296298// / instructions for computing multi-dimensional offsets when distributed by
297299// / LayoutAttr.
298300FailureOr<SmallVector<SmallVector<Value>>>
299- LayoutAttr::getOffsets (OpBuilder &builder, Location loc, Value linearId,
300- ArrayRef<int64_t > shape) {
301- if (!isForWorkgroup ())
301+ LayoutAttr::computeDistributedCoords (OpBuilder &builder, Location loc,
302+ Value linearId, ArrayRef<int64_t > shape,
303+ xegpu::DistributionLevel targetLevel) {
304+ SmallVector<int64_t > layout;
305+ SmallVector<int64_t > subShape;
306+ if (targetLevel == DistributionLevel::SG) {
307+ layout = getEffectiveSgLayoutAsInt ();
308+ subShape = getEffectiveSgDataAsInt ();
309+ } else if (targetLevel == DistributionLevel::WI) {
310+ layout = getEffectiveLaneLayoutAsInt ();
311+ subShape = getEffectiveLaneDataAsInt ();
312+ } else {
302313 return failure ();
303-
304- SmallVector<int64_t > sgLayout = getEffectiveSgLayoutAsInt ();
305- SmallVector<int64_t > sgShape = getEffectiveSgDataAsInt ();
306- if (sgShape.empty ()) {
307- if (auto derivedShape = computeShapeRatio (shape, sgLayout))
308- sgShape = derivedShape.value ();
314+ }
315+ if (subShape.empty ()) {
316+ if (auto derivedShape = computeShapeRatio (shape, layout))
317+ subShape = derivedShape.value ();
309318 else
310319 return failure ();
311320 }
312321
313322 // delinearize Ids
314- auto maybeIds = delinearizeSubgroupId (builder, loc, linearId);
323+ auto maybeIds = delinearizeId (builder, loc, linearId, targetLevel );
315324 if (failed (maybeIds))
316325 return failure ();
317- SmallVector<Value> sgIds = *maybeIds;
326+ SmallVector<Value> ids = *maybeIds;
318327
319- return genOffsetsComputingInsts (builder, loc, sgIds, sgLayout, sgShape,
320- shape);
328+ return genOffsets (builder, loc, ids, layout, subShape, shape);
321329}
322330
323331// ===----------------------------------------------------------------------===//
@@ -371,34 +379,45 @@ SliceAttr SliceAttr::flatten() const {
371379}
372380
373381FailureOr<SmallVector<Value>>
374- SliceAttr::delinearizeSubgroupId (OpBuilder &builder, Location loc,
375- Value linearId ) {
382+ SliceAttr::delinearizeId (OpBuilder &builder, Location loc, Value linearId ,
383+ xegpu::DistributionLevel level ) {
376384 SliceAttr attr = flatten ();
377385 auto parent = dyn_cast<LayoutAttr>(attr.getParent ());
378- return parent.delinearizeSubgroupId (builder, loc, linearId);
386+ return parent.delinearizeId (builder, loc, linearId, level );
379387}
380388
381- // / Implements DistributeLayoutAttr::getOffsets to generate
382- // / instructions for computing multi-dimensional offsets when distributed by
383- // / SliceAttr .
389+ // Implements DistributeLayoutAttr::computeDistributedCoords to generate
390+ // instructions for computing multi-dimensional offsets when distributed by
391+ // LayoutAttr .
384392FailureOr<SmallVector<SmallVector<Value>>>
385- SliceAttr::getOffsets (OpBuilder &builder, Location loc, Value linearId,
386- ArrayRef<int64_t > shape) {
393+ SliceAttr::computeDistributedCoords (OpBuilder &builder, Location loc,
394+ Value linearId, ArrayRef<int64_t > shape,
395+ xegpu::DistributionLevel targetLevel) {
387396 assert (getRank () == static_cast <int64_t >(shape.size ()) && " invalid shape." );
388397 if (!isForWorkgroup ())
389398 return failure ();
390399
391- SmallVector<int64_t > sgLayout = getEffectiveSgLayoutAsInt ();
392- SmallVector<int64_t > sgShape = getEffectiveSgDataAsInt ();
393- if (sgShape.empty ()) {
394- if (auto derivedShape = computeShapeRatio (shape, sgLayout))
395- sgShape = derivedShape.value ();
400+ SmallVector<int64_t > layout;
401+ SmallVector<int64_t > subShape;
402+ if (targetLevel == DistributionLevel::SG) {
403+ layout = getEffectiveSgLayoutAsInt ();
404+ subShape = getEffectiveSgDataAsInt ();
405+ } else if (targetLevel == DistributionLevel::WI) {
406+ layout = getEffectiveLaneLayoutAsInt ();
407+ subShape = getEffectiveLaneDataAsInt ();
408+ } else {
409+ return failure ();
410+ }
411+
412+ if (subShape.empty ()) {
413+ if (auto derivedShape = computeShapeRatio (shape, layout))
414+ subShape = derivedShape.value ();
396415 else
397416 return failure ();
398417 }
399418
400419 // delinearize Ids
401- auto maybeIds = delinearizeSubgroupId (builder, loc, linearId);
420+ auto maybeIds = delinearizeId (builder, loc, linearId, targetLevel );
402421 if (failed (maybeIds))
403422 return failure ();
404423
@@ -408,8 +427,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
408427 SmallVector<Value> sgIds =
409428 XeGPUDialect::slice (ArrayRef<Value>(*maybeIds), dims);
410429
411- return genOffsetsComputingInsts (builder, loc, sgIds, sgLayout, sgShape,
412- shape);
430+ return genOffsets (builder, loc, sgIds, layout, subShape, shape);
413431}
414432
415433bool SliceAttr::isSliceOf (const xegpu::DistributeLayoutAttr &other) {
0 commit comments