Skip to content

Commit b1c3c72

Browse files
committed
Enable the nested layout #slice->#dot->#mma to linear layout conversion for third party extension.
1 parent daa8c28 commit b1c3c72

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,15 @@ unsigned getTotalElemsPerThread(Type type) {
8686
if (type.isIntOrIndexOrFloat() || isa<triton::PointerType>(type))
8787
return 1;
8888
auto tensorType = cast<RankedTensorType>(type);
89+
90+
std::optional<LinearLayout> ll = triton::gpu::toLinearLayout(
91+
tensorType.getShape(), tensorType.getEncoding());
92+
if (ll.has_value()) {
93+
MLIRContext *ctx = tensorType.getContext();
94+
auto kRegister = StringAttr::get(ctx, "register");
95+
return ll->getInDimSize(kRegister);
96+
}
97+
// fallback to legacy layout interface.
8998
return getTotalElemsPerThread(tensorType.getEncoding(), tensorType.getShape(),
9099
tensorType.getElementType());
91100
}

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -736,15 +736,15 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
736736
LinearLayout ret =
737737
LinearLayout(std::move(bases), llvm::to_vector(sliceLL.getOutDimNames()));
738738

739-
// Match a hack in the legacy code that ensures that the number of registers
740-
// matches getTotalElemsPerThread. Yup: We just removed all the zeros, now
741-
// we're (maybe) adding some back. :)
742-
//
743-
// TODO(jlebar): Once getTotalElemsPerThread uses LLs instead of the existing
744-
// legacy code, I think we can remove this.
745-
int expectedNumRegisters =
746-
triton::gpu::getTotalElemsPerThread(RankedTensorType::get(
747-
shape, IntegerType::get(ctx, 32) /*dummy type*/, *this));
739+
// The semantic of the slice layout:
740+
// The threads of the parent layout which are distributed on the
741+
// sliced dim are squeezed to hold the same value of tensor redundantly.
742+
// Only the number of values of sizePerThreads[dim] of the parent are reduced
743+
// to the one. We need to fix up the number of registers in case we just
744+
// removed all zeros bases aggressively.
745+
auto sizePerThreads = triton::gpu::getSizePerThread(getParent());
746+
unsigned expectedNumRegisters =
747+
parentLL->getInDimSize(S("register")) / sizePerThreads[getDim()];
748748
if (ret.getInDimSize(S("register")) != expectedNumRegisters) {
749749
int extraZeros = expectedNumRegisters / ret.getInDimSize(S("register"));
750750
// Our use of "dim0" here is arbitrary; because we're adding zeros, any

0 commit comments

Comments
 (0)