Skip to content

Commit fca399f

Browse files
authored
1 parent 33b2823 commit fca399f

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,24 @@ class DotOpMmaV3SmemLoader : public DotOpMmaMemLoader {
7272
Value descriptor;
7373
};
7474

75+
// Helper class to load shared memory slices following MMAv5 layout.
76+
class DotOpMmaV5SmemLoader : public DotOpMmaV3SmemLoader {
77+
public:
78+
using DotOpMmaV3SmemLoader::DotOpMmaV3SmemLoader;
79+
80+
// Return a descriptor pointing to the shared memory slice at coordinates (a,
81+
// b), with bit 46 set.
82+
Value smemLoad(int a, int b, ConversionPatternRewriter &rewriter,
83+
Location loc) const {
84+
auto tb = TritonLLVMOpBuilder(loc, rewriter);
85+
Value desc = DotOpMmaV3SmemLoader::smemLoad(a, b, rewriter, loc);
86+
// Set bit 46 as per
87+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-shared-memory-descriptor
88+
Value mask = tb.int_val(64, 1ULL << 46);
89+
return tb.or_(desc, mask, /*disjoint*/ true);
90+
}
91+
};
92+
7593
// Helper class to load tensor memory following MMAv5 layout.
7694
class DotOpMmaV5TmemLoader : public DotOpMmaMemLoader {
7795
public:

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,13 +454,13 @@ void convertDotImpl(const LLVMTypeConverter &typeConverter,
454454
interleaved, transA);
455455
} else {
456456
auto allocShapeA = getAllocShape(aTensorTy, 1);
457-
aLoader = std::make_unique<DotOpMmaV3SmemLoader>(
457+
aLoader = std::make_unique<DotOpMmaV5SmemLoader>(
458458
a, baseA, shapeA, allocShapeA, zero, 1, transA, aOperandShape,
459459
op.numBitsPerElementA, rewriter, loc);
460460
}
461461

462462
auto allocShapeB = getAllocShape(bTensorTy, 0);
463-
DotOpMmaV3SmemLoader bLoader = DotOpMmaV3SmemLoader(
463+
DotOpMmaV5SmemLoader bLoader = DotOpMmaV5SmemLoader(
464464
b, baseB, shapeB, allocShapeB, zero, 1, transB, {mmaSizeN, mmaSizeK},
465465
op.numBitsPerElementB, rewriter, loc);
466466

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1016,7 +1016,7 @@ static void copySharedToTmem(ConversionPatternRewriter &rewriter, Location loc,
10161016
auto createCopy = [&](int repM, int repN) {
10171017
Value zero = b.i32_val(0);
10181018
SmallVector<int64_t> shape(op.getSrc().getType().getShape());
1019-
DotOpMmaV3SmemLoader smemLoader = DotOpMmaV3SmemLoader(
1019+
DotOpMmaV5SmemLoader smemLoader = DotOpMmaV5SmemLoader(
10201020
op.getSrc(), baseSrc, shape, op.getSrc().getType().getAllocShape(),
10211021
zero, 1, /*trans=*/false, {128, 8},
10221022
op.getSrc().getType().getElementType().getIntOrFloatBitWidth(),

0 commit comments

Comments
 (0)