Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
fe57b25
[AxisInfo] Make unrealized_conversion_cast handling more robust (#8507)
matthias-springer Oct 22, 2025
56c6468
[GLUON] add device-side TMA (#8505)
hgl71964 Oct 22, 2025
bad2576
[AMD] Fix branch condition in BufferLoadToLocalOpConversion (#8501)
kelesvol Oct 23, 2025
ecd33fe
[BACKEND] Improve constant analysis in AxisInfo (#8502)
lezcano Oct 23, 2025
0257c4c
[Tests] Using device fixure instead of cuda in tensor descriptor test…
red1bluelost Oct 23, 2025
c07886c
[AMD] Update shared memory size for gfx1250 from TargetInfo (#8517)
AlexAUT Oct 23, 2025
4d6ce4e
[AMD] Fix wmma scaled with small k dim on gfx1250 (#8487)
borontion Oct 23, 2025
3a832d6
[BACKEND] Fix memdesc of pointers (#8515)
ThomasRaoux Oct 23, 2025
1c72fb6
[NFC] Remove legacy TODO (#8520)
Jokeren Oct 23, 2025
00cf53f
[BENCH] Incorporate EP sharding and deprecate the legacy communicatio…
Jokeren Oct 24, 2025
a2fdd73
[AMD][BACKEND] Support of ttg.async_wait on gfx1250 (#8510)
AlexAUT Oct 24, 2025
314a622
[AMD] Fix deduceTilesPerWarp boundary cases (#8467)
Dewei-Wang-sh Oct 24, 2025
39eec89
[AMD] Lower `ttg.async_copy_global_to_local` on gfx1250 (#8509)
AlexAUT Oct 24, 2025
3f4ac9f
[KERNELS][NFC] Remove the redundant `reduce.py` file (#8524)
Jokeren Oct 24, 2025
4734af3
Fix AxisInfo handling of PoisonOp producing MemDesc (#8489)
neildhar Oct 24, 2025
4d85824
[NVIDIA] Enable TMA gather4 on sm_120 and sm_121 (#8498)
ita9naiwa Oct 24, 2025
7c59c1d
[AMD][GLUON] Expose get wmma/mfma scale layout (#8496)
borontion Oct 24, 2025
7bdcc6b
[triton_kernels][opt_flags] Add function to reset opt_flags (#8453)
matkle Oct 24, 2025
cbab5f4
[Gluon] Change `gl.warp_specialize` API (#8527)
Mogball Oct 24, 2025
869733f
[AMD] NFC: rename Gluon example directory (#8530)
antiagainst Oct 24, 2025
d703656
[GLUON] Set proper location on restoring the insert point in gluon (#…
pawelszczerbuk Oct 24, 2025
a6e7434
[SWP] Dedup the code that checks if LoadOp can be converted to cpasyn…
masahi Oct 24, 2025
11af53c
[AMD][GLUON] Expose buffer ops to gfx1250 (#8532)
borontion Oct 25, 2025
7578e3e
[mxfp] support EXPT_IS_INNER for MX (#8385)
jongsoo-openai Oct 25, 2025
40dd0c4
[Frontend] Make sure aggregate members are added to the cache key (#8…
Mogball Oct 25, 2025
4caa032
[mxfp4] disable swapping block{k,n} for bf16 x mx4 (#8538)
jongsoo-openai Oct 25, 2025
b900855
Remove constexprs from params in `jit.py` (#8536)
ita9naiwa Oct 26, 2025
e8bc90c
[frontend] Disable cache when interpreter is enabled (#8499)
rpelke Oct 26, 2025
bf9fea9
[FRONTEND] support multidimensional batches in tl.trans and tl.dot (#…
apgoucher Oct 26, 2025
50d10bd
[triton_kernels] revert a100 default layout change (#8549)
ptillet Oct 27, 2025
0a32cff
[BACKEND] Reset alloc_shape when doing memdesc_index (#8537)
ThomasRaoux Oct 27, 2025
9f21c06
[Docs] Update gl.warp_specialize docs + use from_tensor in persistent…
peterbell10 Oct 27, 2025
bdb9a7a
Merge commit '9f21c06d55b5c2eccd872d92e9335c4eb13969c5'
whitneywhtsang Oct 29, 2025
31c76e4
[TEST] Skip test_tensor_descriptor_padding
whitneywhtsang Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/triton/Dialect/Triton/IR/Utility.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef TRITON_IR_UTILITY_H_
#define TRITON_IR_UTILITY_H_

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include <algorithm>
#include <numeric>
Expand All @@ -10,6 +12,14 @@ namespace mlir {
// Bitwidth of pointers
constexpr int kPtrBitWidth = 64;

// Returns the bit width of a type, treating pointer-like types as 64-bit.
// This handles LLVM dialect pointer types.
inline int getIntOrFloatOrPtrBitWidth(Type type) {
if (isa<LLVM::LLVMPointerType, triton::PointerType>(type))
return kPtrBitWidth;
return type.getIntOrFloatBitWidth();
}

template <typename T, typename U> SmallVector<T> convertType(ArrayRef<U> in) {
SmallVector<T> out;
for (const auto &i : in)
Expand Down
7 changes: 3 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,9 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
ArrayRef<unsigned> tilesPerWarp,
ArrayRef<unsigned> warpsPerCTA);

LinearLayout chooseScaledWmmaScaleLayout(
MLIRContext *ctx, int dotOperandIdx,
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
ArrayRef<int64_t> dotOperandShape);
LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
ArrayRef<unsigned> warpsPerCTA,
ArrayRef<int64_t> dotOperandShape);

LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx,
ArrayRef<int64_t> shape, int opIdx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ getLastUseOfPipelinedOp(ArrayRef<Operation *> ops, scf::ForOp forOp,

// Clean up attributes passing over schedules across stages in pipelining
void removePipeliningAttributes(ModuleOp moduleOp);

// For LoadOp, DescriptorLoad, and DescriptorGather ops, determine if
// they should be pipelined.
bool isPipeliningBeneficial(Operation *op,
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis,
bool filterSmall = true);

} // namespace triton
} // namespace mlir

Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ class AllocationAnalysis {
auto shapePerCTA = gpu::getAllocationShapePerCTA(allocType);
numElems = product<int64_t>(shapePerCTA);
}
int64_t bytes = numElems * allocType.getElementTypeBitWidth() / 8;
int64_t bytes =
numElems * getIntOrFloatOrPtrBitWidth(allocType.getElementType()) / 8;

auto alignment = alloc.getAlignmentOrDefault();
allocation->addBuffer<BufferT::BufferKind::Explicit>(alloc, bytes,
Expand Down
135 changes: 76 additions & 59 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,23 +91,26 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
auto lhsInfo = operands[0]->getValue();
auto rhsInfo = operands[1]->getValue();
auto rank = lhsInfo.getRank();
assert(isa<RankedTensorType>(op.getType()) ||
rank == 1 && "Expected ranked tensor or scalar");
assert(operands.size() == 2 && "Expected two operands");
auto constantValue = getConstantValue(op, lhsInfo, rhsInfo);
if (constantValue.has_value()) {
auto resTy = dyn_cast<RankedTensorType>(op.getType());
AxisInfo::DimVectorT constancy =
resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1);
AxisInfo::DimVectorT contiguity(rank, 1);
AxisInfo::DimVectorT divisibility(
rank, highestPowOf2Divisor<int64_t>(constantValue.value()));
return AxisInfo(contiguity, divisibility, constancy, constantValue);
}
AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy;
auto constantValue = getConstantValue(op, lhsInfo, rhsInfo);
for (auto d = 0; d < rank; ++d) {
if (constantValue.has_value()) {
contiguity.push_back(1);
constancy.push_back(
std::max(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)));
divisibility.push_back(
highestPowOf2Divisor<int64_t>(constantValue.value()));
} else {
contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d));
constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d));
divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d));
}
contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d));
constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d));
divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d));
}
return AxisInfo(contiguity, divisibility, constancy, constantValue);
}
Expand All @@ -125,9 +128,8 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {

virtual int64_t getConstancy(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) {
return 1;
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}

virtual std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs) {
return {};
Expand Down Expand Up @@ -192,6 +194,26 @@ class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
}
};

class UnrealizedConversionCastOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<mlir::UnrealizedConversionCastOp> {
public:
using AxisInfoVisitorImpl<
mlir::UnrealizedConversionCastOp>::AxisInfoVisitorImpl;

AxisInfo
getAxisInfo(mlir::UnrealizedConversionCastOp op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto tensorType = dyn_cast<RankedTensorType>(op.getResultTypes()[0]);
if (tensorType &&
tensorType.getRank() != operands[0]->getValue().getRank()) {
// Do not propagate AxisInfo with incorrect rank. This can cause a crash
// in future visitor applications.
return AxisInfo::getPessimisticValueState(op->getResult(0));
}
return operands[0]->getValue();
}
};

class MakeRangeOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<triton::MakeRangeOp> {
public:
Expand Down Expand Up @@ -254,7 +276,7 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
getAxisInfo(ub::PoisonOp op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
unsigned rank = 1;
if (auto shape = dyn_cast<mlir::ShapedType>(op.getType()))
if (auto shape = dyn_cast<RankedTensorType>(op.getType()))
rank = shape.getRank();

// Poison values are never accessed, thus assume optimistic values.
Expand Down Expand Up @@ -308,11 +330,6 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
return gcd(lhs.getDivisibility(dim), rhsDivisibility);
}

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}

std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
Expand Down Expand Up @@ -355,11 +372,6 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
return std::max(lhsContiguity, rhsContiguity);
}

int64_t getConstancy(arith::MulIOp op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) override {
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}

int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) override {
auto lhsDivisibility = lhs.getDivisibility(dim);
Expand All @@ -379,9 +391,13 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {

std::optional<int64_t> getConstantValue(arith::MulIOp op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value())
return {lhs.getConstantValue().value() * rhs.getConstantValue().value()};
auto lhsConst = lhs.getConstantValue();
auto rhsConst = rhs.getConstantValue();
if (lhsConst.has_value() && rhsConst.has_value())
return {lhsConst.value() * rhsConst.value()};
if ((lhsConst.has_value() && lhsConst.value() == 0) ||
(rhsConst.has_value() && rhsConst.value() == 0))
return 0;
return {};
}
};
Expand All @@ -404,12 +420,11 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
auto resTy = dyn_cast<RankedTensorType>(op.getType());
auto constancy = BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
if (!resTy)
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
return constancy;
auto shape = resTy.getShape();
// Case 1: both lhs and rhs are constants.
auto constancy = gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
// Case 2: lhs contiguous, rhs constant.
// Case: lhs contiguous, rhs constant.
// lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n
// rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p
// lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p),
Expand Down Expand Up @@ -506,15 +521,15 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
auto constancy = BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
auto resTy = dyn_cast<RankedTensorType>(op.getType());
if (!resTy)
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
auto shape = resTy.getShape();
// lhs % 1 = 0
return rhs.getConstantValue().has_value() &&
rhs.getConstantValue().value() == 1
? shape[dim]
: gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
return constancy;
// Case: lhs % 1 = 0
if (rhs.getConstantValue().has_value() &&
rhs.getConstantValue().value() == 1)
return resTy.getDimSize(dim);
return constancy;
}

std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
Expand Down Expand Up @@ -669,7 +684,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
int64_t constHint = 1;
if (lhsInfo.getConstantValue().has_value() &&
rhsInfo.getConstantValue().has_value()) {
constHint = lhsInfo.getConstancy(d);
constHint = shape[d];
constantValue =
compare(getPredicate(op), lhsInfo.getConstantValue().value(),
rhsInfo.getConstantValue().value())
Expand Down Expand Up @@ -828,6 +843,13 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
rhsInfo.getConstantValue().has_value() &&
lhsInfo.getConstantValue() == rhsInfo.getConstantValue())
constantValue = lhsInfo.getConstantValue();

if (constantValue.has_value()) {
auto resTy = dyn_cast<RankedTensorType>(op.getType());
assert(resTy || rank == 1);
constancy =
resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1);
}
}

return AxisInfo(contiguity, divisibility, constancy, constantValue);
Expand All @@ -840,11 +862,6 @@ class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;

private:
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}

std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
Expand Down Expand Up @@ -890,11 +907,6 @@ class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::ShLIOp> {
return multiplyDivisor(lhsDivisibility, 1ll << shift);
}

int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs,
const AxisInfo &rhs, int dim) override {
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}

std::optional<int64_t> getConstantValue(arith::ShLIOp op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
Expand Down Expand Up @@ -932,11 +944,6 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
}

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
int dim) override {
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
}

std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
Expand Down Expand Up @@ -969,9 +976,15 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
constantValue = {std::min(lhsInfo.getConstantValue().value(),
rhsInfo.getConstantValue().value())};
}
auto resTy = dyn_cast<RankedTensorType>(op.getType());
assert(resTy || rank == 1);
AxisInfo::DimVectorT constancy =
resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1);
AxisInfo::DimVectorT divisibility(
rank, highestPowOf2Divisor<int64_t>(constantValue.value()));
return AxisInfo(/*knownContiguity=*/AxisInfo::DimVectorT(rank, 1),
/*knownDivisibility=*/AxisInfo::DimVectorT(rank, 1),
/*knownConstancy=*/AxisInfo::DimVectorT(rank, 1),
/*knownDivisibility=*/divisibility,
/*knownConstancy=*/constancy,
/*constantValue=*/constantValue);
} else {
AxisInfo::DimVectorT contiguity, divisibility, constancy;
Expand Down Expand Up @@ -1029,11 +1042,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver,
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
// in the process of a PartialConversion, where UnrealizedConversionCast
// may exist
visitors.append<UnrealizedConversionCastOpAxisInfoVisitor>();
visitors.append<CastOpAxisInfoVisitor<arith::ExtSIOp>,
CastOpAxisInfoVisitor<arith::ExtUIOp>,
CastOpAxisInfoVisitor<arith::TruncIOp>,
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
CastOpAxisInfoVisitor<triton::BitcastOp>>();
visitors.append<MakeRangeOpAxisInfoVisitor>();
visitors.append<PoisonOpAxisInfoVisitor>();
Expand Down Expand Up @@ -1214,6 +1227,7 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
return rhs;
if (rhs.getRank() == 0)
return lhs;
assert(lhs.getRank() == rhs.getRank() && "Mismatched ranks");
DimVectorT contiguity;
DimVectorT divisibility;
DimVectorT constancy;
Expand Down Expand Up @@ -1384,7 +1398,10 @@ void ModuleAxisInfoAnalysis::update(CallOpInterface callOp,
callee.setArgAttr(index, attrName, attr);
};
auto axisInfo = axisInfoMap->lookup(value);
assert(axisInfo.getRank() == 1 && "only scalar arguments are supported");
// Only scalar arguments are supported. Do not forward multi-dimensional
// AxisInfo to the callee.
if (axisInfo.getRank() != 1)
continue;
setAttrFn("tt.contiguity", axisInfo.getContiguity(0));
setAttrFn("tt.divisibility", axisInfo.getDivisibility(0));
setAttrFn("tt.constancy", axisInfo.getConstancy(0));
Expand Down
5 changes: 1 addition & 4 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ struct ConvertLayoutOpConversion
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
const TargetInfoBase &targetInfo;

// Set benefit to 2 so that this pattern applies before other convert-layout
// conversions. TODO(jlebar): Eventually we want this to be the only pattern.
explicit ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit = 1)
Expand Down Expand Up @@ -277,8 +275,7 @@ struct ConvertLayoutOpConversion
StringAttr kReg = str_attr("register");
StringAttr kLane = str_attr("lane");
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
int bitwidth =
elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : kPtrBitWidth;
int bitwidth = getIntOrFloatOrPtrBitWidth(elemTy);

auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, bitwidth);
auto &[pReg, pLane, mixedTranspositions, nPack] = factors;
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ struct ElementwiseInlineAsmOpConversion
auto ty = getTypeConverter()->convertType(getElementType(result));

// Pack return elements into 32-bits.
unsigned bitWidth = ty.isIntOrFloat() ? ty.getIntOrFloatBitWidth() : 64;
unsigned bitWidth = getIntOrFloatOrPtrBitWidth(ty);
unsigned numElemsPerReg =
std::min(std::max(32 / bitWidth, 1u), op.getPackedElement());
assert(op.getPackedElement() % numElemsPerReg == 0);
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ SmallVector<Value> lowerLdSt(
auto kLane = str_attr("lane");
auto kWarp = str_attr("warp");
auto kOffset = str_attr("offset");
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy);

auto [elemsPerVec, permutation] =
largestVectorisation(ctx, cvt, bitwidth, maybeMaxVecElems);
Expand Down Expand Up @@ -625,7 +625,7 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx,
assert(*cvt.getOutDimNames().begin() == str_attr("offset"));
auto calcPaddedOffset = [&](Value smemOffset) {
TritonLLVMOpBuilder b(loc, rewriter);
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy);
if (auto paddedEnc = dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(
srcTy.getEncoding())) {
// Apply the offset needed for padding.
Expand Down
Loading
Loading