@@ -218,9 +218,10 @@ struct DotOpMFMAConverter {
218218 bool allowXF32 =
219219 dotOp.getInputPrecision () == InputPrecision::TF32 && mfmaVersion == 3 ;
220220
221- FailureOr<MfmaIntrinsic> maybeMfmaInsn = MfmaIntrinsic::selectFor (
222- mfmaVersion, mDim , nDim, kDimOperandSize , elemTyA, elemTyB,
223- /* withScale=*/ false , allowXF32);
221+ FailureOr<MfmaIntrinsic> maybeMfmaInsn =
222+ MfmaIntrinsic::selectFor (dotOp->getLoc (), mfmaVersion, mDim , nDim,
223+ kDimOperandSize , elemTyA, elemTyB,
224+ /* withScale=*/ false , allowXF32);
224225
225226 SmallVector<unsigned > mfmaShape = {16 , 16 , 16 };
226227 if (failed (maybeMfmaInsn)) {
@@ -265,15 +266,15 @@ struct DotOpMFMAConverter {
265266 auto extractSliceTypeA =
266267 RankedTensorType::get (refinedShapeA, elemTyA, encodeA);
267268 rewriter.setInsertionPointAfter (dotOp);
268- SmallVector<SmallVector<amdgpu::ExtractSliceOp>> subtilesA;
269+ SmallVector<SmallVector<triton:: amdgpu::ExtractSliceOp>> subtilesA;
269270 unsigned tileIdx = 0 ;
270271 for (int32_t k = 0 ; k < numRepK; ++k) {
271- SmallVector<amdgpu::ExtractSliceOp> subtilesK;
272+ SmallVector<triton:: amdgpu::ExtractSliceOp> subtilesK;
272273 for (int32_t i = 0 ; i < numRepM; ++i) {
273274 LDBG (" local_load_a[" << i << " ][" << k << " ] extract_slice" );
274275 int32_t shiftM = i * elementsPerSliceM;
275276 int32_t shiftK = k * elementsPerSliceK;
276- auto extract = rewriter.create <amdgpu::ExtractSliceOp>(
277+ auto extract = rewriter.create <triton:: amdgpu::ExtractSliceOp>(
277278 loc, Type{extractSliceTypeA}, Value{a},
278279 DenseI64ArrayAttr::get (ctx, {shiftM, shiftK}));
279280 // Add dot-tile info to local_load's slice;
@@ -304,15 +305,15 @@ struct DotOpMFMAConverter {
304305 // Extract slices for B operands.
305306 auto extractSliceTypeB =
306307 RankedTensorType::get (refinedShapeB, elemTyB, encodeB);
307- SmallVector<SmallVector<amdgpu::ExtractSliceOp>> subtilesB;
308+ SmallVector<SmallVector<triton:: amdgpu::ExtractSliceOp>> subtilesB;
308309 tileIdx = 0 ;
309310 for (int32_t k = 0 ; k < numRepK; ++k) {
310- SmallVector<amdgpu::ExtractSliceOp> subtilesK;
311+ SmallVector<triton:: amdgpu::ExtractSliceOp> subtilesK;
311312 for (int32_t j = 0 ; j < numRepN; ++j) {
312313 LDBG (" local_load_b[" << k << " ][" << j << " ] extact_slice" );
313314 int32_t shiftN = j * elementsPerSliceN;
314315 int32_t shiftK = k * elementsPerSliceK;
315- auto extract = rewriter.create <amdgpu::ExtractSliceOp>(
316+ auto extract = rewriter.create <triton:: amdgpu::ExtractSliceOp>(
316317 loc, Type{extractSliceTypeB}, Value{b},
317318 DenseI64ArrayAttr::get (ctx, {shiftK, shiftN}));
318319 // Add dot-tile info to local_load's slice;
@@ -404,9 +405,8 @@ struct DotOpMFMAConverter {
404405 }
405406 }
406407
407- auto concatDims = DenseI64ArrayAttr::get (ctx, {numRepM, numRepN});
408408 auto joinedDotsResult = rewriter.create <triton::amdgpu::ConcatOp>(
409- loc, dTensorTy, refinedDotValues, concatDims );
409+ loc, dTensorTy, refinedDotValues);
410410
411411 d.replaceAllUsesWith (joinedDotsResult);
412412
@@ -545,17 +545,8 @@ struct LocalLoadOpPattern
545545 }
546546 }
547547
548- // concat dims is correct shape 8x1 vs 1x8, else gives wrong output shape.
549- std::vector<int64_t > loweringOrder (numReps2D.size ());
550- int64_t counter = 0 ;
551- auto increment = [&counter](int64_t &val) { val = counter++; };
552- if (opIdx == 0 )
553- std::for_each (loweringOrder.rbegin (), loweringOrder.rend (), increment);
554- else
555- std::for_each (loweringOrder.begin (), loweringOrder.end (), increment);
556-
557- auto joinedResult = rewriter.create <triton::amdgpu::ConcatOp>(
558- loc, resultType, subtiles, numReps2D, loweringOrder);
548+ auto joinedResult =
549+ rewriter.create <triton::amdgpu::ConcatOp>(loc, resultType, subtiles);
559550 LDBG (" ConcatOp: " << *joinedResult);
560551
561552 rewriter.replaceOp (op, joinedResult);
@@ -617,9 +608,8 @@ struct LoadOpPattern : public RefineRewritePattern<triton::LoadOp> {
617608 refinedTensors.push_back (refinedTensor);
618609 }
619610
620- auto concatDims = DenseI64ArrayAttr::get (ctx, refinedBlock.numPerDims );
621611 auto joinedResult = rewriter.create <triton::amdgpu::ConcatOp>(
622- loc, origResultType, refinedTensors, concatDims );
612+ loc, origResultType, refinedTensors);
623613
624614 origResult.replaceAllUsesWith (joinedResult);
625615 return success ();
@@ -831,10 +821,8 @@ struct ReduceOpPattern : public RefineRewritePattern<triton::ReduceOp> {
831821
832822 // Concat reduce slices.
833823 auto reduceResultType = op.getResultTypes ()[0 ];
834- SmallVector<int64_t > concatDimShape = {numReps};
835- auto concatDims = DenseI64ArrayAttr::get (ctx, concatDimShape);
836824 auto concatOp = rewriter.create <triton::amdgpu::ConcatOp>(
837- loc, reduceResultType, refinedReduces, concatDims );
825+ loc, reduceResultType, refinedReduces);
838826 auto origOpResult = op.getResult ();
839827 origOpResult.replaceAllUsesWith (concatOp);
840828 rewriter.eraseOp (op);
@@ -952,9 +940,8 @@ struct ElementWiseOpPattern : public RefineRewritePattern<OpTy> {
952940
953941 // Concat slices.
954942 auto resultType = op->getResultTypes ()[0 ];
955- auto concatDims = DenseI64ArrayAttr::get (op->getContext (), numReps);
956943 auto concatOp = rewriter.create <triton::amdgpu::ConcatOp>(
957- op.getLoc (), resultType, refinedOps, concatDims );
944+ op.getLoc (), resultType, refinedOps);
958945
959946 auto origOpResult = op.getResult ();
960947 origOpResult.replaceAllUsesWith (concatOp);
@@ -1076,11 +1063,8 @@ struct ExpandDimsOpPattern : public RefineRewritePattern<triton::ExpandDimsOp> {
10761063
10771064 // Concat refined ops.
10781065 auto reduceResultType = op->getResultTypes ()[0 ];
1079- // Expand dims of numReps also before concat.
1080- numReps.insert (numReps.begin () + op.getAxis (), 1 );
1081- auto concatDims = DenseI64ArrayAttr::get (op->getContext (), numReps);
10821066 auto concatOp = rewriter.create <triton::amdgpu::ConcatOp>(
1083- op.getLoc (), reduceResultType, refinedReduces, concatDims );
1067+ op.getLoc (), reduceResultType, refinedReduces);
10841068 auto origOpResult = op.getResult ();
10851069
10861070 auto checkLL = triton::gpu::toLinearEncoding (
@@ -1195,9 +1179,8 @@ struct BroadcastOpPattern : public RefineRewritePattern<BroadcastOp> {
11951179
11961180 // Concat refined ops.
11971181 auto reduceResultType = op->getResultTypes ()[0 ];
1198- auto concatDims = DenseI64ArrayAttr::get (op->getContext (), numReps);
11991182 auto concatOp = rewriter.create <triton::amdgpu::ConcatOp>(
1200- op.getLoc (), reduceResultType, refinedBroadcasts, concatDims );
1183+ op.getLoc (), reduceResultType, refinedBroadcasts);
12011184
12021185 auto origOpResult = op.getResult ();
12031186 origOpResult.replaceAllUsesWith (concatOp);
0 commit comments