@@ -390,6 +390,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
390390    TransformMapKeyTy key = {m, r};
391391    int64_t  retRows = 1 ;
392392    Value matmulRetValue = extractFilter;
393+     Value zero = builder.create <arith::ConstantOp>(
394+         loc, rewriter.getZeroAttr (elementType));
393395    if  (leftTransform) {
394396      //  Get constant transform matrix G.
395397      auto  it = GMatrices.find (key);
@@ -399,8 +401,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
399401
400402      retRows = GMatrix.rows ;
401403      auto  matmulType = RankedTensorType::get ({retRows, filterW}, elementType);
402-       auto  init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
403-                                                   elementType);
404+       auto  empty =
405+           builder
406+               .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
407+               .getResult ();
408+       auto  init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
404409
405410      Value G = create2DTransformMatrix (builder, loc, GMatrix, elementType);
406411      //  Multiply G x g.
@@ -418,8 +423,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
418423
419424      auto  matmulType =
420425          RankedTensorType::get ({retRows, GTMatrix.cols }, elementType);
421-       auto  init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
422-                                                   elementType);
426+       auto  empty =
427+           builder
428+               .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
429+               .getResult ();
430+       auto  init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
423431
424432      Value GT = create2DTransformMatrix (builder, loc, GTMatrix, elementType);
425433      //  Multiply u = (G x g) x GT.
@@ -523,6 +531,8 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
523531    int64_t  retRows = 1 ;
524532    int64_t  retCols = 1 ;
525533    Value matmulRetValue = extractInput;
534+     Value zero = builder.create <arith::ConstantOp>(
535+         loc, rewriter.getZeroAttr (elementType));
526536    if  (leftTransform) {
527537      //  Get constant transform matrix BT.
528538      auto  it = BTMatrices.find (key);
@@ -532,8 +542,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
532542
533543      retRows = BTMatrix.rows ;
534544      auto  matmulType = RankedTensorType::get ({retRows, alphaW}, elementType);
535-       auto  init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
536-                                                   elementType);
545+       auto  empty =
546+           builder
547+               .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
548+               .getResult ();
549+       auto  init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
537550
538551      Value BT =
539552          create2DTransformMatrix (builder, loc, BTMatrix, builder.getF32Type ());
@@ -552,8 +565,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
552565
553566      retCols = BMatrix.cols ;
554567      auto  matmulType = RankedTensorType::get ({retRows, retCols}, elementType);
555-       auto  init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
556-                                                   elementType);
568+       auto  empty =
569+           builder
570+               .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
571+               .getResult ();
572+       auto  init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
557573      Value B =
558574          create2DTransformMatrix (builder, loc, BMatrix, builder.getF32Type ());
559575      //  Multiply v = (BT x d) x B.
@@ -636,8 +652,13 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
636652      {inputShape[0 ] * inputShape[1 ],
637653       inputShape[2 ] * inputShape[3 ] * inputShape[4 ], filterShape[3 ]},
638654      outputElementType);
639-   Value init = rewriter.create <tensor::EmptyOp>(loc, matmulType.getShape (),
640-                                                 outputElementType);
655+   Value empty = rewriter
656+                     .create <tensor::EmptyOp>(loc, matmulType.getShape (),
657+                                              outputElementType)
658+                     .getResult ();
659+   Value zero = rewriter.create <arith::ConstantOp>(
660+       loc, rewriter.getZeroAttr (outputElementType));
661+   Value init = rewriter.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
641662
642663  auto  matmulOp = rewriter.create <linalg::BatchMatmulOp>(
643664      loc, matmulType, ValueRange ({collapseInput, collapseFilter}),
@@ -725,6 +746,8 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
725746    int64_t  leftScalarFactor = 1 ;
726747    int64_t  rightScalarFactor = 1 ;
727748    Value matmulRetValue = extractValue;
749+     Value zero = builder.create <arith::ConstantOp>(
750+         loc, rewriter.getZeroAttr (elementType));
728751    if  (leftTransform) {
729752      //  Get constant transform matrix AT.
730753      auto  it = ATMatrices.find (key);
@@ -735,8 +758,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
735758      leftScalarFactor = ATMatrix.scalarFactor ;
736759      retRows = ATMatrix.rows ;
737760      auto  matmulType = RankedTensorType::get ({retRows, valueW}, elementType);
738-       auto  init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
739-                                                   elementType);
761+       auto  empty =
762+           builder
763+               .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
764+               .getResult ();
765+       auto  init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
740766
741767      Value AT = create2DTransformMatrix (builder, loc, ATMatrix, elementType);
742768      //  Multiply AT x m.
@@ -756,8 +782,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
756782      auto  matmulType =
757783          RankedTensorType::get ({retRows, AMatrix.cols }, elementType);
758784      retCols = AMatrix.cols ;
759-       auto  init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
760-                                                   elementType);
785+       auto  empty =
786+           builder
787+               .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
788+               .getResult ();
789+       auto  init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
761790
762791      Value A = create2DTransformMatrix (builder, loc, AMatrix, elementType);
763792      //  Multiply y = (AT x m) x A.
0 commit comments