Skip to content

Commit b3d9c5a

Browse files
committed
[RemoveLayoutConversions]: Update index computations
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 26ca64e commit b3d9c5a

File tree

1 file changed

+42
-12
lines changed

1 file changed

+42
-12
lines changed

third_party/intel/lib/Dialect/Triton/Transforms/FuseReshape.cpp

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -175,43 +175,73 @@ class FuseReshape {
175175

176176
// Create a MakeTensorPtrOp yielding a 2-dim block pointer.
177177
auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType());
178-
ArrayRef<int64_t> origShape =
178+
[[maybe_unused]] ArrayRef<int64_t> resShape =
179179
cast<RankedTensorType>(ptrType.getPointeeType()).getShape();
180-
assert(origShape[0] && "First shape extent is not one");
180+
assert(resShape[0] == 1 && "Result shape should have extent equal to 1 in "
181+
"the outermost dimension");
181182

182183
auto tensorType = cast<RankedTensorType>(reshapeOp.getType());
183184
auto newPtrType =
184185
tt::PointerType::get(tensorType, ptrType.getAddressSpace());
185186

186-
unsigned innermostDimIdx = 0;
187+
// Compute the index of the innermost dimension.
187188
ArrayRef<int> order = makeTensorPtrOp.getOrder();
188-
for (int i : order) {
189-
if (i == 0)
189+
assert(order.size() == 3 && order[0] == 2 && "Invalid order");
190+
191+
unsigned innermostDimIdx = 0;
192+
for (int elem : makeTensorPtrOp.getOrder()) {
193+
if (elem == 0)
190194
break;
191195
++innermostDimIdx;
192196
}
193197

194198
OpBuilder builder(makeTensorPtrOp);
195199
Location loc = makeTensorPtrOp.getLoc();
196-
Value firstShape = makeTensorPtrOp.getShape().front();
197-
Value firstStride = makeTensorPtrOp.getStrides().front();
198-
Value firstOffset = makeTensorPtrOp.getOffsets().front();
200+
OperandRange shapes = makeTensorPtrOp.getShape();
201+
OperandRange strides = makeTensorPtrOp.getStrides();
202+
OperandRange offsets = makeTensorPtrOp.getOffsets();
203+
204+
#if 0
205+
// order=2,1,0 --> idx = 2 (row major) --> idx we want = 1
206+
// order=2,0,1 --> idx = 1 (column major) --> idx we want == 0
199207

200208
SmallVector<Value> newShape(makeTensorPtrOp.getShape().drop_front());
201209
newShape[innermostDimIdx - 1] = builder.create<arith::AddIOp>(
202-
loc, builder.create<arith::MulIOp>(loc, firstStride, firstShape),
210+
loc, builder.create<arith::MulIOp>(loc, strides[0], shapes[0]),
203211
newShape[innermostDimIdx - 1]);
204212
SmallVector<Value> newStrides(makeTensorPtrOp.getStrides().drop_front());
205213
SmallVector<Value> newOffsets(makeTensorPtrOp.getOffsets().drop_front());
206214
newOffsets[innermostDimIdx - 1] = builder.create<arith::AddIOp>(
207215
loc,
208216
builder.create<arith::MulIOp>(
209217
loc,
210-
builder.create<arith::TruncIOp>(loc, firstOffset.getType(),
211-
firstStride),
212-
firstOffset),
218+
builder.create<arith::TruncIOp>(loc, offsets[0].getType(),
219+
strides[0]),
220+
offsets[0]),
213221
newOffsets[innermostDimIdx - 1]);
222+
#else
223+
// order=2,1,0 --> idx = 2 (row major) --> idx we want = 0
224+
// order=2,0,1 --> idx = 1 (column major) --> idx we want == 1
225+
226+
unsigned newInnermostDimIdx = (innermostDimIdx - 1);
227+
unsigned newOutermostDimIdx = !newInnermostDimIdx;
214228

229+
SmallVector<Value> newShape(makeTensorPtrOp.getShape().drop_front());
230+
SmallVector<Value> newStrides(makeTensorPtrOp.getStrides().drop_front());
231+
SmallVector<Value> newOffsets(makeTensorPtrOp.getOffsets().drop_front());
232+
233+
auto div = builder.create<arith::DivUIOp>(loc, strides[0],
234+
newStrides[newOutermostDimIdx]);
235+
newShape[newOutermostDimIdx] = builder.create<arith::AddIOp>(
236+
loc, builder.create<arith::MulIOp>(loc, shapes[0], div),
237+
newShape[newOutermostDimIdx]);
238+
newOffsets[newOutermostDimIdx] = builder.create<arith::AddIOp>(
239+
loc,
240+
builder.create<arith::MulIOp>(
241+
loc, offsets[0],
242+
builder.create<arith::TruncIOp>(loc, offsets[0].getType(), div)),
243+
newOffsets[newOutermostDimIdx]);
244+
#endif
215245
Value ptr = builder.create<tt::MakeTensorPtrOp>(
216246
loc, newPtrType, makeTensorPtrOp.getBase(), newShape, newStrides,
217247
newOffsets,

0 commit comments

Comments
 (0)