@@ -52,98 +52,28 @@ struct LocalLoadOpConversion
5252 auto kOrder = dotEnc.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
5353 auto nonKOrder = dotEnc.getOpIdx () == 0 ? rank - 2 : rank - 1 ;
5454 auto needTrans = kOrder != sharedEnc.getOrder ()[0 ];
55- // Limitation 1: Cannot use ldmatrix if we need to transpose a non-fp16
56- // matrix
57- // Limitation 2: If kWidth is greater than the vector width of the dot
58- // operands of MMA, we don't use ldmatrix
59- // Limitation 3 [TODO: remove]: Shared memory with leading offset is not
60- // supported yet
61- auto canUseLdmatrixLegacy =
55+ // Limitation 1 [TODO: remove]: Check LL bases to verify register and
56+ // address alignment
57+ auto canUseLdmatrix =
6258 (kWidth == vecWidth) && (!sharedEnc.getHasLeadingOffset ());
63- if (mmaEnc.isHopper ()) {
64- // Limitation 4 [TODO: remove]:
65- // I think we should be able to remove this condition, but it's here
66- // as the legacy ldmatrix path does not support it
67- canUseLdmatrixLegacy &= srcTy.getElementTypeBitWidth () * kWidth == 32 &&
68- dotEnc.getOpIdx () == 0 ;
69- }
70- // Limitation 5: If we perform swizzling, it must be done within a single
71- // ldmatrix tile
72- auto maxPhase = sharedEnc.getMaxPhase ();
73- auto perPhase = sharedEnc.getPerPhase ();
74- auto vecSize = sharedEnc.getVec ();
75- canUseLdmatrixLegacy &=
76- (maxPhase == 1 ) ||
77- ((maxPhase / perPhase <= 8 ) && (vecSize * bitwidth >= 8 * 16 ));
59+ canUseLdmatrix &= (sharedEnc.getMaxPhase () == 1 ) ||
60+ (sharedEnc.getVec () * bitwidth >= 8 * 16 );
7861 auto shape = srcTy.getShape ();
79- auto allocShape = srcTy.getAllocShape ();
80- // Limitation 6 [TODO: remove]: Only support 2d matrices now but we should
62+ // Limitation 2 [TODO: remove]: Only support 2d matrices now but we should
8163 // be able to support 3D minor changes
82- auto canUseLdmatrixLL = (bitwidth <= 16 || (!needTrans)) &&
83- shape.size () <= 2 && canUseLdmatrixLegacy;
84- canUseLdmatrixLegacy &=
85- (bitwidth == 16 || (!needTrans)) && shape.size () <= 2 ;
86- if (dotEnc.getOpIdx () == 0 ) {
87- canUseLdmatrixLL &=
88- shape[kOrder ] >= (16 * 16 / bitwidth) && shape[nonKOrder] >= 16 ;
89- } else {
90- // Limitation 8 [TODO: remove]: Due to the use of ldmatrix.x4, we need
91- // to read 4 tiles. For opIdx=1, a single warp load four consecutive
92- // tiles along the K dimension, so the minimum K size is 4 * 8 = 32.
93- // The legacy path doesn't have this limitation because it reads
94- // duplicated elements from shared memory and throw them away.
95- // It might be better to use ldmatrix.x2 in such a case instead of
96- // abandoning elements.
97- canUseLdmatrixLL &=
98- shape[kOrder ] >= (32 * 16 / bitwidth) && shape[nonKOrder] >= 16 ;
99- }
100- // Limitation 9 [TODO: remove]:
101- // If we remove this one, ldmatrix will IMA. It can probably be relaxed
102- // though. Remove this constraint after all other limitations have been
103- // resolved
104- canUseLdmatrixLegacy &=
105- srcTy.getShape ()[0 ] >= 8 && srcTy.getShape ()[1 ] >= 4 * kWidth ;
106- if (canUseLdmatrixLL) {
64+ canUseLdmatrix &= (bitwidth <= 16 || !needTrans) && shape.size () <= 2 ;
65+ // Limitation 3: Minimum tile size (8)x(8x16bits)
66+ canUseLdmatrix &=
67+ shape[kOrder ] >= (8 * 16 / bitwidth) && shape[nonKOrder] >= 8 ;
68+ if (canUseLdmatrix) {
10769 return lowerSharedToDotOperandLL (op, adaptor, getTypeConverter (),
10870 rewriter);
109- } else if (canUseLdmatrixLegacy) {
110- return lowerSharedToDotOperandLegacy (op, adaptor, getTypeConverter (),
111- rewriter);
11271 }
11372 }
11473 return failure ();
11574 }
11675
11776private:
118- LogicalResult
119- lowerSharedToDotOperandLegacy (triton::gpu::LocalLoadOp op,
120- triton::gpu::LocalLoadOpAdaptor adaptor,
121- const LLVMTypeConverter *typeConverter,
122- ConversionPatternRewriter &rewriter) const {
123- auto loc = op.getLoc ();
124- auto src = op.getSrc ();
125- auto dstLayout = cast<DotOperandEncodingAttr>(op.getType ().getEncoding ());
126- auto mmaLayout = cast<NvidiaMmaEncodingAttr>(dstLayout.getParent ());
127- auto llvmElemTy =
128- typeConverter->convertType (src.getType ().getElementType ());
129- auto smemObj = LLVM::getSharedMemoryObjectFromStruct (loc, adaptor.getSrc (),
130- llvmElemTy, rewriter);
131- Value res;
132- if (mmaLayout.isHopper () || mmaLayout.isAmpere ()) { // tensor core v2 or v3
133- if (mmaLayout.isHopper ())
134- assert (dstLayout.getOpIdx () == 0 &&
135- " Operand $b in MMAv3 can only be in shared memory" );
136-
137- res = SharedToDotOperandMMAv2OrV3::convertLayout (
138- dstLayout.getOpIdx (), rewriter, loc, src, dstLayout, smemObj,
139- typeConverter, getThreadId (rewriter, loc));
140- } else {
141- llvm_unreachable (" Unsupported mma layout found" );
142- }
143- rewriter.replaceOp (op, res);
144- return success ();
145- }
146-
14777 LogicalResult
14878 lowerSharedToDotOperandLL (triton::gpu::LocalLoadOp op,
14979 triton::gpu::LocalLoadOpAdaptor adaptor,
@@ -158,6 +88,7 @@ struct LocalLoadOpConversion
15888 auto shape = dstTy.getShape ();
15989 auto rank = dstTy.getRank ();
16090 auto kOrder = dotEnc.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
91+ auto nonKOrder = dotEnc.getOpIdx () == 0 ? rank - 2 : rank - 1 ;
16192 auto needTrans = kOrder != sharedEnc.getOrder ()[0 ];
16293
16394 auto llvmElemTy = typeConverter->convertType (dstTy.getElementType ());
@@ -169,22 +100,25 @@ struct LocalLoadOpConversion
169100
170101 // Emit ldmatrix load operations for values packed in i32s
171102 SmallVector<Value> elemsI32;
103+ // Typically we load 32x8 to use ldmatrix.x4, but the minimum tile size for
104+ // opIdx=1 is 16x8. Therefore, we use ldmatrix.x2 instead of
105+ // ldmatrix.x4 in this case.
106+ auto shift = dotEnc.getOpIdx () == 1 && shape[kOrder ] < (32 * 16 / bitwidth);
172107 auto maxVecElems = 8 * 16 / bitwidth;
173108 bool valid = emitTransferBetweenRegistersAndShared (
174109 ldmatrixLayout, srcTy, llvmElemTy,
175110 /* maxVecElems=*/ maxVecElems, smemObj, loc, rewriter, targetInfo,
176111 [&](VectorType vecTy, Value vecAddr) {
177112 auto numElems = vecTy.getNumElements ();
178- auto numElemsI32 = numElems * bitwidth / 32 ;
113+ auto numElemsI32 = ( numElems * bitwidth / 32 ) >> shift ;
179114 auto matTy = LLVM::LLVMStructType::getLiteral (
180115 ctx, SmallVector<Type>(numElemsI32, i32_ty));
181116 auto ldMatrixOp = rewriter.create <nvgpu::LoadMatrixOp>(
182117 loc, matTy, vecAddr, /* needTrans=*/ needTrans);
183- auto resV4 = ldMatrixOp.getResult ();
184- elemsI32.push_back (extract_val (i32_ty, resV4, 0 ));
185- elemsI32.push_back (extract_val (i32_ty, resV4, 1 ));
186- elemsI32.push_back (extract_val (i32_ty, resV4, 2 ));
187- elemsI32.push_back (extract_val (i32_ty, resV4, 3 ));
118+ auto res = ldMatrixOp.getResult ();
119+ for (auto i = 0 ; i < numElemsI32; ++i) {
120+ elemsI32.push_back (extract_val (i32_ty, res, i));
121+ }
188122 });
189123 assert (valid && " Failed to emit ldmatrix load operations" );
190124
0 commit comments