Skip to content

Commit cdd7a3d

Browse files
[AMD] Support register broadcast in slice/concat ops (#7407)
This PR: - permits arbitrary broadcasted register in layouts - fixes few possible crashes in verifier in case of broadcasted layouts --------- Co-authored-by: Alexander Efimov <[email protected]>
1 parent b80f5dd commit cdd7a3d

File tree

8 files changed

+277
-80
lines changed

8 files changed

+277
-80
lines changed

test/Conversion/amd/invalid_concat_op.mlir

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,18 +157,19 @@ module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32,
157157
// -----
158158

159159
// Different layouts 2
160-
#src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
161-
#dst_layout = #ttg.linear<{register=[[0, 0], [0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
160+
// Case when src and dst layouts have same CTA tile shape, but different number of registers
161+
#src_layout = #ttg.linear<{register=[[1, 0], [2, 0]], lane=[[4, 0], [8, 0], [16, 0], [0, 1], [0, 2], [0, 4]], warp=[[0, 0], [0, 8]], block=[]}>
162+
#dst_layout = #ttg.linear<{register=[[1, 0]], lane=[[4, 0], [8, 0], [16, 0], [0, 1], [0, 2], [0, 4]], warp=[[2, 0], [0, 8]], block=[]}>
162163
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
163164
tt.func @invalid_concat(
164-
%arg0: tensor<128x128xf32, #src_layout>,
165-
%arg1: tensor<128x128xf32, #src_layout>,
166-
%arg2: tensor<128x128xf32, #src_layout>,
167-
%arg3: tensor<128x128xf32, #src_layout>) {
165+
%arg0: tensor<32x16xf32, #src_layout>,
166+
%arg1: tensor<32x16xf32, #src_layout>,
167+
%arg2: tensor<32x16xf32, #src_layout>,
168+
%arg3: tensor<32x16xf32, #src_layout>) {
168169

169170
// expected-error @+1 {{Register basis must match on a CTA tile between source and destination.}}
170171
%1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3:
171-
tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout> -> tensor<256x256xf32, #dst_layout>
172+
tensor<32x16xf32, #src_layout>, tensor<32x16xf32, #src_layout>, tensor<32x16xf32, #src_layout>, tensor<32x16xf32, #src_layout> -> tensor<64x32xf32, #dst_layout>
172173
tt.return
173174
}
174175
}

test/Conversion/amd/invalid_extractslice_to_llvm.mlir

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,23 @@ tt.func @invalid_non_static_offset(%arg0: tensor<256x128xi32, #blocked1> {tt.div
8383

8484
// Invalid layout 1
8585
#dst_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
86-
#src_layout = #ttg.linear<{register=[[0, 0], [0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
87-
tt.func @invalid_register_base(%arg0: tensor<256x256xi32, #src_layout> {tt.divisibility = 16 : i32}) {
88-
// expected-error @+1 {{Register basis must match on a CTA tile between source and destination}}
86+
#src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4], [0, 0]], warp=[[0, 32], [32, 0]], block=[]}>
87+
tt.func @invalid_lane_warp_basis(%arg0: tensor<256x256xi32, #src_layout> {tt.divisibility = 16 : i32}) {
88+
// expected-error @+1 {{Lane and warp dim basis must match between source and destination layout}}
8989
%2 = amdgpu.extract_slice %arg0 [0, 0] : tensor<256x256xi32, #src_layout> to tensor<128x128xi32, #dst_layout>
9090
tt.return
9191
}
9292

9393
// -----
9494

9595
// Invalid layout 2
96-
#dst_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
97-
#src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4], [0, 0]], warp=[[0, 32], [32, 0]], block=[]}>
98-
tt.func @invalid_lane_warp_basis(%arg0: tensor<256x256xi32, #src_layout> {tt.divisibility = 16 : i32}) {
99-
// expected-error @+1 {{Lane and warp dim basis must match between source and destination layout}}
100-
%2 = amdgpu.extract_slice %arg0 [0, 0] : tensor<256x256xi32, #src_layout> to tensor<128x128xi32, #dst_layout>
101-
tt.return
96+
// Case when src and dst layouts have same CTA tile shape, but different number of registers
97+
#src_layout = #ttg.linear<{register=[[1, 0], [2, 0]], lane=[[4, 0], [8, 0], [16, 0], [0, 1], [0, 2], [0, 4]], warp=[[0, 0], [0, 8]], block=[]}>
98+
#dst_layout = #ttg.linear<{register=[[1, 0]], lane=[[4, 0], [8, 0], [16, 0], [0, 1], [0, 2], [0, 4]], warp=[[2, 0], [0, 8]], block=[]}>
99+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
100+
tt.func @invalid_concat(%arg0: tensor<64x32xi32, #src_layout>) {
101+
// expected-error @+1 {{Register basis must match on a CTA tile between source and destination.}}
102+
%1 = amdgpu.extract_slice %arg0 [0, 0] : tensor<64x32xi32, #src_layout> to tensor<32x16xi32, #dst_layout>
103+
tt.return
104+
}
102105
}

test/TritonGPU/amd/amd-concat-op.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,39 @@ module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32,
103103
tt.return
104104
}
105105
}
106+
107+
// -----
108+
109+
// Each input tensor broadcasts 4 registers along dimension 1, resulting in total 16 values per input.
110+
// Output tensor do not have redundancy in registers and holds 8 values.
111+
// Check that concat copies only 4 values from each input tensor, 8 in total.
112+
#src_layout = #ttg.linear<{register=[[0, 0], [0, 0], [1, 0], [2, 0]], lane=[[0, 0], [0, 0], [0, 0], [4, 0], [8, 0], [16, 0]], warp=[[0, 0], [32, 0], [64, 0]], block=[]}>
113+
#dst_layout = #ttg.linear<{register=[ [1, 0], [2, 0]], lane=[[0, 0], [0, 0], [0, 0], [4, 0], [8, 0], [16, 0]], warp=[[0, 0], [32, 0], [64, 0]], block=[]}>
114+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
115+
tt.func @concat_from_broadcasted_tensor(%arg0: tensor<128x1xi32, #src_layout>, %arg1: tensor<128x1xi32, #src_layout> {tt.divisibility = 16 : i32}) {
116+
// CHECK-LABEL: llvm.func @concat_from_broadcasted_tensor
117+
// CHECK-COUNT-16: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct
118+
// CHECK-COUNT-16: %{{.*}} = llvm.extractvalue %arg1[{{.*}}] : !llvm.struct
119+
// CHECK-COUNT-8: %{{.*}} = llvm.insertvalue %{{.*}} : !llvm.struct
120+
%1 = amdgpu.concat %arg0, %arg1: tensor<128x1xi32, #src_layout>, tensor<128x1xi32, #src_layout> -> tensor<256x1xi32, #dst_layout>
121+
tt.return
122+
}
123+
}
124+
125+
// -----
126+
127+
// Input tensors do not have redundancy in register and hold 4 values each.
128+
// Output tensor broadcasts 4 registers along dimension 1, resulting in total 32 values.
129+
// Check that concat duplicates 4 values from each input 4 times, resulting in total 32 values.
130+
#src_layout = #ttg.linear<{register=[ [1, 0], [2, 0]], lane=[[0, 0], [0, 0], [0, 0], [4, 0], [8, 0], [16, 0]], warp=[[0, 0], [32, 0], [64, 0]], block=[]}>
131+
#dst_layout = #ttg.linear<{register=[[0, 0], [0, 0], [1, 0], [2, 0]], lane=[[0, 0], [0, 0], [0, 0], [4, 0], [8, 0], [16, 0]], warp=[[0, 0], [32, 0], [64, 0]], block=[]}>
132+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
133+
tt.func @concat_to_broadcasted_tensor(%arg0: tensor<128x1xi32, #src_layout>, %arg1: tensor<128x1xi32, #src_layout> {tt.divisibility = 16 : i32}) {
134+
// CHECK-LABEL: llvm.func @concat_to_broadcasted_tensor
135+
// CHECK-COUNT-4: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct
136+
// CHECK-COUNT-4: %{{.*}} = llvm.extractvalue %arg1[{{.*}}] : !llvm.struct
137+
// CHECK-COUNT-32: %{{.*}} = llvm.insertvalue %{{.*}} : !llvm.struct
138+
%1 = amdgpu.concat %arg0, %arg1: tensor<128x1xi32, #src_layout>, tensor<128x1xi32, #src_layout> -> tensor<256x1xi32, #dst_layout>
139+
tt.return
140+
}
141+
}

test/TritonGPU/amd/amd-extractslice-op.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,37 @@ module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32,
5555
tt.return
5656
}
5757
}
58+
59+
// -----
60+
61+
// Input tensor broadcasts 4 registers along dimension 1, resulting in total 32 values in tensor and 16 values per [128x1] tile.
62+
// Output tensor do not have redundancy in register and holds 4 values.
63+
// Test checks that extract slice copies only 4 values from input to output.
64+
#blocked1 = #ttg.linear<{register=[[0, 0], [0, 0], [1, 0], [2, 0], [128, 0]], lane=[[0, 0], [0, 0], [0, 0], [4, 0], [8, 0], [16, 0]], warp=[[0, 0], [32, 0], [64, 0]], block=[]}>
65+
#blocked2 = #ttg.linear<{register=[ [1, 0], [2, 0]], lane=[[0, 0], [0, 0], [0, 0], [4, 0], [8, 0], [16, 0]], warp=[[0, 0], [32, 0], [64, 0]], block=[]}>
66+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
67+
tt.func @extract_from_broadcasted_tensor(%arg0: tensor<256x1xi32, #blocked1> {tt.divisibility = 16 : i32}) {
68+
// CHECK-LABEL: llvm.func @extract_from_broadcasted_tensor
69+
// CHECK-COUNT-32: %{{.*}} = llvm.extractvalue %{{.*}} : !llvm.struct
70+
// CHECK-COUNT-4: %{{.*}} = llvm.insertvalue %{{.*}} : !llvm.struct
71+
%0 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x1xi32, #blocked1> to tensor<128x1xi32, #blocked2>
72+
tt.return
73+
}
74+
}
75+
76+
// -----
77+
78+
// Input tensor do not have broadcasted registers, resulting in total 8 values in tensor and 4 values per [128x1] tile.
79+
// Output tensor broadcasts 4 registers along dimension 1 and total 16 values.
80+
// Test checks that extract slice duplicates 4 values from input in 16 output values.
81+
#blocked1 = #ttg.linear<{register=[ [1, 0], [2, 0], [128, 0]], lane=[[0, 0], [0, 0], [0, 0], [4, 0], [8, 0], [16, 0]], warp=[[0, 0], [32, 0], [64, 0]], block=[]}>
82+
#blocked2 = #ttg.linear<{register=[[0, 0], [0, 0], [1, 0], [2, 0]], lane=[[0, 0], [0, 0], [0, 0], [4, 0], [8, 0], [16, 0]], warp=[[0, 0], [32, 0], [64, 0]], block=[]}>
83+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
84+
tt.func @extract_to_broadcasted_tensor(%arg0: tensor<256x1xi32, #blocked1> {tt.divisibility = 16 : i32}) {
85+
// CHECK-LABEL: llvm.func @extract_to_broadcasted_tensor
86+
// CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %{{.*}} : !llvm.struct
87+
// CHECK-COUNT-16: %{{.*}} = llvm.insertvalue %{{.*}} : !llvm.struct
88+
%72 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x1xi32, #blocked1> to tensor<128x1xi32, #blocked2>
89+
tt.return
90+
}
91+
}

third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ bool hasMatchingCTATileLayoutForSliceConcat(
7272
auto dstLL = triton::gpu::toLinearLayout(dstShape, dstTy.getEncoding());
7373

7474
MLIRContext *ctx = srcTy.getContext();
75+
auto kReg = StringAttr::get(ctx, "register");
76+
srcLL = srcLL.removeZeroBasesAlongDim(kReg);
77+
dstLL = dstLL.removeZeroBasesAlongDim(kReg);
78+
7579
auto getBases = [&](StringRef name) {
7680
auto key = StringAttr::get(ctx, name);
7781
return std::pair{srcLL.getBases().lookup(key),
@@ -98,14 +102,22 @@ bool hasMatchingCTATileLayoutForSliceConcat(
98102
numCTAs *= srcShape[d] / shapeCTASrc[d];
99103
}
100104

101-
unsigned elemsPerThreadPerCTA =
102-
triton::gpu::getTotalElemsPerThread(srcTy) / numCTAs;
103-
unsigned regCompareLen = std::log2(elemsPerThreadPerCTA);
105+
assert(llvm::isPowerOf2_32(numCTAs) &&
106+
"expect number of CTAs to be power of 2");
107+
108+
unsigned totalElemsPerThreadNoBroadcastLog = regSrc.size();
109+
unsigned elemsPerThreadPerCTALog =
110+
totalElemsPerThreadNoBroadcastLog - llvm::Log2_32(numCTAs);
111+
unsigned regCompareLen = elemsPerThreadPerCTALog;
104112

105113
auto compareBasis = [&](auto &srcBasis, auto &dstBasis, StringRef message,
106114
int limit = -1) {
107115
int n = (limit < 0 ? srcBasis.size()
108116
: std::min<unsigned>(srcBasis.size(), limit));
117+
if (dstBasis.size() < n) {
118+
emitError(message);
119+
return false;
120+
}
109121
for (size_t i = 0; i < n; ++i) {
110122
if (srcBasis[i] != dstBasis[i]) {
111123
emitError(message);

third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ struct ConcatOpConversion : public ConvertOpToLLVMPattern<amdgpu::ConcatOp> {
3333

3434
MLIRContext *context = resultType.getContext();
3535
auto linearLayoutSrc = triton::gpu::toLinearLayout(srcShape, srcEncoding);
36-
auto linearLayoutDst = triton::gpu::toLinearLayout(dstShape, dstEncoding);
36+
auto outDimNames = llvm::to_vector(linearLayoutSrc.getOutDimNames());
37+
// Call transposeOuts, to ensure that order of input and output tensor
38+
// element coordinates are compatible on stage 8 in algorithm below.
39+
auto linearLayoutDst = triton::gpu::toLinearLayout(dstShape, dstEncoding)
40+
.transposeOuts(outDimNames);
3741
auto srcCTAOrder = LLVM::AMD::getCTATileOrder(context, linearLayoutSrc);
3842
auto dstCTAOrder = LLVM::AMD::getCTATileOrder(context, linearLayoutSrc);
3943

@@ -69,34 +73,77 @@ struct ConcatOpConversion : public ConvertOpToLLVMPattern<amdgpu::ConcatOp> {
6973
unpackedSources.push_back(unpackLLElements(loc, currSrc, rewriter));
7074
}
7175

72-
// Traverse CTA tiles in the result tensor
73-
for (int i = 0; i < numCTATiles; ++i) {
74-
auto currTileIdx = mlir::LLVM::delinearize(i, dstCTAShape, dstCTAOrder);
76+
// Algorithm:
77+
// 1. for all registers in src tensor
78+
// 2. compute src location in tensor relative to tile beginnig
79+
// 3. save mapping from src elem coordinates to register idx
80+
// 4. for all elements in dst tensor
81+
// 5. get dst value location in tensor
82+
// 6. find, which input tile holds the dst value
83+
// 7. subtract dst coordinates and start coordinates of the tile
84+
// 8. find source register number which holds dst value
85+
// 9. copy dst element from computed tile and register
86+
auto ctx = rewriter.getContext();
87+
StringAttr kReg = StringAttr::get(ctx, "register");
88+
auto srcRegBases = linearLayoutSrc.getBases().lookup(kReg);
89+
auto dstRegBases = linearLayoutDst.getBases().lookup(kReg);
90+
91+
using ElemLocationKey = decltype(linearLayoutSrc.apply({}));
92+
llvm::MapVector<ElemLocationKey, unsigned> srcElemToReg;
93+
int srcRegNum = 1 << srcRegBases.size();
94+
// 1. for all registers in src tensor
95+
for (int regId = 0; regId < srcRegNum; ++regId) {
96+
// 2. compute src location in tensor relative to tile beginnig
97+
SmallVector<std::pair<StringAttr, int32_t>> hardwareLocation;
98+
for (auto dimName : linearLayoutSrc.getInDimNames()) {
99+
if (dimName == kReg)
100+
hardwareLocation.push_back({dimName, regId});
101+
else
102+
hardwareLocation.push_back({dimName, 0});
103+
}
104+
auto elemCoords = linearLayoutSrc.apply(hardwareLocation);
105+
// 3. save mapping from src elem coordinates to register idx
106+
srcElemToReg[elemCoords] = regId;
107+
}
108+
// for every output register get element coords,
109+
// find corresponding operand and copy src register
110+
int dstRegNum = 1 << dstRegBases.size();
111+
// 4. for all elements in dst tensor
112+
for (int regId = 0; regId < dstRegNum; ++regId) {
113+
SmallVector<std::pair<StringAttr, int32_t>> hardwareLocation;
114+
// 5. get dst value location in tensor
115+
for (auto dimName : linearLayoutDst.getInDimNames()) {
116+
if (dimName == kReg)
117+
hardwareLocation.push_back({dimName, regId});
118+
else
119+
hardwareLocation.push_back({dimName, 0});
120+
}
121+
auto elemCoords = linearLayoutDst.apply(hardwareLocation);
122+
auto elemCoordsArray =
123+
llvm::to_vector(llvm::make_second_range(elemCoords));
75124
// The n-dim destination tensor is built by arranging n-dim source tensors
76125
// into a destination tensor shape. Determine which source tensor contains
77126
// the current CTA tile.
78-
auto multiDimSrcIdx = LLVM::AMD::multiDimElementwise<unsigned, unsigned>(
79-
currTileIdx, srcCTAShape, std::divides<unsigned>());
127+
auto multiDimOperandIdx =
128+
LLVM::AMD::multiDimElementwise<int32_t, int64_t>(
129+
elemCoordsArray, srcShape, std::divides<unsigned>());
80130
// Compute linear index of the current source tensor.
81131
// Concat operands are laid out in the destination tensor
82132
// in fastest slowest varying dimension order.
83-
auto linearSrcIdx =
84-
mlir::LLVM::linearize(multiDimSrcIdx, srcToDstShape, defaultOrder);
85-
86-
// After determining which source tensor the current CTA tile belongs to,
87-
// compute the index of this CTA tile within that source tensor,
88-
// considering the source tensors may include CTA tiles.
89-
auto multiDimSrcCTAIdx =
90-
LLVM::AMD::multiDimElementwise<unsigned, unsigned>(
91-
currTileIdx, srcCTAShape, std::modulus<unsigned>());
92-
auto linearSrcCTAIdx =
93-
mlir::LLVM::linearize(multiDimSrcCTAIdx, srcCTAShape, srcCTAOrder);
94-
auto unpackedElements = unpackedSources[linearSrcIdx];
95-
96-
auto startIt =
97-
unpackedElements.begin() + linearSrcCTAIdx * elemsPerThreadPerCTA;
98-
auto endIt = startIt + elemsPerThreadPerCTA;
99-
llvm::append_range(resultVals, llvm::make_range(startIt, endIt));
133+
// 6. find, which input tile holds the dst value
134+
auto linearOperandIdx = mlir::LLVM::linearize(
135+
multiDimOperandIdx, srcToDstShape, defaultOrder);
136+
137+
// 7. subtract dst coordinates and start coordinates of the tile
138+
for (int dim = 0; dim < rank; ++dim)
139+
elemCoords[dim].second -= multiDimOperandIdx[dim] * srcShape[dim];
140+
141+
assert(srcElemToReg.contains(elemCoords));
142+
// 8. find source register number which holds dst value
143+
int srcRegIdx = srcElemToReg.lookup(elemCoords);
144+
145+
// 9. copy dst element from found tile and register
146+
resultVals.push_back(unpackedSources[linearOperandIdx][srcRegIdx]);
100147
}
101148

102149
Value packedResult = packLLElements(loc, this->getTypeConverter(),

0 commit comments

Comments
 (0)