@@ -390,6 +390,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
390
390
TransformMapKeyTy key = {m, r};
391
391
int64_t retRows = 1 ;
392
392
Value matmulRetValue = extractFilter;
393
+ Value zero = builder.create <arith::ConstantOp>(
394
+ loc, rewriter.getZeroAttr (elementType));
393
395
if (leftTransform) {
394
396
// Get constant transform matrix G.
395
397
auto it = GMatrices.find (key);
@@ -399,8 +401,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
399
401
400
402
retRows = GMatrix.rows ;
401
403
auto matmulType = RankedTensorType::get ({retRows, filterW}, elementType);
402
- auto init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
403
- elementType);
404
+ auto empty =
405
+ builder
406
+ .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
407
+ .getResult ();
408
+ auto init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
404
409
405
410
Value G = create2DTransformMatrix (builder, loc, GMatrix, elementType);
406
411
// Multiply G x g.
@@ -418,8 +423,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
418
423
419
424
auto matmulType =
420
425
RankedTensorType::get ({retRows, GTMatrix.cols }, elementType);
421
- auto init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
422
- elementType);
426
+ auto empty =
427
+ builder
428
+ .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
429
+ .getResult ();
430
+ auto init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
423
431
424
432
Value GT = create2DTransformMatrix (builder, loc, GTMatrix, elementType);
425
433
// Multiply u = (G x g) x GT.
@@ -523,6 +531,8 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
523
531
int64_t retRows = 1 ;
524
532
int64_t retCols = 1 ;
525
533
Value matmulRetValue = extractInput;
534
+ Value zero = builder.create <arith::ConstantOp>(
535
+ loc, rewriter.getZeroAttr (elementType));
526
536
if (leftTransform) {
527
537
// Get constant transform matrix BT.
528
538
auto it = BTMatrices.find (key);
@@ -532,8 +542,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
532
542
533
543
retRows = BTMatrix.rows ;
534
544
auto matmulType = RankedTensorType::get ({retRows, alphaW}, elementType);
535
- auto init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
536
- elementType);
545
+ auto empty =
546
+ builder
547
+ .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
548
+ .getResult ();
549
+ auto init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
537
550
538
551
Value BT =
539
552
create2DTransformMatrix (builder, loc, BTMatrix, builder.getF32Type ());
@@ -552,8 +565,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
552
565
553
566
retCols = BMatrix.cols ;
554
567
auto matmulType = RankedTensorType::get ({retRows, retCols}, elementType);
555
- auto init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
556
- elementType);
568
+ auto empty =
569
+ builder
570
+ .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
571
+ .getResult ();
572
+ auto init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
557
573
Value B =
558
574
create2DTransformMatrix (builder, loc, BMatrix, builder.getF32Type ());
559
575
// Multiply v = (BT x d) x B.
@@ -636,8 +652,13 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
636
652
{inputShape[0 ] * inputShape[1 ],
637
653
inputShape[2 ] * inputShape[3 ] * inputShape[4 ], filterShape[3 ]},
638
654
outputElementType);
639
- Value init = rewriter.create <tensor::EmptyOp>(loc, matmulType.getShape (),
640
- outputElementType);
655
+ Value empty = rewriter
656
+ .create <tensor::EmptyOp>(loc, matmulType.getShape (),
657
+ outputElementType)
658
+ .getResult ();
659
+ Value zero = rewriter.create <arith::ConstantOp>(
660
+ loc, rewriter.getZeroAttr (outputElementType));
661
+ Value init = rewriter.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
641
662
642
663
auto matmulOp = rewriter.create <linalg::BatchMatmulOp>(
643
664
loc, matmulType, ValueRange ({collapseInput, collapseFilter}),
@@ -725,6 +746,8 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
725
746
int64_t leftScalarFactor = 1 ;
726
747
int64_t rightScalarFactor = 1 ;
727
748
Value matmulRetValue = extractValue;
749
+ Value zero = builder.create <arith::ConstantOp>(
750
+ loc, rewriter.getZeroAttr (elementType));
728
751
if (leftTransform) {
729
752
// Get constant transform matrix AT.
730
753
auto it = ATMatrices.find (key);
@@ -735,8 +758,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
735
758
leftScalarFactor = ATMatrix.scalarFactor ;
736
759
retRows = ATMatrix.rows ;
737
760
auto matmulType = RankedTensorType::get ({retRows, valueW}, elementType);
738
- auto init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
739
- elementType);
761
+ auto empty =
762
+ builder
763
+ .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
764
+ .getResult ();
765
+ auto init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
740
766
741
767
Value AT = create2DTransformMatrix (builder, loc, ATMatrix, elementType);
742
768
// Multiply AT x m.
@@ -756,8 +782,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
756
782
auto matmulType =
757
783
RankedTensorType::get ({retRows, AMatrix.cols }, elementType);
758
784
retCols = AMatrix.cols ;
759
- auto init = builder.create <tensor::EmptyOp>(loc, matmulType.getShape (),
760
- elementType);
785
+ auto empty =
786
+ builder
787
+ .create <tensor::EmptyOp>(loc, matmulType.getShape (), elementType)
788
+ .getResult ();
789
+ auto init = builder.create <linalg::FillOp>(loc, zero, empty).getResult (0 );
761
790
762
791
Value A = create2DTransformMatrix (builder, loc, AMatrix, elementType);
763
792
// Multiply y = (AT x m) x A.
0 commit comments