Skip to content

Commit 4d5dc49

Browse files
committed
Reenable rewrite tensor ptr
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 2546665 commit 4d5dc49

File tree

2 files changed

+16
-20
lines changed

2 files changed

+16
-20
lines changed

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def make_ttgir(mod, metadata, opt, properties):
235235
intel.passes.ttgpuir.add_accelerate_matmul(pm)
236236
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
237237
intel.passes.ttgpuir.add_materialize_block_pointer(pm)
238-
# intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm)
238+
intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm)
239239
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False)
240240

241241
intel.passes.ttgpuir.add_coalesce(pm)

third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ namespace ttgi = mlir::triton::gpu::intel;
2828

2929
namespace {
3030

31+
RankedTensorType getRankedTensorType(Type ptrTy) {
32+
return tt::isTensorPointerType(ptrTy)
33+
? cast<RankedTensorType>(
34+
cast<tt::PointerType>(ptrTy).getPointeeType())
35+
: dyn_cast<RankedTensorType>(ptrTy);
36+
}
37+
3138
struct CoalescePass
3239
: public ttgi::impl::TritonIntelGPUCoalesceBase<CoalescePass> {
3340
private:
@@ -49,12 +56,7 @@ struct CoalescePass
4956
SmallVector<unsigned> order = argSort(contiguity);
5057
LDBG("order=[" << triton::join(order, ", ") << "]");
5158

52-
RankedTensorType refTensorType =
53-
tt::isTensorPointerType(ptr.getType())
54-
? cast<RankedTensorType>(
55-
cast<tt::PointerType>(ptr.getType()).getPointeeType())
56-
: cast<RankedTensorType>(ptr.getType());
57-
59+
RankedTensorType refTensorType = getRankedTensorType(ptr.getType());
5860
auto matchesShape = [&refTensorType](const Value &val) {
5961
auto rttType = dyn_cast<RankedTensorType>(val.getType());
6062
return rttType && rttType.getShape() == refTensorType.getShape();
@@ -197,7 +199,7 @@ struct CoalescePass
197199
"Unexpected layout");
198200

199201
auto resType = cast<tt::PointerType>(res.getType());
200-
auto tensorType = cast<RankedTensorType>(resType.getPointeeType());
202+
RankedTensorType tensorType = getRankedTensorType(resType);
201203
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
202204
resType.getAddressSpace()));
203205
}
@@ -243,7 +245,7 @@ struct CoalescePass
243245
"Unexpected layout");
244246

245247
auto resType = cast<tt::PointerType>(res.getType());
246-
auto tensorType = cast<RankedTensorType>(resType.getPointeeType());
248+
RankedTensorType tensorType = getRankedTensorType(resType);
247249
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
248250
resType.getAddressSpace()));
249251
}
@@ -281,8 +283,10 @@ struct CoalescePass
281283
}
282284
}
283285

284-
// Change the \p layout of the \p op result(s) and propagate the new result
285-
// type to its users.
286+
// TODO: change the implementation to handle only operation yielding one
287+
// result?
288+
// Change the \p layout of the \p op result(s) and propagate the new
289+
// result type to its users.
286290
static void changeAndPropagateLayout(Operation *op, Attribute layout,
287291
IRRewriter &rewriter) {
288292
assert(op && op->getNumResults() != 0 &&
@@ -293,10 +297,6 @@ struct CoalescePass
293297
if (!tt::isTensorPointerType(res.getType()))
294298
continue;
295299

296-
// Problem: if the operation is a for loop we cannot modify the layout
297-
// of all the tensor ptr results, we need to modify only the one used by
298-
// the yield operation.
299-
300300
auto ptrType = cast<tt::PointerType>(res.getType());
301301
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
302302
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
@@ -382,11 +382,7 @@ struct CoalescePass
382382
if (!ptr)
383383
return;
384384

385-
RankedTensorType refTensorType =
386-
tt::isTensorPointerType(ptr.getType())
387-
? cast<RankedTensorType>(
388-
cast<tt::PointerType>(ptr.getType()).getPointeeType())
389-
: dyn_cast<RankedTensorType>(ptr.getType());
385+
RankedTensorType refTensorType = getRankedTensorType(ptr.getType());
390386
if (!refTensorType || !refTensorType.getEncoding())
391387
return;
392388

0 commit comments

Comments
 (0)