Skip to content

Commit 326287f

Browse files
RoboTuxMax191
andauthored
Add missing FillOp to winograd lowering (#108181)
Winograd lowering involves a number of matmul and batch_matmul which are currently passed tensor.empty result as out parameter, thereby are undefined behaviour. This commit adds the necessary linalg.fill. --------- Co-authored-by: Max191 <[email protected]>
1 parent 4a9b6b0 commit 326287f

File tree

4 files changed

+213
-110
lines changed

4 files changed

+213
-110
lines changed

mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
390390
TransformMapKeyTy key = {m, r};
391391
int64_t retRows = 1;
392392
Value matmulRetValue = extractFilter;
393+
Value zero = builder.create<arith::ConstantOp>(
394+
loc, rewriter.getZeroAttr(elementType));
393395
if (leftTransform) {
394396
// Get constant transform matrix G.
395397
auto it = GMatrices.find(key);
@@ -399,8 +401,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
399401

400402
retRows = GMatrix.rows;
401403
auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
402-
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
403-
elementType);
404+
auto empty =
405+
builder
406+
.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
407+
.getResult();
408+
auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
404409

405410
Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
406411
// Multiply G x g.
@@ -418,8 +423,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
418423

419424
auto matmulType =
420425
RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
421-
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
422-
elementType);
426+
auto empty =
427+
builder
428+
.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
429+
.getResult();
430+
auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
423431

424432
Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
425433
// Multiply u = (G x g) x GT.
@@ -523,6 +531,8 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
523531
int64_t retRows = 1;
524532
int64_t retCols = 1;
525533
Value matmulRetValue = extractInput;
534+
Value zero = builder.create<arith::ConstantOp>(
535+
loc, rewriter.getZeroAttr(elementType));
526536
if (leftTransform) {
527537
// Get constant transform matrix BT.
528538
auto it = BTMatrices.find(key);
@@ -532,8 +542,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
532542

533543
retRows = BTMatrix.rows;
534544
auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
535-
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
536-
elementType);
545+
auto empty =
546+
builder
547+
.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
548+
.getResult();
549+
auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
537550

538551
Value BT =
539552
create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
@@ -552,8 +565,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
552565

553566
retCols = BMatrix.cols;
554567
auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
555-
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
556-
elementType);
568+
auto empty =
569+
builder
570+
.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
571+
.getResult();
572+
auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
557573
Value B =
558574
create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
559575
// Multiply v = (BT x d) x B.
@@ -636,8 +652,13 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
636652
{inputShape[0] * inputShape[1],
637653
inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
638654
outputElementType);
639-
Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
640-
outputElementType);
655+
Value empty = rewriter
656+
.create<tensor::EmptyOp>(loc, matmulType.getShape(),
657+
outputElementType)
658+
.getResult();
659+
Value zero = rewriter.create<arith::ConstantOp>(
660+
loc, rewriter.getZeroAttr(outputElementType));
661+
Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);
641662

642663
auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
643664
loc, matmulType, ValueRange({collapseInput, collapseFilter}),
@@ -725,6 +746,8 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
725746
int64_t leftScalarFactor = 1;
726747
int64_t rightScalarFactor = 1;
727748
Value matmulRetValue = extractValue;
749+
Value zero = builder.create<arith::ConstantOp>(
750+
loc, rewriter.getZeroAttr(elementType));
728751
if (leftTransform) {
729752
// Get constant transform matrix AT.
730753
auto it = ATMatrices.find(key);
@@ -735,8 +758,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
735758
leftScalarFactor = ATMatrix.scalarFactor;
736759
retRows = ATMatrix.rows;
737760
auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
738-
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
739-
elementType);
761+
auto empty =
762+
builder
763+
.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
764+
.getResult();
765+
auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
740766

741767
Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
742768
// Multiply AT x m.
@@ -756,8 +782,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
756782
auto matmulType =
757783
RankedTensorType::get({retRows, AMatrix.cols}, elementType);
758784
retCols = AMatrix.cols;
759-
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
760-
elementType);
785+
auto empty =
786+
builder
787+
.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
788+
.getResult();
789+
auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
761790

762791
Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
763792
// Multiply y = (AT x m) x A.

0 commit comments

Comments
 (0)