Skip to content

Commit 3a832d6

Browse files
authored
[BACKEND] Fix memdesc of pointers (#8515)
Prevent crash when lowering memdesc of pointer
1 parent 4d6ce4e commit 3a832d6

File tree

13 files changed

+57
-19
lines changed

13 files changed

+57
-19
lines changed

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#ifndef TRITON_IR_UTILITY_H_
22
#define TRITON_IR_UTILITY_H_
33

4+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
5+
#include "mlir/IR/BuiltinTypes.h"
46
#include "triton/Dialect/Triton/IR/Dialect.h"
57
#include <algorithm>
68
#include <numeric>
@@ -10,6 +12,14 @@ namespace mlir {
1012
// Bitwidth of pointers
1113
constexpr int kPtrBitWidth = 64;
1214

15+
// Returns the bit width of a type, treating pointer-like types as 64-bit.
16+
// This handles LLVM dialect pointer types.
17+
inline int getIntOrFloatOrPtrBitWidth(Type type) {
18+
if (isa<LLVM::LLVMPointerType, triton::PointerType>(type))
19+
return kPtrBitWidth;
20+
return type.getIntOrFloatBitWidth();
21+
}
22+
1323
template <typename T, typename U> SmallVector<T> convertType(ArrayRef<U> in) {
1424
SmallVector<T> out;
1525
for (const auto &i : in)

lib/Analysis/Allocation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ class AllocationAnalysis {
152152
auto shapePerCTA = gpu::getAllocationShapePerCTA(allocType);
153153
numElems = product<int64_t>(shapePerCTA);
154154
}
155-
int64_t bytes = numElems * allocType.getElementTypeBitWidth() / 8;
155+
int64_t bytes =
156+
numElems * getIntOrFloatOrPtrBitWidth(allocType.getElementType()) / 8;
156157

157158
auto alignment = alloc.getAlignmentOrDefault();
158159
allocation->addBuffer<BufferT::BufferKind::Explicit>(alloc, bytes,

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,7 @@ struct ConvertLayoutOpConversion
271271
StringAttr kReg = str_attr("register");
272272
StringAttr kLane = str_attr("lane");
273273
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
274-
int bitwidth =
275-
elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : kPtrBitWidth;
274+
int bitwidth = getIntOrFloatOrPtrBitWidth(elemTy);
276275

277276
auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, bitwidth);
278277
auto &[pReg, pLane, mixedTranspositions, nPack] = factors;

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ struct ElementwiseInlineAsmOpConversion
276276
auto ty = getTypeConverter()->convertType(getElementType(result));
277277

278278
// Pack return elements into 32-bits.
279-
unsigned bitWidth = ty.isIntOrFloat() ? ty.getIntOrFloatBitWidth() : 64;
279+
unsigned bitWidth = getIntOrFloatOrPtrBitWidth(ty);
280280
unsigned numElemsPerReg =
281281
std::min(std::max(32 / bitWidth, 1u), op.getPackedElement());
282282
assert(op.getPackedElement() % numElemsPerReg == 0);

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ SmallVector<Value> lowerLdSt(
540540
auto kLane = str_attr("lane");
541541
auto kWarp = str_attr("warp");
542542
auto kOffset = str_attr("offset");
543-
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
543+
auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy);
544544

545545
auto [elemsPerVec, permutation] =
546546
largestVectorisation(ctx, cvt, bitwidth, maybeMaxVecElems);
@@ -625,7 +625,7 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx,
625625
assert(*cvt.getOutDimNames().begin() == str_attr("offset"));
626626
auto calcPaddedOffset = [&](Value smemOffset) {
627627
TritonLLVMOpBuilder b(loc, rewriter);
628-
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
628+
auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy);
629629
if (auto paddedEnc = dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(
630630
srcTy.getEncoding())) {
631631
// Apply the offset needed for padding.

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@ using namespace mlir::triton;
1111
using namespace mlir::triton::gpu;
1212
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
1313
namespace {
14+
15+
Value bitOrPtrCast(Value val, Type type, TritonLLVMOpBuilder &b) {
16+
if (isa<LLVM::LLVMPointerType>(val.getType()) &&
17+
!isa<LLVM::LLVMPointerType>(type)) {
18+
return b.ptrtoint(type, val);
19+
} else {
20+
return b.bitcast(val, type);
21+
}
22+
}
23+
1424
struct SplatOpConversion : public ConvertOpToLLVMPattern<triton::SplatOp> {
1525
using ConvertOpToLLVMPattern<triton::SplatOp>::ConvertOpToLLVMPattern;
1626
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
@@ -39,13 +49,13 @@ struct SplatOpConversion : public ConvertOpToLLVMPattern<triton::SplatOp> {
3949
unsigned ratio = srcBitWidth / cstBitWidth;
4050
Type intTy = IntegerType::get(elemType.getContext(), cstBitWidth);
4151
VectorType vecType = VectorType::get(ratio, intTy);
42-
Value intCst = b.bitcast(constVal, intTy);
52+
Value intCst = bitOrPtrCast(constVal, intTy, b);
4353
Value vec = b.undef(vecType);
4454
for (unsigned i = 0; i < ratio; ++i)
4555
vec = b.insert_element(vecType, vec, intCst, b.int_val(32, i));
4656
constVal = vec;
4757
}
48-
auto llSrc = b.bitcast(constVal, srcType);
58+
Value llSrc = bitOrPtrCast(constVal, srcType, b);
4959
size_t elemsPerThread = getTotalElemsPerThread(tensorTy);
5060
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
5161
return packLLElements(loc, typeConverter, elems, rewriter, resType);

test/Analysis/test-allocation.mlir

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,18 @@ tt.func @preallocate(%A : !tt.ptr<f16>) {
155155
tt.return
156156
}
157157

158+
// expected-remark @below {{memdesc_ptr}}
159+
// expected-remark @below {{size = 6144}}
160+
tt.func @memdesc_ptr() {
161+
// expected-remark @below {{offset = 0, size = 4096}}
162+
%a0 = ttg.local_alloc : () -> !ttg.memdesc<32x16x!tt.ptr<f16>, #A_SHARED, #ttg.shared_memory, mutable>
163+
// expected-remark @below {{offset = 4096, size = 2048}}
164+
%a1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16x!tt.ptr<f16>, #A_SHARED, #ttg.shared_memory, mutable>
165+
ttg.local_dealloc %a0 : !ttg.memdesc<32x16x!tt.ptr<f16>, #A_SHARED, #ttg.shared_memory, mutable>
166+
ttg.local_dealloc %a1 : !ttg.memdesc<1x16x16x!tt.ptr<f16>, #A_SHARED, #ttg.shared_memory, mutable>
167+
tt.return
168+
}
169+
158170
// Unused tensors are immediately released
159171
// expected-remark @below {{unused}}
160172
// expected-remark @below {{size = 1024}}
@@ -279,9 +291,9 @@ tt.func @multi_color_multi_rounds(%arg0: !tt.ptr<f16>) {
279291
}
280292

281293

282-
// expected-remark @below {{alloc}}
294+
// expected-remark @below {{alloc_ptr}}
283295
// expected-remark @below {{size = 512}}
284-
tt.func @alloc(%A : !tt.ptr<f16>) {
296+
tt.func @alloc_ptr(%A : !tt.ptr<f16>) {
285297
// expected-remark @below {{offset = 0, size = 512}}
286298
%cst0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>
287299
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>

third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "Utility.h"
22

33
#include "mlir/Dialect/SCF/IR/SCF.h"
4+
#include "triton/Dialect/Triton/IR/Utility.h"
45
#include "triton/Tools/LayoutUtils.h"
56

67
#include <limits>
@@ -159,7 +160,7 @@ ttg::PaddedSharedEncodingAttr composePaddedLayoutForAsyncCopyCDNA4(
159160
return {};
160161
}
161162

162-
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
163+
unsigned bitWidth = getIntOrFloatOrPtrBitWidth(srcTy.getElementType());
163164
unsigned elemByteWidth = std::max(bitWidth / 8u, 1u);
164165
auto loadBytes = shape[0] * shape[1] * elemByteWidth;
165166
if (loadBytes < 16384) {

third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ int getTxCount(Operation *descOp) {
145145
auto encoding = getEncodingFromDescriptor(descOp, tensorType, desc);
146146
auto shapePerCTA = getShapePerCTA(encoding, tensorType.getShape());
147147
return product(shapePerCTA) *
148-
tensorType.getElementType().getIntOrFloatBitWidth() / 8;
148+
getIntOrFloatOrPtrBitWidth(tensorType.getElementType()) / 8;
149149
}
150150

151151
void createNVWSDescriptorLoadOp(OpBuilder &builder, Operation *ttDescLoadOp,

third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,8 @@ class LoadAcquireOpPattern : public OpRewritePattern<ttn::LoadAcquireOp> {
256256
auto loc = op->getLoc();
257257
auto b = TritonLLVMOpBuilder(loc, rewriter);
258258
Type valueTy = op.getType();
259-
const unsigned valueNBits = std::max(8u, valueTy.getIntOrFloatBitWidth());
259+
const unsigned valueNBits =
260+
std::max(8u, (unsigned)getIntOrFloatOrPtrBitWidth(valueTy));
260261
const size_t maxWordWidth = std::max<size_t>(32, valueNBits);
261262
const size_t width = std::min((size_t)valueNBits, maxWordWidth);
262263

0 commit comments

Comments
 (0)