Skip to content

Commit adb2f15

Browse files
committed
[Matrix] Optimize shuffle extracts with ShapeInfo
When a shuffle extracts a vector that we have as part of the ShapeInfo for a Matrix (i.e. one column of a column-major matrix, or one row of a row-major matrix), replace the shuffle with that vector during lowering.
1 parent 0e90a84 commit adb2f15

File tree

2 files changed

+97
-5
lines changed

2 files changed

+97
-5
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
2121
#include "llvm/ADT/PostOrderIterator.h"
22+
#include "llvm/ADT/STLExtras.h"
2223
#include "llvm/ADT/ScopeExit.h"
2324
#include "llvm/ADT/SmallSet.h"
2425
#include "llvm/ADT/SmallVector.h"
@@ -32,6 +33,7 @@
3233
#include "llvm/IR/CFG.h"
3334
#include "llvm/IR/DataLayout.h"
3435
#include "llvm/IR/DebugInfoMetadata.h"
36+
#include "llvm/IR/DerivedTypes.h"
3537
#include "llvm/IR/Function.h"
3638
#include "llvm/IR/IRBuilder.h"
3739
#include "llvm/IR/Instructions.h"
@@ -1337,6 +1339,57 @@ class LowerMatrixIntrinsics {
13371339
return Builder.CreateAdd(Sum, Mul);
13381340
}
13391341

1342+
bool tryLowerShuffleVector(ShuffleVectorInst *Inst) {
1343+
Value *Op0 = Inst->getOperand(0), *Op1 = Inst->getOperand(1);
1344+
SmallVector<int> Mask;
1345+
Inst->getShuffleMask(Mask);
1346+
1347+
auto *Ty = cast<FixedVectorType>(Op0->getType());
1348+
1349+
if (Mask[0] == PoisonMaskElem)
1350+
return false;
1351+
1352+
// Check if the Mask implies a contiguous extraction, i.e. one column of a
1353+
// column-major matrix (or row of a row-major one).
1354+
for (int I = 1, E = Mask.size(); I != E; ++I) {
1355+
if (Mask[I] == PoisonMaskElem)
1356+
return false;
1357+
if (Mask[I-1] + 1 != Mask[I])
1358+
return false;
1359+
}
1360+
1361+
auto VectorForIndex = [&](int Idx) {
1362+
return Idx < int(Ty->getNumElements()) ? Op0 : Op1;
1363+
};
1364+
1365+
// Check if the Mask extracts from a single source operand.
1366+
Value *Op = VectorForIndex(Mask.front());
1367+
if (Op != VectorForIndex(Mask.back()))
1368+
return false;
1369+
1370+
auto *I = Inst2ColumnMatrix.find(Op);
1371+
if (I == Inst2ColumnMatrix.end())
1372+
return false;
1373+
1374+
const MatrixTy &M = I->second;
1375+
1376+
// Check if the Mask extracts one entire vector from the matrix.
1377+
if (Mask.size() != M.getStride())
1378+
return false;
1379+
1380+
// Check if the result would span two of the vectors in the matrix.
1381+
// TODO: we could handle this case by creating a new shuffle, if we see that
1382+
// happening in the wild.
1383+
if (0 != Mask[0] % M.getStride())
1384+
return false;
1385+
1386+
Value *Result = M.getVector(Mask[0] / M.getStride());
1387+
Inst->replaceAllUsesWith(Result);
1388+
Result->takeName(Inst);
1389+
Inst->eraseFromParent();
1390+
return true;
1391+
}
1392+
13401393
/// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
13411394
/// users with shape information, there's nothing to do: they will use the
13421395
/// cached value when they are lowered. For other users, \p Matrix is
@@ -1351,11 +1404,16 @@ class LowerMatrixIntrinsics {
13511404
ToRemove.push_back(Inst);
13521405
Value *Flattened = nullptr;
13531406
for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
1354-
if (!ShapeMap.contains(U.getUser())) {
1355-
if (!Flattened)
1356-
Flattened = Matrix.embedInVector(Builder);
1357-
U.set(Flattened);
1358-
}
1407+
if (ShapeMap.contains(U.getUser()))
1408+
continue;
1409+
1410+
if (auto *Intr = dyn_cast<ShuffleVectorInst>(U.getUser()))
1411+
if (tryLowerShuffleVector(Intr))
1412+
continue;
1413+
1414+
if (!Flattened)
1415+
Flattened = Matrix.embedInVector(Builder);
1416+
U.set(Flattened);
13591417
}
13601418
}
13611419

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
3+
4+
define <3 x double> @extract_column(ptr %in, ptr %out) {
5+
; CHECK-LABEL: @extract_column(
6+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <3 x double>, ptr [[IN:%.*]], align 8
7+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 3
8+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP]], align 8
9+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 6
10+
; CHECK-NEXT: [[COL_LOAD3:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP2]], align 8
11+
; CHECK-NEXT: ret <3 x double> [[COL_LOAD3]]
12+
;
13+
%inv = call <9 x double> @llvm.matrix.column.major.load(ptr %in, i64 3, i1 1, i32 3, i32 3)
14+
%col = shufflevector <9 x double> %inv, <9 x double> poison, <3 x i32> <i32 6, i32 7, i32 8>
15+
ret <3 x double> %col
16+
}
17+
18+
define <3 x double> @extract_row(ptr %in, ptr %out) {
19+
; CHECK-LABEL: @extract_row(
20+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <3 x double>, ptr [[IN:%.*]], align 8
21+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 3
22+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP]], align 8
23+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 6
24+
; CHECK-NEXT: [[COL_LOAD3:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP2]], align 8
25+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD]], <3 x double> [[COL_LOAD1]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
26+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <3 x double> [[COL_LOAD3]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
27+
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <6 x double> [[TMP1]], <6 x double> [[TMP2]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
28+
; CHECK-NEXT: [[ROW:%.*]] = shufflevector <9 x double> [[TMP3]], <9 x double> poison, <3 x i32> <i32 0, i32 3, i32 6>
29+
; CHECK-NEXT: ret <3 x double> [[ROW]]
30+
;
31+
%inv = call <9 x double> @llvm.matrix.column.major.load(ptr %in, i64 3, i1 1, i32 3, i32 3)
32+
%row = shufflevector <9 x double> %inv, <9 x double> poison, <3 x i32> <i32 0, i32 3, i32 6>
33+
ret <3 x double> %row
34+
}

0 commit comments

Comments
 (0)