1515#include " mlir/Dialect/Affine/IR/AffineOps.h"
1616#include " mlir/Dialect/Arith/IR/Arith.h"
1717#include " mlir/Dialect/Bufferization/IR/Bufferization.h"
18+ #include " mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
1819#include " mlir/Dialect/Func/IR/FuncOps.h"
1920#include " mlir/Dialect/GPU/IR/GPUDialect.h"
2021#include " mlir/Dialect/Linalg/IR/Linalg.h"
@@ -2911,8 +2912,9 @@ static func::FuncOp createGpuAttentionKernel(ModuleOp module,
29112912 MemRefType resMemRefType =
29122913 MemRefType::get ({qShape[0 ], sequenceLengthQ, sequenceLengthK},
29132914 cast<ShapedType>(qkTensor.getType ()).getElementType ());
2914- Value resMemref =
2915- builder.create <bufferization::ToBufferOp>(loc, resMemRefType, qkTensor);
2915+ Value resMemref = builder.create <bufferization::ToBufferOp>(
2916+ loc, cast<mlir::bufferization::BufferLikeType>(resMemRefType),
2917+ qkTensor);
29162918 Value outMemref = preSoftmaxElemwiseBlock->addArgument (resMemRefType, loc);
29172919 builder.create <memref::CopyOp>(loc, resMemref, outMemref);
29182920 builder.create <rock::YieldOp>(loc);
@@ -3002,8 +3004,9 @@ createGpuConvElementwiseGemmKernel(ModuleOp module, const GenParams ¶ms) {
30023004 MemRefType resMemRefType =
30033005 MemRefType::get ({aShape[0 ], firstGemmSize.m , firstGemmSize.n },
30043006 cast<ShapedType>(abTensor.getType ()).getElementType ());
3005- Value resMemref =
3006- builder.create <bufferization::ToBufferOp>(loc, resMemRefType, abTensor);
3007+ Value resMemref = builder.create <bufferization::ToBufferOp>(
3008+ loc, cast<mlir::bufferization::BufferLikeType>(resMemRefType),
3009+ abTensor);
30073010 Value outMemref = preSecondGemmBlock->addArgument (resMemRefType, loc);
30083011 builder.create <memref::CopyOp>(loc, resMemref, outMemref);
30093012 builder.create <rock::YieldOp>(loc);
@@ -3098,8 +3101,9 @@ createGpuGemmElementwiseGemmKernel(ModuleOp module, const GenParams ¶ms) {
30983101 MemRefType resMemRefType =
30993102 MemRefType::get ({aShape[0 ], gemmM, gemmN},
31003103 cast<ShapedType>(abTensor.getType ()).getElementType ());
3101- Value resMemref =
3102- builder.create <bufferization::ToBufferOp>(loc, resMemRefType, abTensor);
3104+ Value resMemref = builder.create <bufferization::ToBufferOp>(
3105+ loc, cast<mlir::bufferization::BufferLikeType>(resMemRefType),
3106+ abTensor);
31033107 Value outMemref = preSecondGemmBlock->addArgument (resMemRefType, loc);
31043108 builder.create <memref::CopyOp>(loc, resMemref, outMemref);
31053109 builder.create <rock::YieldOp>(loc);
@@ -3280,7 +3284,7 @@ createCpuConvElementwiseGemmKernelWithMlir(ModuleOp module,
32803284 bool isWritable = false ) {
32813285 constexpr bool isRestrict{true };
32823286 Value flatTensor = builder.create <bufferization::ToTensorOp>(
3283- loc, block->getArgument (blockArgIndex).getType (),
3287+ loc, memref::getTensorTypeFromMemRefType ( block->getArgument (blockArgIndex).getType () ),
32843288 block->getArgument (blockArgIndex), isRestrict, isWritable);
32853289 ArrayRef<int64_t > origShape =
32863290 cast<ShapedType>(argTypes[blockArgIndex]).getShape ();
@@ -3418,11 +3422,11 @@ createCpuConvElementwiseGemmKernelWithMlir(ModuleOp module,
34183422 }
34193423
34203424 Value output = block->getArguments ().back ();
3421- auto outputType = cast<MemRefType >(output.getType ());
3425+ auto outputType = cast<bufferization::BufferLikeType >(output.getType ());
34223426
34233427 ImplicitLocOpBuilder implicitBuilder (loc, builder);
3424- auto shapeValue =
3425- tosa::getTosaConstShape ( implicitBuilder, outputType.getShape ());
3428+ auto shapeValue = tosa::getTosaConstShape (
3429+ implicitBuilder, cast<ShapedType>( outputType) .getShape ());
34263430 auto flatResultTensor =
34273431 builder.create <tosa::ReshapeOp>(loc, resultTensor, shapeValue);
34283432
@@ -3460,7 +3464,7 @@ createCpuGemmElementwiseGemmKernelWithMlir(ModuleOp module,
34603464 bool isWritable = false ) {
34613465 constexpr bool isRestrict{true };
34623466 Value flatTensor = builder.create <bufferization::ToTensorOp>(
3463- loc, block->getArgument (blockArgIndex).getType (),
3467+ loc, memref::getTensorTypeFromMemRefType ( block->getArgument (blockArgIndex).getType () ),
34643468 block->getArgument (blockArgIndex), isRestrict, isWritable);
34653469 ArrayRef<int64_t > origShape =
34663470 cast<ShapedType>(argTypes[blockArgIndex]).getShape ();
@@ -3534,11 +3538,11 @@ createCpuGemmElementwiseGemmKernelWithMlir(ModuleOp module,
35343538 }
35353539
35363540 Value output = block->getArguments ().back ();
3537- auto outputType = cast<MemRefType >(output.getType ());
3541+ auto outputType = cast<mlir::bufferization::BufferLikeType >(output.getType ());
35383542
35393543 ImplicitLocOpBuilder implicitBuilder (loc, builder);
3540- auto shapeValue =
3541- tosa::getTosaConstShape ( implicitBuilder, outputType.getShape ());
3544+ auto shapeValue = tosa::getTosaConstShape (
3545+ implicitBuilder, cast<ShapedType>( outputType) .getShape ());
35423546 auto flatResultTensor =
35433547 builder.create <tosa::ReshapeOp>(loc, resultTensor, shapeValue);
35443548
@@ -3576,7 +3580,7 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module,
35763580 bool isWritable = false ) {
35773581 constexpr bool isRestrict{true };
35783582 Value flatTensor = builder.create <bufferization::ToTensorOp>(
3579- loc, block->getArgument (blockArgIndex).getType (),
3583+ loc, memref::getTensorTypeFromMemRefType ( block->getArgument (blockArgIndex).getType () ),
35803584 block->getArgument (blockArgIndex), isRestrict, isWritable);
35813585 ArrayRef<int64_t > origShape =
35823586 cast<ShapedType>(argTypes[blockArgIndex]).getShape ();
@@ -3792,10 +3796,10 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module,
37923796 }
37933797
37943798 Value output = block->getArguments ().back ();
3795- auto outputType = cast<MemRefType >(output.getType ());
3799+ auto outputType = cast<mlir::bufferization::BufferLikeType >(output.getType ());
37963800 ImplicitLocOpBuilder implicitBuilder (loc, builder);
3797- auto shapeValue =
3798- tosa::getTosaConstShape ( implicitBuilder, outputType.getShape ());
3801+ auto shapeValue = tosa::getTosaConstShape (
3802+ implicitBuilder, cast<ShapedType>( outputType) .getShape ());
37993803 auto flatResultTensor =
38003804 builder.create <tosa::ReshapeOp>(loc, resultTensor, shapeValue);
38013805
@@ -3806,9 +3810,9 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module,
38063810
38073811 // return LSE (log-sum-exp)
38083812 if (returnLSE) {
3809- auto lseOutType = cast<MemRefType >(lseOut.getType ());
3810- auto lseShapeValue =
3811- tosa::getTosaConstShape ( implicitBuilder, lseOutType.getShape ());
3813+ auto lseOutType = cast<bufferization::BufferLikeType >(lseOut.getType ());
3814+ auto lseShapeValue = tosa::getTosaConstShape (
3815+ implicitBuilder, cast<ShapedType>( lseOutType) .getShape ());
38123816 auto flatLseTensor =
38133817 builder.create <tosa::ReshapeOp>(loc, lseTensor, lseShapeValue);
38143818
0 commit comments