@@ -390,6 +390,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
390390 TransformMapKeyTy key = {m, r};
391391 int64_t retRows = 1 ;
392392 Value matmulRetValue = extractFilter;
393+ Value zero = builder.create <arith::ConstantOp>(
394+ loc, rewriter.getZeroAttr (elementType));
393395 if (leftTransform) {
394396 // Get constant transform matrix G.
395397 auto it = GMatrices.find (key);
@@ -399,8 +401,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
399401
400402 retRows = GMatrix.rows ;
401403 auto matmulType = RankedTensorType::get ({retRows, filterW}, elementType);
402- auto init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
403- elementType);
404+ auto empty =
405+ builder
406+ .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
407+ .getResult ();
408+ auto init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
404409
405410 Value G = create2DTransformMatrix (builder, loc, GMatrix, elementType);
406411 // Multiply G x g.
@@ -418,8 +423,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
418423
419424 auto matmulType =
420425 RankedTensorType::get ({retRows, GTMatrix.cols }, elementType);
421- auto init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
422- elementType);
426+ auto empty =
427+ builder
428+ .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
429+ .getResult ();
430+ auto init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
423431
424432 Value GT = create2DTransformMatrix (builder, loc, GTMatrix, elementType);
425433 // Multiply u = (G x g) x GT.
@@ -523,6 +531,8 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
523531 int64_t retRows = 1 ;
524532 int64_t retCols = 1 ;
525533 Value matmulRetValue = extractInput;
534+ Value zero = builder.create <arith::ConstantOp>(
535+ loc, rewriter.getZeroAttr (elementType));
526536 if (leftTransform) {
527537 // Get constant transform matrix BT.
528538 auto it = BTMatrices.find (key);
@@ -532,8 +542,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
532542
533543 retRows = BTMatrix.rows ;
534544 auto matmulType = RankedTensorType::get ({retRows, alphaW}, elementType);
535- auto init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
536- elementType);
545+ auto empty =
546+ builder
547+ .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
548+ .getResult ();
549+ auto init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
537550
538551 Value BT =
539552 create2DTransformMatrix (builder, loc, BTMatrix, builder.getF32Type ());
@@ -552,8 +565,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
552565
553566 retCols = BMatrix.cols ;
554567 auto matmulType = RankedTensorType::get ({retRows, retCols}, elementType);
555- auto init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
556- elementType);
568+ auto empty =
569+ builder
570+ .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
571+ .getResult ();
572+ auto init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
557573 Value B =
558574 create2DTransformMatrix (builder, loc, BMatrix, builder.getF32Type ());
559575 // Multiply v = (BT x d) x B.
@@ -636,8 +652,13 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
636652 {inputShape[0 ] * inputShape[1 ],
637653 inputShape[2 ] * inputShape[3 ] * inputShape[4 ], filterShape[3 ]},
638654 outputElementType);
639- Value init = rewriter.create <tensor::EmptyOp>(loc, matmulType.getShape (),
640- outputElementType);
655+ Value empty = rewriter
656+ .create <tensor::EmptyOp>(loc, matmulType.getShape (),
657+ outputElementType)
658+ .getResult ();
659+ Value zero = rewriter.create <arith::ConstantOp>(
660+ loc, rewriter.getZeroAttr (outputElementType));
661+ Value init = rewriter.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
641662
642663 auto matmulOp = rewriter.create <linalg::BatchMatmulOp>(
643664 loc, matmulType, ValueRange ({collapseInput, collapseFilter}),
@@ -725,6 +746,8 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
725746 int64_t leftScalarFactor = 1 ;
726747 int64_t rightScalarFactor = 1 ;
727748 Value matmulRetValue = extractValue;
749+ Value zero = builder.create <arith::ConstantOp>(
750+ loc, rewriter.getZeroAttr (elementType));
728751 if (leftTransform) {
729752 // Get constant transform matrix AT.
730753 auto it = ATMatrices.find (key);
@@ -735,8 +758,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
735758 leftScalarFactor = ATMatrix.scalarFactor ;
736759 retRows = ATMatrix.rows ;
737760 auto matmulType = RankedTensorType::get ({retRows, valueW}, elementType);
738- auto init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
739- elementType);
761+ auto empty =
762+ builder
763+ .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
764+ .getResult ();
765+ auto init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
740766
741767 Value AT = create2DTransformMatrix (builder, loc, ATMatrix, elementType);
742768 // Multiply AT x m.
@@ -756,8 +782,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
756782 auto matmulType =
757783 RankedTensorType::get ({retRows, AMatrix.cols }, elementType);
758784 retCols = AMatrix.cols ;
759- auto init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
760- elementType);
785+ auto empty =
786+ builder
787+ .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
788+ .getResult ();
789+ auto init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
761790
762791 Value A = create2DTransformMatrix (builder, loc, AMatrix, elementType);
763792 // Multiply y = (AT x m) x A.
0 commit comments