Skip to content

Commit 0372175

Browse files
authored
Merge branch 'main' into sub-group-shuffle-broadcast
2 parents 2512ab6 + 0f002cd commit 0372175

File tree

10 files changed

+42
-27
lines changed

10 files changed

+42
-27
lines changed

.github/workflows/llvm-build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ jobs:
157157
cp -r /usr/aarch64-linux-gnu/lib ./arm-sysroot
158158
cp -r /usr/aarch64-linux-gnu/include ./arm-sysroot
159159
LINKER=$(pwd)/arm-sysroot/lib/ld-linux-aarch64.so.1
160-
wget http://ftp.de.debian.org/debian/pool/main/g/gcc-defaults/gcc-aarch64-linux-gnu_14.1.0-2_amd64.deb
161-
dpkg-deb -x gcc-aarch64-linux-gnu_14.1.0-2_amd64.deb ./arm-sysroot
160+
wget http://ftp.de.debian.org/debian/pool/main/g/gcc-defaults/gcc-aarch64-linux-gnu_14.2.0-1_amd64.deb
161+
dpkg-deb -x gcc-aarch64-linux-gnu_14.2.0-1_amd64.deb ./arm-sysroot
162162
export LD_LIBRARY_PATH=$(pwd)/arm-sysroot/lib:$LD_LIBRARY_PATH
163163
sudo ln -s $LINKER /lib/ld-linux-aarch64.so.1
164164
SYSROOT="$(pwd)/arm-sysroot"

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
61f8a7f618901797ee8663389a29722f29216a96
1+
b5cc222d7429fe6f18c787f633d5262fac2e676f

lib/Analysis/AxisInfo.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,8 +1084,9 @@ LogicalResult AxisInfoAnalysis::visitOperation(
10841084

10851085
void AxisInfoAnalysis::visitForOpInductionVar(
10861086
scf::ForOp op, ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices) {
1087-
auto lb = getLatticeElementFor(op, op.getLowerBound())->getValue();
1088-
auto step = getLatticeElementFor(op, op.getStep())->getValue();
1087+
ProgramPoint programPoint(op);
1088+
auto lb = getLatticeElementFor(&programPoint, op.getLowerBound())->getValue();
1089+
auto step = getLatticeElementFor(&programPoint, op.getStep())->getValue();
10891090

10901091
AxisInfo::DimVectorT knownContiguity(1, 1);
10911092
AxisInfo::DimVectorT knownDivisibility(1, 1);

lib/Analysis/Utility.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -904,15 +904,16 @@ class ConstantAnalysis : public DataFlowAnalysis {
904904

905905
LogicalResult initialize(Operation *top) override {
906906
WalkResult result = top->walk([&](Operation *op) {
907-
if (failed(visit(op)))
907+
ProgramPoint programPoint(op);
908+
if (failed(visit(&programPoint)))
908909
return WalkResult::interrupt();
909910
return WalkResult::advance();
910911
});
911912
return success(!result.wasInterrupted());
912913
}
913914

914-
LogicalResult visit(ProgramPoint point) override {
915-
Operation *op = point.get<Operation *>();
915+
LogicalResult visit(ProgramPoint *point) override {
916+
Operation *op = point->getOperation();
916917
Attribute value;
917918
if (matchPattern(op, m_Constant(&value))) {
918919
auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(

lib/Target/LLVMIR/LLVMDIScope.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ struct LLVMDIScopePass : public LLVMDIScopeBase<LLVMDIScopePass> {
104104
auto subprogramAttr = LLVM::DISubprogramAttr::get(
105105
context, distinctId, compileUnitAttr, fileAttr, funcNameAttr,
106106
funcNameAttr, fileAttr, /*line=*/line, /*scopeline=*/line,
107-
subprogramFlags, subroutineTypeAttr, /*retainNodes=*/{});
107+
subprogramFlags, subroutineTypeAttr, /*retainNodes=*/{},
108+
/*annotations=*/{});
108109
funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr));
109110
}
110111

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
a12e20d3cb095489aab5a3ff188998bc4104d693
1+
fd80cf38d6fde8b5fc28cd3231cc611e91dd8b56

third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ class ConversionPatternRewriter;
1818

1919
namespace mlir::triton::gpu::intel {
2020

21+
// If the given type is a pointer of tensors, return the pointee type.
22+
// Otherwise, attempt to cast the given type to a ranked tensor and return the
23+
// dynamic cast result.
24+
RankedTensorType getRankedTensorType(Type type);
25+
2126
// Check if given value is divisible by the divisor.
2227
bool isDivisible(Value value, unsigned divisor);
2328

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
#include "llvm/Support/raw_ostream.h"
55

66
#include "intel/include/Analysis/AxisInfo.h"
7+
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
78
#include "triton/Dialect/Triton/IR/Dialect.h"
89

910
#define DEBUG_TYPE "intel-axis-info"
1011
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
1112
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
1213

1314
namespace mlir::triton::intel {
15+
16+
namespace ttgi = mlir::triton::gpu::intel;
1417
namespace {
1518

1619
int64_t gcdImpl(int64_t a, int64_t b, int64_t *x, int64_t *y) {
@@ -50,12 +53,6 @@ int64_t multiplyDivisor(int64_t lhs, int64_t rhs) {
5053
return lhs * rhs;
5154
}
5255

53-
RankedTensorType getRankedTensorType(Type ptrTy) {
54-
return isTensorPointerType(ptrTy)
55-
? cast<RankedTensorType>(cast<PointerType>(ptrTy).getPointeeType())
56-
: dyn_cast<RankedTensorType>(ptrTy);
57-
}
58-
5956
class AxisInfoVisitor {
6057
public:
6158
AxisInfoVisitor() = default;
@@ -415,7 +412,7 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
415412

416413
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
417414
int dim) override {
418-
auto resTy = getRankedTensorType(op.getType());
415+
auto resTy = ttgi::getRankedTensorType(op.getType());
419416
if (!resTy)
420417
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
421418
auto shape = resTy.getShape();
@@ -470,7 +467,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
470467
private:
471468
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
472469
int dim) override {
473-
auto resTy = getRankedTensorType(op.getType());
470+
auto resTy = ttgi::getRankedTensorType(op.getType());
474471
if (!resTy)
475472
return BinaryOpVisitorImpl<OpTy>::getContiguity(op, lhs, rhs, dim);
476473
auto shape = resTy.getShape();
@@ -504,7 +501,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
504501

505502
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
506503
int dim) override {
507-
auto resTy = getRankedTensorType(op.getType());
504+
auto resTy = ttgi::getRankedTensorType(op.getType());
508505
if (!resTy)
509506
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
510507
auto shape = resTy.getShape();
@@ -653,7 +650,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
653650
AxisInfo
654651
getAxisInfo(OpTy op,
655652
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
656-
auto resTy = getRankedTensorType(op.getType());
653+
auto resTy = ttgi::getRankedTensorType(op.getType());
657654
if (!resTy)
658655
return AxisInfo();
659656
auto shape = resTy.getShape();
@@ -1144,8 +1141,11 @@ LogicalResult AxisInfoAnalysis::visitOperation(
11441141

11451142
void AxisInfoAnalysis::visitForOpInductionVar(
11461143
scf::ForOp op, ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices) {
1147-
const auto &lb = getLatticeElementFor(op, op.getLowerBound())->getValue();
1148-
const auto &step = getLatticeElementFor(op, op.getStep())->getValue();
1144+
ProgramPoint programPoint(op);
1145+
const auto lb =
1146+
getLatticeElementFor(&programPoint, op.getLowerBound())->getValue();
1147+
const auto step =
1148+
getLatticeElementFor(&programPoint, op.getStep())->getValue();
11491149

11501150
AxisInfo::DimVectorT knownContiguity(1, 1);
11511151
AxisInfo::DimVectorT knownDivisibility(1, 1);
@@ -1265,7 +1265,7 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
12651265
}
12661266

12671267
unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
1268-
auto tensorTy = getRankedTensorType(ptr.getType());
1268+
auto tensorTy = ttgi::getRankedTensorType(ptr.getType());
12691269
if (!tensorTy)
12701270
return 1;
12711271
auto layout = tensorTy.getEncoding();
@@ -1287,7 +1287,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
12871287
}
12881288

12891289
unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
1290-
auto tensorTy = getRankedTensorType(ptr.getType());
1290+
auto tensorTy = ttgi::getRankedTensorType(ptr.getType());
12911291
if (!tensorTy)
12921292
return 1;
12931293
auto *axisInfo = getAxisInfo(ptr);
@@ -1315,7 +1315,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
13151315
}
13161316

13171317
unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
1318-
auto tensorTy = getRankedTensorType(mask.getType());
1318+
auto tensorTy = ttgi::getRankedTensorType(mask.getType());
13191319
if (!tensorTy)
13201320
return 1;
13211321
auto *axisInfo = getAxisInfo(mask);

third_party/intel/lib/Target/LLVMIR/SLPVectorizer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13236,7 +13236,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1323613236
false /*HasGlobalPred*/);
1323713237
CF = VFDatabase(*CI).getVectorizedFunction(Shape);
1323813238
} else {
13239-
CF = Intrinsic::getDeclaration(F->getParent(), ID, TysForDecl);
13239+
CF = Intrinsic::getOrInsertDeclaration(F->getParent(), ID, TysForDecl);
1324013240
}
1324113241

1324213242
SmallVector<OperandBundleDef, 1> OpBundles;

third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,16 @@ namespace ttgi = mlir::triton::gpu::intel;
2525

2626
namespace mlir::triton::gpu::intel {
2727

28+
RankedTensorType getRankedTensorType(Type ptrTy) {
29+
return tt::isTensorPointerType(ptrTy)
30+
? cast<RankedTensorType>(
31+
cast<tt::PointerType>(ptrTy).getPointeeType())
32+
: dyn_cast<RankedTensorType>(ptrTy);
33+
}
34+
2835
static bool isSingleValue(Value value) {
2936
// Don't consider load as expensive if it is loading a scalar.
30-
if (auto tensorTy = dyn_cast<RankedTensorType>(value.getType()))
37+
if (auto tensorTy = getRankedTensorType(value.getType()))
3138
return tensorTy.getNumElements() == 1;
3239
// TODO: Handle other cases.
3340
// For example, when ptr is a tensor of single value.

0 commit comments

Comments
 (0)