@@ -1907,7 +1907,8 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
19071907 ArrayRef<int64_t > padToMultipleOf,
19081908 ArrayRef<int64_t > nofoldFlags,
19091909 ArrayRef<Attribute> transposePaddings,
1910- StringRef copyBackOp) {
1910+ StringRef copyBackOp,
1911+ bool usePrescribedTensorShapes) {
19111912 auto resultType = transform::AnyOpType::get (b.getContext ());
19121913 return build (/* builder=*/ b,
19131914 /* result=*/ result,
@@ -1922,15 +1923,18 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
19221923 : b.getDenseI64ArrayAttr (padToMultipleOf)),
19231924 /* nofoldFlags=*/ b.getI64ArrayAttr (nofoldFlags),
19241925 /* transposePaddings=*/ b.getArrayAttr (transposePaddings),
1925- /* copyBackOp=*/ b.getStringAttr (copyBackOp));
1926+ /* copyBackOp=*/ b.getStringAttr (copyBackOp),
1927+ /* usePrescribedTensorShapes=*/
1928+ usePrescribedTensorShapes ? b.getUnitAttr () : nullptr );
19261929}
19271930
19281931void transform::PadOp::build (OpBuilder &b, OperationState &result, Value target,
19291932 ArrayRef<int64_t > paddingDimensions,
19301933 ArrayRef<OpFoldResult> mixedPadToMultipleOf,
19311934 ArrayRef<int64_t > nofoldFlags,
19321935 ArrayRef<Attribute> transposePaddings,
1933- StringRef copyBackOp) {
1936+ StringRef copyBackOp,
1937+ bool usePrescribedTensorShapes) {
19341938 auto resultType = transform::AnyOpType::get (b.getContext ());
19351939 SmallVector<int64_t > staticPadToMultipleOf;
19361940 SmallVector<Value> dynamicPadToMultipleOf;
@@ -1946,7 +1950,8 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
19461950 /* padToMultipleOf=*/ staticPadToMultipleOf,
19471951 /* nofoldFlags=*/ b.getI64ArrayAttr (nofoldFlags),
19481952 /* transposePaddings=*/ b.getArrayAttr (transposePaddings),
1949- /* copyBackOp=*/ b.getStringAttr (copyBackOp));
1953+ /* copyBackOp=*/ copyBackOp,
1954+ /* usePrescribedTensorShapes=*/ usePrescribedTensorShapes);
19501955}
19511956
19521957void PadOp::getEffects (
@@ -2051,11 +2056,32 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
20512056 } else {
20522057 llvm_unreachable (" unsupported copy_back op" );
20532058 }
2059+ // Populate `sizeToPadTo` with the dynamic tensor sizes for each operand.
2060+ bool irChanged = false ;
2061+ if (getUsePrescribedTensorShapes () &&
2062+ linalgTarget.hasPureTensorSemantics ()) {
2063+ for (OpOperand &operand : linalgTarget->getOpOperands ()) {
2064+ for (auto [i, dim] : llvm::enumerate (linalgTarget.getShape (&operand))) {
2065+ if (ShapedType::isDynamic (dim))
2066+ continue ;
2067+ options.setSizeToPadTo (operand.getOperandNumber (), i,
2068+ tensor::getMixedSize (rewriter,
2069+ operand.get ().getLoc (),
2070+ operand.get (), i));
2071+ irChanged = true ;
2072+ }
2073+ }
2074+ }
20542075
20552076 SmallVector<Value> replacements;
20562077 SmallVector<tensor::PadOp> newPadOps;
20572078 if (failed (rewriteAsPaddedOp (rewriter, linalgTarget, options, paddedOp,
20582079 replacements, newPadOps))) {
2080+ if (irChanged) {
2081+ auto diag = emitDefiniteFailure () << " failed to pad op" ;
2082+ diag.attachNote (target->getLoc ()) << " target op" ;
2083+ return diag;
2084+ }
20592085 auto diag = emitSilenceableError () << " failed to pad op" ;
20602086 diag.attachNote (target->getLoc ()) << " target op" ;
20612087 return diag;
0 commit comments