Skip to content

Commit 0f002cd

Browse files
authored
Fix the implementation of isSingleValue to handle blocked pointers (#2534)
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 7e59486 commit 0f002cd

File tree

3 files changed

+23
-14
lines changed

3 files changed

+23
-14
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ class ConversionPatternRewriter;
1818

1919
namespace mlir::triton::gpu::intel {
2020

21+
// If the given type is a pointer of tensors, return the pointee type.
22+
// Otherwise, attempt to cast the given type to a ranked tensor and return the
23+
// dynamic cast result.
24+
RankedTensorType getRankedTensorType(Type type);
25+
2126
// Check if given value is divisible by the divisor.
2227
bool isDivisible(Value value, unsigned divisor);
2328

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
#include "llvm/Support/raw_ostream.h"
55

66
#include "intel/include/Analysis/AxisInfo.h"
7+
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
78
#include "triton/Dialect/Triton/IR/Dialect.h"
89

910
#define DEBUG_TYPE "intel-axis-info"
1011
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
1112
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
1213

1314
namespace mlir::triton::intel {
15+
16+
namespace ttgi = mlir::triton::gpu::intel;
1417
namespace {
1518

1619
int64_t gcdImpl(int64_t a, int64_t b, int64_t *x, int64_t *y) {
@@ -50,12 +53,6 @@ int64_t multiplyDivisor(int64_t lhs, int64_t rhs) {
5053
return lhs * rhs;
5154
}
5255

53-
RankedTensorType getRankedTensorType(Type ptrTy) {
54-
return isTensorPointerType(ptrTy)
55-
? cast<RankedTensorType>(cast<PointerType>(ptrTy).getPointeeType())
56-
: dyn_cast<RankedTensorType>(ptrTy);
57-
}
58-
5956
class AxisInfoVisitor {
6057
public:
6158
AxisInfoVisitor() = default;
@@ -415,7 +412,7 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
415412

416413
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
417414
int dim) override {
418-
auto resTy = getRankedTensorType(op.getType());
415+
auto resTy = ttgi::getRankedTensorType(op.getType());
419416
if (!resTy)
420417
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
421418
auto shape = resTy.getShape();
@@ -470,7 +467,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
470467
private:
471468
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
472469
int dim) override {
473-
auto resTy = getRankedTensorType(op.getType());
470+
auto resTy = ttgi::getRankedTensorType(op.getType());
474471
if (!resTy)
475472
return BinaryOpVisitorImpl<OpTy>::getContiguity(op, lhs, rhs, dim);
476473
auto shape = resTy.getShape();
@@ -504,7 +501,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
504501

505502
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
506503
int dim) override {
507-
auto resTy = getRankedTensorType(op.getType());
504+
auto resTy = ttgi::getRankedTensorType(op.getType());
508505
if (!resTy)
509506
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
510507
auto shape = resTy.getShape();
@@ -653,7 +650,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
653650
AxisInfo
654651
getAxisInfo(OpTy op,
655652
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
656-
auto resTy = getRankedTensorType(op.getType());
653+
auto resTy = ttgi::getRankedTensorType(op.getType());
657654
if (!resTy)
658655
return AxisInfo();
659656
auto shape = resTy.getShape();
@@ -1268,7 +1265,7 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
12681265
}
12691266

12701267
unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
1271-
auto tensorTy = getRankedTensorType(ptr.getType());
1268+
auto tensorTy = ttgi::getRankedTensorType(ptr.getType());
12721269
if (!tensorTy)
12731270
return 1;
12741271
auto layout = tensorTy.getEncoding();
@@ -1290,7 +1287,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
12901287
}
12911288

12921289
unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
1293-
auto tensorTy = getRankedTensorType(ptr.getType());
1290+
auto tensorTy = ttgi::getRankedTensorType(ptr.getType());
12941291
if (!tensorTy)
12951292
return 1;
12961293
auto *axisInfo = getAxisInfo(ptr);
@@ -1318,7 +1315,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
13181315
}
13191316

13201317
unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
1321-
auto tensorTy = getRankedTensorType(mask.getType());
1318+
auto tensorTy = ttgi::getRankedTensorType(mask.getType());
13221319
if (!tensorTy)
13231320
return 1;
13241321
auto *axisInfo = getAxisInfo(mask);

third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,16 @@ namespace ttgi = mlir::triton::gpu::intel;
2525

2626
namespace mlir::triton::gpu::intel {
2727

28+
RankedTensorType getRankedTensorType(Type ptrTy) {
29+
return tt::isTensorPointerType(ptrTy)
30+
? cast<RankedTensorType>(
31+
cast<tt::PointerType>(ptrTy).getPointeeType())
32+
: dyn_cast<RankedTensorType>(ptrTy);
33+
}
34+
2835
static bool isSingleValue(Value value) {
2936
// Don't consider load as expensive if it is loading a scalar.
30-
if (auto tensorTy = dyn_cast<RankedTensorType>(value.getType()))
37+
if (auto tensorTy = getRankedTensorType(value.getType()))
3138
return tensorTy.getNumElements() == 1;
3239
// TODO: Handle other cases.
3340
// For example, when ptr is a tensor of single value.

0 commit comments

Comments
 (0)