Skip to content

Commit 322cd5b

Browse files
authored
[BACKEND] Reshape the allocShape within MemDescReshapeOp (#7495)
This follows the same pattern as `MemdescTransOp`. To do so, we align more the op with `MemDescTransOp` by inferring the output type automatically.
1 parent 7d3bf12 commit 322cd5b

File tree

10 files changed

+151
-70
lines changed

10 files changed

+151
-70
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,27 @@ def TTG_MemDescReshapeOp : TTG_Op<"memdesc_reshape", [Pure,
273273
}];
274274

275275
let arguments = (ins TTG_MemDescType:$src);
276+
277+
let builders = [
278+
OpBuilder<(ins "Value":$src, "ArrayRef<int64_t>":$shape),
279+
[{
280+
MemDescType dstTy;
281+
auto srcTy = cast<MemDescType>(src.getType());
282+
auto result = inferReturnTypes($_builder.getContext(),
283+
$_builder.getUnknownLoc(),
284+
srcTy, shape, dstTy);
285+
assert(succeeded(result) && "failed to infer return types");
286+
build($_builder, $_state, dstTy, src);
287+
}]>
288+
];
289+
let extraClassDeclaration = [{
290+
static LogicalResult inferReturnTypes(MLIRContext *context,
291+
std::optional<Location> loc,
292+
MemDescType srcTy,
293+
ArrayRef<int64_t> dstShape,
294+
MemDescType &inferredReturnType);
295+
}];
296+
276297
let results = (outs TTG_MemDescType:$result);
277298

278299
let assemblyFormat = "$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))";

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2466,6 +2466,38 @@ struct TritonGPUInferLayoutInterface
24662466
Attribute srcEnc,
24672467
ArrayRef<int64_t> dstShape,
24682468
Attribute &dstEnc) const {
2469+
if (auto mmaEncoding = dyn_cast<NVMMASharedEncodingAttr>(srcEnc)) {
2470+
// TODO: supporting reshape of CTA layouts is non-trivial.
2471+
if (getNumCTAs(mmaEncoding) > 1)
2472+
return failure();
2473+
int innerDimDst =
2474+
mmaEncoding.getTransposed() ? dstShape.front() : dstShape.back();
2475+
int innerDimSrc =
2476+
mmaEncoding.getTransposed() ? srcShape.front() : srcShape.back();
2477+
// For now disallow reshape of the inner dimension.
2478+
if (innerDimDst != innerDimSrc)
2479+
return failure();
2480+
auto *ctx = srcEnc.getContext();
2481+
2482+
// CTALayout can be all 1's because we bailed on multi-CTA layouts above.
2483+
auto CTALayout = CTALayoutAttr::get(
2484+
ctx,
2485+
/*CTAsPerCGA=*/SmallVector<unsigned>(dstShape.size(), 1),
2486+
/*CTASplitNum=*/SmallVector<unsigned>(dstShape.size(), 1),
2487+
/*CTAOrder=*/llvm::to_vector(llvm::seq<unsigned>(dstShape.size())));
2488+
dstEnc = NVMMASharedEncodingAttr::get(
2489+
ctx, mmaEncoding.getSwizzlingByteWidth(), mmaEncoding.getTransposed(),
2490+
mmaEncoding.getElementBitWidth(), mmaEncoding.getFp4Padded(),
2491+
CTALayout);
2492+
// Big guns, check linear layouts are equivalent
2493+
// We disallow reshaping memdesc_subviews in the verifier
2494+
auto srcLL = toLinearLayout(srcShape, srcEnc);
2495+
auto dstLL = toLinearLayout(dstShape, dstEnc);
2496+
if (reshapeLayout(ctx, srcLL, dstShape) != dstLL) {
2497+
return failure();
2498+
}
2499+
return success();
2500+
}
24692501
auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
24702502
if (!src) {
24712503
return failure();
@@ -2713,6 +2745,10 @@ struct TritonGPUInferLayoutInterface
27132745
if (succeeded(result)) {
27142746
return result;
27152747
}
2748+
if (!isa<DistributedEncodingTrait>(srcEnc)) {
2749+
return emitOptionalError(loc,
2750+
"Failed MemDescReshapeOp encoding inference");
2751+
}
27162752
// If the legacy encoding failed use LinearLayouts.
27172753
// Once LinearLayouts are more widely used, we can remove
27182754
// inferReshapeOpLegacyEncoding and simply use LLs.

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -474,19 +474,48 @@ LogicalResult MemDescReshapeOp::verify() {
474474
return emitError("result element type must match src element type");
475475
}
476476

477-
// Infer the dst layout from the source and verify that it is equivalent.
478-
auto srcEncoding = srcType.getEncoding();
479-
Attribute inferedDstEncoding;
480-
481-
LinearLayout ll = inferReshapeLinearLayout(cast<TensorOrMemDesc>(srcType),
482-
dstType.getShape());
483-
LinearLayout llDst = triton::gpu::toLinearLayout(dstType);
484-
if (ll != llDst) {
477+
MemDescType expectedTy;
478+
if (failed(inferReturnTypes(getContext(), getLoc(), srcType,
479+
dstType.getShape(), expectedTy)))
480+
return failure();
481+
// Check that the alloc shape separately to give a cleaner error, given that
482+
// it's the most likely source of the error.
483+
if (expectedTy.getAllocShape() != dstType.getAllocShape()) {
484+
return emitError(
485+
"The result alloc shape does not match the expected alloc shape.");
486+
}
487+
if (expectedTy != dstType) {
485488
return emitError("source and destination layout are incompatible.");
486489
}
487490
return success();
488491
}
489492

493+
LogicalResult MemDescReshapeOp::inferReturnTypes(
494+
MLIRContext *context, std::optional<Location> loc, MemDescType srcTy,
495+
ArrayRef<int64_t> dstShape, MemDescType &inferredReturnType) {
496+
if (product<int64_t>(dstShape) != product<int64_t>(srcTy.getShape()))
497+
return emitOptionalError(
498+
loc, "dst shape has different number of elements than src");
499+
500+
Attribute dstEncoding;
501+
if (Attribute srcEnc = srcTy.getEncoding()) {
502+
auto *inferLayout = cast<DialectInferLayoutInterface>(&srcEnc.getDialect());
503+
if (failed(inferLayout->inferReshapeOpEncoding(srcTy.getShape(), srcEnc,
504+
dstShape, dstEncoding, loc)))
505+
return failure();
506+
}
507+
508+
SmallVector<int64_t> dstAllocShape =
509+
to_vector(srcTy.getAllocShape().take_front(srcTy.getAllocShape().size() -
510+
srcTy.getShape().size()));
511+
dstAllocShape.append(dstShape.begin(), dstShape.end());
512+
513+
inferredReturnType = MemDescType::get(
514+
dstShape, srcTy.getElementType(), dstEncoding, srcTy.getMemorySpace(),
515+
srcTy.getMutableMemory(), dstAllocShape);
516+
return success();
517+
}
518+
490519
// MemDescReinterpretOp
491520
LogicalResult MemDescReinterpretOp::verify() {
492521
if (getSrc().getType().getMemorySpace() != getType().getMemorySpace())

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 11 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -143,44 +143,6 @@ class FuseTransMMAV3Plus : public OpRewritePattern<LocalAllocOp> {
143143
}
144144
};
145145

146-
static Attribute inferSrcEncodingMemDescReshape(ArrayRef<int64_t> srcShape,
147-
MemDescType dstType) {
148-
auto dstEncoding = dstType.getEncoding();
149-
auto dstShape = dstType.getShape();
150-
auto mmaEncoding = dyn_cast<NVMMASharedEncodingAttr>(dstEncoding);
151-
if (!mmaEncoding)
152-
return {};
153-
// TODO: supporting reshape of CTA layouts is non-trivial.
154-
if (getNumCTAs(mmaEncoding) > 1)
155-
return {};
156-
int innerDimDst =
157-
mmaEncoding.getTransposed() ? dstShape.front() : dstShape.back();
158-
int innerDimSrc =
159-
mmaEncoding.getTransposed() ? srcShape.front() : srcShape.back();
160-
// For now disallow reshape of the inner dimension.
161-
if (innerDimDst != innerDimSrc)
162-
return {};
163-
164-
// CTALayout can be all 1's because we bailed on multi-CTA layouts above.
165-
auto CTALayout = CTALayoutAttr::get(
166-
dstEncoding.getContext(),
167-
/*CTAsPerCGA=*/SmallVector<unsigned>(srcShape.size(), 1),
168-
/*CTASplitNum=*/SmallVector<unsigned>(srcShape.size(), 1),
169-
/*CTAOrder=*/llvm::to_vector(llvm::seq<unsigned>(srcShape.size())));
170-
auto srcEncoding = NVMMASharedEncodingAttr::get(
171-
dstEncoding.getContext(), mmaEncoding.getSwizzlingByteWidth(),
172-
mmaEncoding.getTransposed(), mmaEncoding.getElementBitWidth(),
173-
mmaEncoding.getFp4Padded(), CTALayout);
174-
// Big guns, check linear layouts are equivalent
175-
auto srcLL = toLinearLayout(srcShape, srcEncoding);
176-
auto dstLL = toLinearLayout(dstShape, dstEncoding);
177-
auto ctx = dstEncoding.getContext();
178-
if (reshapeLayout(ctx, srcLL, dstShape) != dstLL) {
179-
return {};
180-
}
181-
return srcEncoding;
182-
}
183-
184146
// Rewrite
185147
//
186148
// alloc(reshape(), #shared1) ->
@@ -204,18 +166,21 @@ class ReshapeMemDesc : public OpRewritePattern<LocalAllocOp> {
204166
auto allocEncoding = allocType.getEncoding();
205167

206168
RankedTensorType srcTy = reshapeOp.getSrc().getType();
207-
auto newAllocEncoding =
208-
inferSrcEncodingMemDescReshape(srcTy.getShape(), allocType);
209-
if (!newAllocEncoding)
169+
auto srcShape = srcTy.getShape();
170+
auto dstShape = allocType.getShape();
171+
172+
// We use the fact that forward and backward inference are the same for
173+
// MemDescReshapeOp to infer the source MemDescType that would produce
174+
// `allocType` after a reshape.
175+
MemDescType innerTy;
176+
if (failed(MemDescReshapeOp::inferReturnTypes(
177+
getContext(), allocOp.getLoc(), allocType, srcShape, innerTy)))
210178
return failure();
211179

212-
MemDescType innerTy =
213-
MemDescType::get(srcTy.getShape(), srcTy.getElementType(),
214-
newAllocEncoding, allocType.getMemorySpace());
215180
auto newAlloc = rewriter.create<LocalAllocOp>(allocOp.getLoc(), innerTy,
216181
reshapeOp.getSrc());
217-
rewriter.replaceOpWithNewOp<MemDescReshapeOp>(allocOp, allocOp.getType(),
218-
newAlloc);
182+
rewriter.replaceOpWithNewOp<MemDescReshapeOp>(allocOp, newAlloc,
183+
allocOp.getType().getShape());
219184
return success();
220185
}
221186
};

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,8 +1500,9 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
15001500
newVal = builder.create<ttg::MemDescTransOp>(trans.getLoc(), val,
15011501
trans.getOrder());
15021502
} else if (auto reshape = dyn_cast<ttg::MemDescReshapeOp>(user)) {
1503-
newVal = builder.create<ttg::MemDescReshapeOp>(reshape.getLoc(),
1504-
reshape.getType(), val);
1503+
auto shape = reshape.getType().getShape();
1504+
newVal =
1505+
builder.create<ttg::MemDescReshapeOp>(reshape.getLoc(), val, shape);
15051506
}
15061507
assert(newVal && "unhandled memdesc view");
15071508
newVal.getDefiningOp()->setAttrs(user->getAttrs());

python/src/gluon_ir.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,9 @@ void init_gluon_ir(py::module &&m) {
396396
return self.create<ttg::MemDescTransOp>(src, order);
397397
})
398398
.def("create_memdesc_reshape",
399-
[](GluonOpBuilder &self, Type resultType, Value src) -> Value {
400-
return self.create<ttg::MemDescReshapeOp>(resultType, src);
399+
[](GluonOpBuilder &self, Value src,
400+
std::vector<int64_t> &shape) -> Value {
401+
return self.create<ttg::MemDescReshapeOp>(src, shape);
401402
})
402403
.def("create_memdesc_reinterpret",
403404
[](GluonOpBuilder &self, Type resultType, Value src) -> Value {

python/test/gluon/test_frontend.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,8 @@ def shared_memory_cast_kernel():
309309

310310
layout_b: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=False, element_bitwidth=16,
311311
rank=4, cta_order=[3, 2, 1, 0])
312-
layout_reshape: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=False,
313-
element_bitwidth=16, rank=2)
314312
smem = ttgl.allocate_shared_memory(ttgl.float16, [32, 1, 4, 64], layout_b)
315-
smem.reshape((128, 64), layout_reshape)
313+
smem.reshape((128, 64))
316314

317315
smem._reinterpret(ttgl.int8, [1024], ttgl.SwizzledSharedLayout(1, 1, 1, [0, 1]))
318316

@@ -336,7 +334,7 @@ def test_shared_memory_cast(fresh_knobs):
336334
%2 = ttg.memdesc_trans %1 {order = array<i32: 1, 0>} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>
337335
tt.call @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[2, 128, 256]ASMD__"(%2) : (!ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>) -> ()
338336
%3 = ttg.local_alloc : () -> !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable>
339-
%4 = ttg.memdesc_reshape %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared3, #smem, mutable, 32x1x4x64>
337+
%4 = ttg.memdesc_reshape %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared3, #smem, mutable>
340338
%5 = ttg.memdesc_reinterpret %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<1024xi8, #shared4, #smem, mutable>
341339
tt.return
342340
}

python/triton/experimental/gluon/language/_core.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,21 +293,19 @@ def permute(self, order, _semantic: GluonSemantic) -> shared_memory_descriptor:
293293
return _semantic.memdesc_trans(self, order)
294294

295295
@builtin
296-
def reshape(self, shape, layout, _semantic: GluonSemantic) -> shared_memory_descriptor:
296+
def reshape(self, shape, _semantic: GluonSemantic) -> shared_memory_descriptor:
297297
"""
298298
Reshape the shared memory descriptor to a new shape and layout.
299299
300300
Args:
301301
shape (List[int]): The target shape.
302-
layout (SharedLayout): The new layout for the descriptor.
303302
304303
Returns:
305304
shared_memory_descriptor: Descriptor with the new shape and layout.
306305
"""
307306
shape = [_unwrap_if_constexpr(s) for s in shape]
308-
layout = _unwrap_if_constexpr(layout)
309307

310-
return _semantic.memdesc_reshape(self, shape, layout)
308+
return _semantic.memdesc_reshape(self, shape)
311309

312310
@builtin
313311
def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> shared_memory_descriptor:

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Sequence, List, TypeVar, Tuple, Callable
2+
import math
23
from triton.language.semantic import TritonSemantic
34
from . import _core as ttgl
45
from ._layouts import SliceLayout, AutoLayout
@@ -213,10 +214,26 @@ def memdesc_trans(self, mem_desc, order):
213214
return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape,
214215
alloc_shape=new_alloc_shape, layout=layout)
215216

216-
def memdesc_reshape(self, mem_desc, shape, layout):
217-
ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
218-
handle = self.builder.create_memdesc_reshape(ty.to_ir(self.builder), mem_desc.handle)
219-
return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
217+
def memdesc_reshape(self, mem_desc, shape):
218+
_check(
219+
math.prod(shape) == math.prod(mem_desc.shape),
220+
lambda: (f"memdesc_reshape total elements mismatch: "
221+
f"{mem_desc.shape} -> {shape}"),
222+
)
223+
224+
handle = self.builder.create_memdesc_reshape(mem_desc.handle, shape)
225+
layout = self.builder.get_gluon_layout_from_memdesc(handle)
226+
alloc_shape = mem_desc.type.alloc_shape
227+
prefix_len = len(alloc_shape) - mem_desc.rank
228+
new_alloc_shape = alloc_shape[:prefix_len] + list(shape)
229+
230+
return ttgl.shared_memory_descriptor(
231+
handle,
232+
element_ty=mem_desc.dtype,
233+
shape=shape,
234+
alloc_shape=new_alloc_shape,
235+
layout=layout,
236+
)
220237

221238
def memdesc_reinterpret(self, mem_desc, dtype, shape, layout):
222239
ty = ttgl.shared_memory_descriptor_type(dtype, shape, layout, shape)

test/TritonGPU/ops.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,21 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w
6666
}
6767
}
6868

69+
// -----
70+
71+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CTAsPerCGA = [1,1,1,1], CTASplitNum = [1,1,1,1], CTAOrder = [3, 2, 1, 0]}>
72+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16}>
73+
#smem = #ttg.shared_memory
74+
module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
75+
// CHECK-LABEL: memdesc_reshape
76+
// CHECK: !ttg.memdesc<128x64xf16, #{{.+}}, mutable>
77+
tt.func @memdesc_reshape(%d : !ttg.memdesc<32x1x4x64xf16, #shared, #smem, mutable>){
78+
%1 = ttg.memdesc_reshape %d : !ttg.memdesc<32x1x4x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>
79+
tt.return
80+
}
81+
}
82+
83+
6984
// -----
7085

7186
// CHECK-LABEL: @warp_specialize_nothing

0 commit comments

Comments
 (0)