@@ -39,30 +39,32 @@ void XeGPUDialect::initialize() {
3939
4040// / Generates instructions to compute offsets for a subgroup identified by
4141// / its multidimensional indices (sgId), using the specified subgroup layout
42- // / (sgLayout), subgroup data dimensions (sgShape ), and the overall data
43- // / dimensions (shape ).
42+ // / (sgLayout), subgroup data dimensions (sizePerSg ), and the overall data
43+ // / dimensions (sizePerWg ).
4444static SmallVector<SmallVector<Value>>
45- genOffsetsComputations (OpBuilder &builder, Location loc,
46- SmallVector<Value> sgId, ArrayRef<int64_t > sgLayout,
47- ArrayRef<int64_t > sgShape, ArrayRef<int64_t > shape) {
45+ genOffsetsComputingInsts (OpBuilder &builder, Location loc,
46+ SmallVector<Value> sgId, ArrayRef<int64_t > sgLayout,
47+ ArrayRef<int64_t > sizePerSg,
48+ ArrayRef<int64_t > sizePerWg) {
4849
4950 SmallVector<SmallVector<Value>> offsets;
5051
51- // nd local offset, localOffset[i] = sgId[i] * sgShape [i]
52+ // nd local offset, localOffset[i] = sgId[i] * sizePerSg [i]
5253 SmallVector<Value> localOffsets = llvm::map_to_vector (
53- llvm::zip (sgId, sgShape ), [&](const auto &t) -> Value {
54+ llvm::zip (sgId, sizePerSg ), [&](const auto &t) -> Value {
5455 return builder.createOrFold <index::MulOp>(
5556 loc, std::get<0 >(t),
5657 builder.createOrFold <arith::ConstantIndexOp>(loc, std::get<1 >(t)));
5758 });
5859
59- // distUnit[i] is the minimum value between shape [i] and
60- // sgLayout[i] * sgShape [i]
60+ // distUnit[i] is the minimum value between sizePerWg [i] and
61+ // sgLayout[i] * sizePerSg [i]
6162 SmallVector<int64_t > distUnit = llvm::map_to_vector (
62- llvm::zip_equal (shape , computeElementwiseMul (sgLayout, sgShape )),
63+ llvm::zip_equal (sizePerWg , computeElementwiseMul (sgLayout, sizePerSg )),
6364 [](const auto &t) { return std::min (std::get<0 >(t), std::get<1 >(t)); });
6465
65- for (SmallVector<int64_t > unitOffs : StaticTileOffsetRange (shape, distUnit)) {
66+ for (SmallVector<int64_t > unitOffs :
67+ StaticTileOffsetRange (sizePerWg, distUnit)) {
6668 SmallVector<Value> base =
6769 llvm::map_to_vector (unitOffs, [&](int64_t d) -> Value {
6870 return builder.create <arith::ConstantIndexOp>(loc, d);
@@ -75,7 +77,7 @@ genOffsetsComputations(OpBuilder &builder, Location loc,
7577 });
7678
7779 SmallVector<Value> mods = llvm::map_to_vector (
78- llvm::zip_equal (adds, shape ), [&](const auto &t) -> Value {
80+ llvm::zip_equal (adds, sizePerWg ), [&](const auto &t) -> Value {
7981 return builder.createOrFold <index::RemUOp>(
8082 loc, std::get<0 >(t),
8183 builder.create <arith::ConstantIndexOp>(loc, std::get<1 >(t)));
@@ -300,8 +302,8 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
300302 SmallVector<int64_t > sgShape;
301303 if (auto maybeSgShape = getSgDataAsInt ())
302304 sgShape = maybeSgShape.value ();
303- else if (auto ratio = computeShapeRatio (shape, sgLayout))
304- sgShape = ratio .value ();
305+ else if (auto derivedShape = computeShapeRatio (shape, sgLayout))
306+ sgShape = derivedShape .value ();
305307 else
306308 return failure ();
307309
@@ -311,7 +313,8 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
311313 return failure ();
312314 SmallVector<Value> sgIds = *maybeIds;
313315
314- return genOffsetsComputations (builder, loc, sgIds, sgLayout, sgShape, shape);
316+ return genOffsetsComputingInsts (builder, loc, sgIds, sgLayout, sgShape,
317+ shape);
315318}
316319
317320// ===----------------------------------------------------------------------===//
@@ -401,7 +404,8 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
401404 SmallVector<Value> sgIds =
402405 XeGPUDialect::slice (ArrayRef<Value>(*maybeIds), dims);
403406
404- return genOffsetsComputations (builder, loc, sgIds, sgLayout, sgShape, shape);
407+ return genOffsetsComputingInsts (builder, loc, sgIds, sgLayout, sgShape,
408+ shape);
405409}
406410
407411// ===----------------------------------------------------------------------===//
0 commit comments