Skip to content

Commit e9bc211

Browse files
authored
[DRAFT][LL] LL Python Interface (#8521)
1 parent 7231594 commit e9bc211

File tree

10 files changed

+503
-88
lines changed

10 files changed

+503
-88
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
280280
add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc
281281
${PYTHON_SRC_PATH}/ir.cc
282282
${PYTHON_SRC_PATH}/gluon_ir.cc
283+
${PYTHON_SRC_PATH}/linear_layout.cc
283284
${PYTHON_SRC_PATH}/passes.cc
284285
${PYTHON_SRC_PATH}/interpreter.cc
285286
${PYTHON_SRC_PATH}/llvm.cc

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,12 @@ void dumpHWLayout(RankedTensorType tensorType);
266266
// Return a string representation of the layout of the tensor.
267267
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView);
268268

269+
// Return a string representation of the shared layout of the tensor.
270+
std::string getSharedLayoutStr(LinearLayout &ll, bool useHWPointOfView);
271+
272+
// Return a string representation of the distributed layout of the tensor.
273+
std::string getDistributedLayoutStr(LinearLayout &ll, bool useHWPointOfView);
274+
269275
template <typename T>
270276
llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s);
271277

include/triton/Tools/LinearLayout.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,8 @@ inline std::ostream &operator<<(std::ostream &os, const ColumnAction &action) {
869869
return os;
870870
}
871871

872+
std::unique_ptr<uint64_t[]> getMatrix(const LinearLayout &layout);
873+
872874
} // namespace mlir::triton
873875

874876
#endif // TRITON_TOOLS_LINEARLAYOUT_H

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3307,20 +3307,17 @@ static std::string paddedString(int value, int max) {
33073307
return str;
33083308
}
33093309

3310-
std::string getSharedLayoutStr(RankedTensorType type, bool useHWPointOfView) {
3311-
if (!type)
3312-
return "";
3313-
3310+
std::string mlir::triton::gpu::getSharedLayoutStr(LinearLayout &ll,
3311+
bool useHWPointOfView) {
33143312
// This RankedTensorType is a MemDescType (?!)
3315-
auto shape = type.getShape();
3316-
auto layout = type.getEncoding();
3317-
LinearLayout ll = triton::gpu::toLinearLayout(shape, layout);
3313+
auto outDimNames = llvm::to_vector(ll.getOutDimNames());
3314+
auto shape = convertType<int64_t>(llvm::to_vector(ll.getOutDimSizes()));
3315+
auto *ctx = outDimNames[0].getContext();
33183316

3319-
StringAttr kOffset = StringAttr::get(type.getContext(), "offset");
3320-
StringAttr kBlock = StringAttr::get(type.getContext(), "block");
3321-
int64_t tensorSize = product(type.getShape());
3322-
auto enc = type.getEncoding();
3323-
unsigned numBlocks = getNumCTAs(enc);
3317+
StringAttr kOffset = StringAttr::get(ctx, "offset");
3318+
StringAttr kBlock = StringAttr::get(ctx, "block");
3319+
int64_t tensorSize = product(shape);
3320+
unsigned numBlocks = ll.getInDimSize(kBlock);
33243321
int32_t blockSize = tensorSize / numBlocks;
33253322

33263323
// elementMapping is for the non-hw layout, offsetMapping for hw-layout
@@ -3374,7 +3371,7 @@ std::string getSharedLayoutStr(RankedTensorType type, bool useHWPointOfView) {
33743371
std::string layoutStr;
33753372

33763373
if (!useHWPointOfView) {
3377-
int rank = type.getRank();
3374+
int rank = shape.size();
33783375
bool newLine = true;
33793376
for (int i = 0; i < tensorSize; i++) {
33803377
auto indices = delinearizeIndex(i, shape);
@@ -3422,21 +3419,19 @@ std::string getSharedLayoutStr(RankedTensorType type, bool useHWPointOfView) {
34223419
return layoutStr;
34233420
}
34243421

3425-
std::string getDistributedLayoutStr(RankedTensorType tensorType,
3426-
bool useHWPointOfView) {
3427-
auto layout = tensorType.getEncoding();
3428-
if (!layout)
3429-
return "";
3430-
3431-
StringAttr kRegister = StringAttr::get(tensorType.getContext(), "register");
3432-
StringAttr kLane = StringAttr::get(tensorType.getContext(), "lane");
3433-
StringAttr kWarp = StringAttr::get(tensorType.getContext(), "warp");
3434-
StringAttr kBlock = StringAttr::get(tensorType.getContext(), "block");
3422+
std::string mlir::triton::gpu::getDistributedLayoutStr(LinearLayout &ll,
3423+
bool useHWPointOfView) {
3424+
auto inDimNames = llvm::to_vector(ll.getInDimNames());
3425+
auto *ctx = inDimNames[0].getContext();
3426+
StringAttr kRegister = StringAttr::get(ctx, "register");
3427+
StringAttr kLane = StringAttr::get(ctx, "lane");
3428+
StringAttr kWarp = StringAttr::get(ctx, "warp");
3429+
StringAttr kBlock = StringAttr::get(ctx, "block");
34353430

3436-
LinearLayout ll = toLinearLayout(tensorType);
3437-
int64_t tensorSize = product(tensorType.getShape());
3431+
int64_t tensorSize = ll.getTotalOutDimSize();
34383432
std::vector<std::string> elementMapping(tensorSize);
34393433
std::vector<std::string> threadMapping;
3434+
auto shape = convertType<int64_t>(llvm::to_vector(ll.getOutDimSizes()));
34403435
unsigned threadsPerWarp = ll.getInDimSize(kLane);
34413436
unsigned numWarpsPerCTA = ll.getInDimSize(kWarp);
34423437
unsigned numBlocks = ll.getInDimSize(kBlock);
@@ -3456,7 +3451,7 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
34563451
int stride = 1;
34573452
for (int i = outputs.size() - 1; i >= 0; i--) {
34583453
linearizedIdx += outputs[i].second * stride;
3459-
stride *= tensorType.getDimSize(i);
3454+
stride *= shape[i];
34603455
}
34613456
std::string &value = elementMapping[linearizedIdx];
34623457
if (!value.empty())
@@ -3476,8 +3471,7 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
34763471
for (int i = 0; i < outputs.size(); i++) {
34773472
if (i > 0)
34783473
threadInfo += ",";
3479-
threadInfo +=
3480-
paddedString(outputs[i].second, tensorType.getDimSize(i));
3474+
threadInfo += paddedString(outputs[i].second, shape[i]);
34813475
}
34823476
threadInfo += ")";
34833477
threadMapping.push_back(threadInfo);
@@ -3488,13 +3482,13 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
34883482
std::string layoutStr;
34893483
if (!useHWPointOfView) {
34903484
// Printing the threads containing each elements of the tensor.
3491-
int rank = tensorType.getRank();
3485+
int rank = ll.getNumOutDims();
34923486
bool newLine = true;
34933487
for (int i = 0; i < tensorSize; i++) {
3494-
auto indices = delinearizeIndex(i, tensorType.getShape());
3488+
auto indices = delinearizeIndex(i, shape);
34953489
int numOpenBracket = 0;
34963490
for (int j = rank - 1; j >= 0; j--) {
3497-
if (indices[j] % tensorType.getDimSize(j) != 0)
3491+
if (indices[j] % shape[j] != 0)
34983492
break;
34993493
layoutStr += "[";
35003494
numOpenBracket++;
@@ -3506,13 +3500,13 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
35063500
}
35073501

35083502
layoutStr += elementMapping[i];
3509-
auto nextIndices = delinearizeIndex(i + 1, tensorType.getShape());
3503+
auto nextIndices = delinearizeIndex(i + 1, shape);
35103504
for (int j = rank - 1; j >= 0; j--) {
3511-
if (nextIndices[j] % tensorType.getDimSize(j) != 0)
3505+
if (nextIndices[j] % shape[j] != 0)
35123506
break;
35133507
layoutStr += "]";
35143508
}
3515-
if (nextIndices.back() % tensorType.getShape().back() == 0) {
3509+
if (nextIndices.back() % shape.back() == 0) {
35163510
layoutStr += "\n";
35173511
newLine = true;
35183512
} else {
@@ -3578,15 +3572,16 @@ mlir::triton::gpu::expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o) {
35783572
std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
35793573
bool useHWPointOfView) {
35803574
auto layout = tensorType.getEncoding();
3575+
LinearLayout ll = triton::gpu::toLinearLayout(tensorType.getShape(), layout);
35813576

35823577
// tensorType is needed later on (e.g., getDimSize(j)), so we still have to
35833578
// pass it as a param
35843579
// TODO: Pass TensorOrMemDesc instead of RankedTensorType in
35853580
// triton-tensor-layout.cpp
35863581
if (mlir::isa<SharedEncodingTrait>(layout)) {
3587-
return getSharedLayoutStr(tensorType, useHWPointOfView);
3582+
return getSharedLayoutStr(ll, useHWPointOfView);
35883583
} else if (mlir::isa<DistributedEncodingTrait>(layout)) {
3589-
return getDistributedLayoutStr(tensorType, useHWPointOfView);
3584+
return getDistributedLayoutStr(ll, useHWPointOfView);
35903585
}
35913586

35923587
// else unimplemented, return error

lib/Tools/LinearLayout.cpp

Lines changed: 52 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -65,56 +65,6 @@ void dumpMatrix(uint64_t *m, int numRows, int numCols) {
6565
}
6666
}
6767

68-
// Build a matrix of size sum(outDimSizeLog2) x sum(inDimSizeLog2) representing
69-
// the bases of the given layout. This can then be used by f2reduce.
70-
//
71-
// This function is called from the constructor of LinearLayout, so be careful
72-
// not to use any functions that create LLs in here.
73-
std::unique_ptr<uint64_t[]> getMatrix(const LinearLayout &layout) {
74-
int numRows = layout.getTotalOutDimSizeLog2();
75-
int numCols = layout.getTotalInDimSizeLog2();
76-
77-
// Don't handle giant LLs. This makes some things easier; for example, each
78-
// row can be a single uint64_t.
79-
assert(numCols <= 64 && "LinearLayout too large");
80-
assert(numRows <= 64 && "LinearLayout too large");
81-
82-
// Suppose we have a layout specified by the following values.
83-
//
84-
// L(0,1) = (0b01, 0b1)
85-
// L(0,2) = (0b10, 0b0)
86-
// L(1,0) = (0b10, 0b0)
87-
// L(2,0) = (0b11, 0b0)
88-
//
89-
// We will create one column per entry above. The max bit width of the
90-
// codomain is (2,1), so our matrix will have 2+1=3 rows. The final matrix
91-
// will be
92-
//
93-
// | L(0,1)[0] L(0,2)[0] L(1,0)[0] L(2,0)[0] | | 0b1001 |
94-
// | ↓ ↓ ↓ ↓ | | 0b0111 |
95-
// | L(0,1)[1] L(0,2)[1] L(1,0)[1] L(2,0)[1] | = | 0b1000 |
96-
// | ↓ ↓ ↓ ↓ |
97-
//
98-
// Note `new uint64_t[n]()` is zero-initialized, but `new uint64_t[n]` is not.
99-
std::unique_ptr<uint64_t[]> m(new uint64_t[numRows]());
100-
int r = 0;
101-
for (StringAttr outDim : layout.getOutDimNames()) {
102-
int c = 0;
103-
for (StringAttr inDim : layout.getInDimNames()) {
104-
for (int i = 0; i < layout.getInDimSizeLog2(inDim); i++) {
105-
uint64_t basis = layout.getBasis(inDim, i, outDim);
106-
for (int j = 0; j < layout.getOutDimSizeLog2(outDim); j++) {
107-
m[r + j] |= ((basis >> j) & 1) << c;
108-
}
109-
c++;
110-
}
111-
}
112-
r += layout.getOutDimSizeLog2(outDim);
113-
}
114-
115-
return m;
116-
}
117-
11868
// Compute the rank of the matrix formed by taking the bases for the given
11969
// outDim as columns. In other words, finds the number of linearly-independent
12070
// bases for this output dimension.
@@ -340,7 +290,7 @@ int32_t LinearLayout::getOutDimIndex(StringAttr outDim) const {
340290

341291
int32_t LinearLayout::getInDimSizeLog2(StringAttr inDim) const {
342292
auto it = bases.find(inDim);
343-
assert(it != bases.end());
293+
assert(it != bases.end() && "inDim not found in layout");
344294
return it->second.size();
345295
}
346296

@@ -353,7 +303,7 @@ int32_t LinearLayout::getTotalInDimSizeLog2() const {
353303

354304
int32_t LinearLayout::getOutDimSizeLog2(StringAttr outDim) const {
355305
auto it = outDims.find(outDim);
356-
assert(it != outDims.end());
306+
assert(it != outDims.end() && "outDim not found in layout");
357307
return llvm::Log2_32(it->second);
358308
}
359309

@@ -1370,4 +1320,54 @@ std::string ColumnAction::toString() const {
13701320
return ret;
13711321
}
13721322

1323+
// Build a matrix of size sum(outDimSizeLog2) x sum(inDimSizeLog2) representing
1324+
// the bases of the given layout. This can then be used by f2reduce.
1325+
//
1326+
// This function is called from the constructor of LinearLayout, so be careful
1327+
// not to use any functions that create LLs in here.
1328+
std::unique_ptr<uint64_t[]> getMatrix(const LinearLayout &layout) {
1329+
int numRows = layout.getTotalOutDimSizeLog2();
1330+
int numCols = layout.getTotalInDimSizeLog2();
1331+
1332+
// Don't handle giant LLs. This makes some things easier; for example, each
1333+
// row can be a single uint64_t.
1334+
assert(numCols <= 64 && "LinearLayout too large");
1335+
assert(numRows <= 64 && "LinearLayout too large");
1336+
1337+
// Suppose we have a layout specified by the following values.
1338+
//
1339+
// L(0,1) = (0b01, 0b1)
1340+
// L(0,2) = (0b10, 0b0)
1341+
// L(1,0) = (0b10, 0b0)
1342+
// L(2,0) = (0b11, 0b0)
1343+
//
1344+
// We will create one column per entry above. The max bit width of the
1345+
// codomain is (2,1), so our matrix will have 2+1=3 rows. The final matrix
1346+
// will be
1347+
//
1348+
// | L(0,1)[0] L(0,2)[0] L(1,0)[0] L(2,0)[0] | | 0b1001 |
1349+
// | ↓ ↓ ↓ ↓ | | 0b0111 |
1350+
// | L(0,1)[1] L(0,2)[1] L(1,0)[1] L(2,0)[1] | = | 0b1000 |
1351+
// | ↓ ↓ ↓ ↓ |
1352+
//
1353+
// Note `new uint64_t[n]()` is zero-initialized, but `new uint64_t[n]` is not.
1354+
std::unique_ptr<uint64_t[]> m(new uint64_t[numRows]());
1355+
int r = 0;
1356+
for (StringAttr outDim : layout.getOutDimNames()) {
1357+
int c = 0;
1358+
for (StringAttr inDim : layout.getInDimNames()) {
1359+
for (int i = 0; i < layout.getInDimSizeLog2(inDim); i++) {
1360+
uint64_t basis = layout.getBasis(inDim, i, outDim);
1361+
for (int j = 0; j < layout.getOutDimSizeLog2(outDim); j++) {
1362+
m[r + j] |= ((basis >> j) & 1) << c;
1363+
}
1364+
c++;
1365+
}
1366+
}
1367+
r += layout.getOutDimSizeLog2(outDim);
1368+
}
1369+
1370+
return m;
1371+
}
1372+
13731373
} // namespace mlir::triton

0 commit comments

Comments
 (0)