Skip to content

Commit dee0846

Browse files
authored
[BACKEND] Fix wrong K dimension for dot_scaled op (#6269)
When we have a lhs of 4bits the K dimension calculated was wrong due to ambiguous meaning of some variable.
1 parent 593a1b5 commit dee0846

File tree

2 files changed

+48
-7
lines changed

2 files changed

+48
-7
lines changed

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,44 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
229229

230230
// -----
231231

232+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
233+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
234+
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
235+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
236+
#tmem_scales = #ttng.tensor_memory_scales_encoding<>
237+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
238+
// CHECK-LABEL: @tc_gen5_mma_block_scale_fp4_a
239+
// CHECK: %[[DESC0:.+]] = llvm.mlir.constant(144769664 : i32) : i32
240+
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %[[DESC0]]
241+
// CHECK: %[[DESC1:.+]] = llvm.mlir.constant(681640592 : i32) : i32
242+
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %[[DESC1]]
243+
// CHECK: %[[DESC2:.+]] = llvm.mlir.constant(1218511520 : i32) : i32
244+
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %[[DESC2]]
245+
// CHECK: %[[DESC3:.+]] = llvm.mlir.constant(1755382448 : i32) : i32
246+
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %[[DESC3]]
247+
tt.func @tc_gen5_mma_block_scale_fp4_a(%a: !ttg.memdesc<128x64xi8, #shared1, #ttg.shared_memory>,
248+
%b: !ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>,
249+
%c: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
250+
%scale_a: !ttg.memdesc<128x2xi8, #tmem_scales, #ttng.tensor_memory>,
251+
%scale_b: !ttg.memdesc<128x2xi8, #tmem_scales, #ttng.tensor_memory>,
252+
%useAcc: i1,
253+
%pred: i1,
254+
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) {
255+
ttng.tc_gen5_mma_scaled %a, %b, %c, %scale_a, %scale_b, %useAcc, %pred lhs = e2m1 rhs = e4m3, %barrier :
256+
(!ttg.memdesc<128x64xi8, #shared1, #ttg.shared_memory>,
257+
!ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>,
258+
!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
259+
!ttg.memdesc<128x2xi8, #tmem_scales, #ttng.tensor_memory>,
260+
!ttg.memdesc<128x2xi8, #tmem_scales, #ttng.tensor_memory>,
261+
i1,
262+
i1,
263+
!ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) -> ()
264+
tt.return
265+
}
266+
}
267+
268+
// -----
269+
232270
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
233271
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CTAsPerCGA = [1, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}>
234272
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [2], CTASplitNum = [1], CTAOrder = [0]}>

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -511,11 +511,14 @@ struct TCGen5MMAScaledOpConversion
511511

512512
unsigned int M = dTensorTy.getDimSize(0);
513513
unsigned int N = dTensorTy.getDimSize(1);
514-
int numBitsPerElementA = opKindIsMXFP4 ? getFormatBitSize(op.getAType())
515-
: aTensorTy.getElementTypeBitWidth();
516-
int numBitsPerElementB = opKindIsMXFP4 ? getFormatBitSize(op.getBType())
517-
: bTensorTy.getElementTypeBitWidth();
518-
unsigned int K = (aTensorTy.getDimSize(1) * 8) / numBitsPerElementA;
514+
int numBitsUnpackedPerElementA = opKindIsMXFP4
515+
? getFormatBitSize(op.getAType())
516+
: aTensorTy.getElementTypeBitWidth();
517+
int numBitsUnpackedPerElementB = opKindIsMXFP4
518+
? getFormatBitSize(op.getBType())
519+
: bTensorTy.getElementTypeBitWidth();
520+
unsigned int K =
521+
(aTensorTy.getDimSize(1) * 8) / getFormatBitSize(op.getAType());
519522

520523
// Get MMA size based on acc layout.
521524
auto tensorMemAttr = cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
@@ -546,12 +549,12 @@ struct TCGen5MMAScaledOpConversion
546549
} else {
547550
aLoader = std::make_unique<DotOpMmaV3SmemLoader>(
548551
op.getA(), baseA, shapeA, shapeA, zero, 1, transA, aOperandShape,
549-
numBitsPerElementA, rewriter, loc);
552+
numBitsUnpackedPerElementA, rewriter, loc);
550553
}
551554
DotOpMmaV3SmemLoader bLoader =
552555
DotOpMmaV3SmemLoader(op.getB(), baseB, shapeB, shapeB, zero, 1, transB,
553556
{(unsigned)mmaSizeN, (unsigned)mmaSizeK},
554-
numBitsPerElementB, rewriter, loc);
557+
numBitsUnpackedPerElementB, rewriter, loc);
555558

556559
// Only run mma on one thread. We currently use elect as ptxas is not able
557560
// to detect that tid.x == 0 is true only for 1 thread.

0 commit comments

Comments
 (0)