@@ -171,6 +171,54 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
171171 return DiagnosedSilenceableFailure::success ();
172172}
173173
174+ // / When possible, converts each `OpFoldResult` in `mixedResult` to
175+ // / an integer if the value can be statically inferred. If a result
176+ // / is a `Value` then it must be either a `ParamType` or a handle
177+ // / to an a constant like op.
178+ static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults (
179+ TransformState &state, TransformOpInterface &transformOp,
180+ ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t > &reified) {
181+ for (OpFoldResult paramOrHandle : mixedResults) {
182+ if (isa<Attribute>(paramOrHandle)) {
183+ reified.push_back (
184+ cast<IntegerAttr>(paramOrHandle.get <Attribute>()).getInt ());
185+ continue ;
186+ } else if (isa<ParamType>(paramOrHandle.get <Value>().getType ())) {
187+ ArrayRef<Attribute> params = state.getParams (paramOrHandle.get <Value>());
188+ if (params.size () != 1 )
189+ return transformOp.emitSilenceableError () << " expected a single param" ;
190+ reified.push_back (
191+ cast<IntegerAttr>(params.front ()).getValue ().getSExtValue ());
192+ continue ;
193+ }
194+
195+ Value handle = paramOrHandle.get <Value>();
196+ if (!isa<TransformHandleTypeInterface>(handle.getType ()))
197+ return transformOp.emitSilenceableError () << " unexpected value handle" ;
198+ auto payload = state.getPayloadOps (handle);
199+ if (!llvm::hasSingleElement (payload))
200+ return transformOp.emitSilenceableError ()
201+ << " requires param or handle that is mapped to 1 payload op" ;
202+
203+ Operation *paramOrHandlePayloadOp = *payload.begin ();
204+ if (paramOrHandlePayloadOp->getNumResults () != 1 ||
205+ !paramOrHandlePayloadOp->getResult (0 ).getType ().isIndex ()) {
206+ return transformOp.emitSilenceableError ()
207+ << " requires param or handle to be result of op with 1 index "
208+ " result" ;
209+ }
210+
211+ IntegerAttr attr;
212+ if (!matchPattern (paramOrHandlePayloadOp->getResult (0 ), m_Constant (&attr)))
213+ return transformOp.emitSilenceableError ()
214+ << " requires param or handle to be the result of a constant like "
215+ " op" ;
216+
217+ reified.push_back (attr.getInt ());
218+ }
219+ return DiagnosedSilenceableFailure::success ();
220+ }
221+
174222// ===----------------------------------------------------------------------===//
175223// Apply...PatternsOp
176224// ===----------------------------------------------------------------------===//
@@ -1664,6 +1712,8 @@ transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
16641712// PadOp
16651713// ===---------------------------------------------------------------------===//
16661714
1715+ static const StringLiteral kPadToMultipleOfKeyword = " pad_to_multiple_of" ;
1716+
16671717void transform::PadOp::build (OpBuilder &b, OperationState &result, Value target,
16681718 ArrayRef<int64_t > paddingDimensions,
16691719 ArrayRef<int64_t > padToMultipleOf,
@@ -1677,18 +1727,60 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
16771727 /* target=*/ target,
16781728 /* paddingValues=*/ ArrayAttr (), // let inference handle this
16791729 /* paddingDimensions=*/ b.getI64ArrayAttr (paddingDimensions),
1730+ /* padToMultipleOf=*/ ValueRange{},
16801731 /* padToMultipleOf=*/
1681- (padToMultipleOf.empty () ? ArrayAttr ()
1682- : b.getI64ArrayAttr (padToMultipleOf)),
1732+ (padToMultipleOf.empty ()
1733+ ? DenseI64ArrayAttr ()
1734+ : b.getDenseI64ArrayAttr (padToMultipleOf)),
1735+ /* packPaddings=*/ b.getI64ArrayAttr (packPaddings),
1736+ /* transposePaddings=*/ b.getArrayAttr (transposePaddings),
1737+ /* copyBackOp=*/ b.getStringAttr (copyBackOp));
1738+ }
1739+
1740+ void transform::PadOp::build (OpBuilder &b, OperationState &result, Value target,
1741+ ArrayRef<int64_t > paddingDimensions,
1742+ ArrayRef<OpFoldResult> mixedPadToMultipleOf,
1743+ ArrayRef<int64_t > packPaddings,
1744+ ArrayRef<Attribute> transposePaddings,
1745+ StringRef copyBackOp) {
1746+ auto resultType = transform::AnyOpType::get (b.getContext ());
1747+ SmallVector<int64_t > staticPadToMultipleOf;
1748+ SmallVector<Value> dynamicPadToMultipleOf;
1749+ dispatchIndexOpFoldResults (mixedPadToMultipleOf, dynamicPadToMultipleOf,
1750+ staticPadToMultipleOf);
1751+ return build (/* builder=*/ b,
1752+ /* result=*/ result,
1753+ /* types=*/ TypeRange{resultType, resultType},
1754+ /* target=*/ target,
1755+ /* paddingValues=*/ ArrayAttr (), // let inference handle this
1756+ /* paddingDimensions=*/ b.getI64ArrayAttr (paddingDimensions),
1757+ /* padToMultipleOf=*/ dynamicPadToMultipleOf,
1758+ /* padToMultipleOf=*/ staticPadToMultipleOf,
16831759 /* packPaddings=*/ b.getI64ArrayAttr (packPaddings),
16841760 /* transposePaddings=*/ b.getArrayAttr (transposePaddings),
16851761 /* copyBackOp=*/ b.getStringAttr (copyBackOp));
16861762}
16871763
1764+ void PadOp::getEffects (
1765+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1766+ consumesHandle (getTarget (), effects);
1767+ onlyReadsHandle (getPadToMultipleOf (), effects);
1768+ producesHandle (getPadded (), effects);
1769+ producesHandle (getPad (), effects);
1770+ producesHandle (getCopy (), effects);
1771+ modifiesPayload (effects);
1772+ }
1773+
1774+ SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf () {
1775+ Builder b (getContext ());
1776+ return getMixedValues (getStaticPadToMultipleOf (), getPadToMultipleOf (), b);
1777+ }
1778+
16881779DiagnosedSilenceableFailure
16891780transform::PadOp::apply (transform::TransformRewriter &rewriter,
16901781 transform::TransformResults &results,
16911782 transform::TransformState &state) {
1783+ auto transformOp = cast<TransformOpInterface>(getOperation ());
16921784 SmallVector<Operation *> paddedOps, padOps, copyBackOps;
16931785
16941786 for (Operation *target : state.getPayloadOps (getTarget ())) {
@@ -1749,10 +1841,16 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
17491841 LinalgPaddingOptions options;
17501842 options.paddingDimensions =
17511843 extractFromIntegerArrayAttr<int64_t >(getPaddingDimensions ());
1752- SmallVector<int64_t > padToMultipleOf (options.paddingDimensions .size (), 1 );
1753- if (getPadToMultipleOf ().has_value ())
1844+
1845+ SmallVector<int64_t > padToMultipleOf;
1846+ DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults (
1847+ state, transformOp, getMixedPadToMultipleOf (), padToMultipleOf);
1848+ if (!status.succeeded ())
1849+ return status;
1850+ if (padToMultipleOf.empty ())
17541851 padToMultipleOf =
1755- extractFromIntegerArrayAttr<int64_t >(*getPadToMultipleOf ());
1852+ SmallVector<int64_t >(options.paddingDimensions .size (), 1 );
1853+
17561854 options.padToMultipleOf = padToMultipleOf;
17571855 options.paddingValues = paddingValues;
17581856 options.packPaddings = packPaddings;
@@ -1819,8 +1917,8 @@ LogicalResult transform::PadOp::verify() {
18191917 " integers, found "
18201918 << getPaddingDimensions ();
18211919 }
1822- if (getPadToMultipleOf ().has_value ()) {
1823- if (getPadToMultipleOf ()-> size () != paddingDimensions.size ()) {
1920+ if (! getMixedPadToMultipleOf ().empty ()) {
1921+ if (getMixedPadToMultipleOf (). size () != paddingDimensions.size ()) {
18241922 return emitOpError () << " expects as many multiples as padding_dimensions" ;
18251923 }
18261924 }
@@ -3204,49 +3302,12 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
32043302 auto targets = state.getPayloadOps (getTarget ());
32053303 if (std::empty (targets))
32063304 return DiagnosedSilenceableFailure::success ();
3207-
3305+ auto transformOp = cast<TransformOpInterface>( getOperation ());
32083306 SmallVector<int64_t > vectorSizes;
3209- for (OpFoldResult sz : getMixedVectorSizes ()) {
3210- if (sz.is <Attribute>()) {
3211- auto attr = sz.get <Attribute>();
3212- vectorSizes.push_back (cast<IntegerAttr>(attr).getInt ());
3213- continue ;
3214- } else if (sz.is <Value>() && isa<ParamType>(sz.get <Value>().getType ())) {
3215- ArrayRef<Attribute> params = state.getParams (sz.get <Value>());
3216- if (params.size () != 1 )
3217- return emitSilenceableFailure (getLoc ()) << " expected a single param" ;
3218- vectorSizes.push_back (
3219- cast<IntegerAttr>(params.front ()).getValue ().getSExtValue ());
3220- continue ;
3221- }
3222-
3223- auto szPayloads = state.getPayloadOps (sz.get <Value>());
3224- if (!llvm::hasSingleElement (szPayloads)) {
3225- auto diag = this ->emitOpError (
3226- " requires vector size handle that is mapped to 1 payload op" );
3227- diag.attachNote (sz.get <Value>().getLoc ())
3228- << " mapped to " << llvm::range_size (szPayloads) << " payload ops" ;
3229- return DiagnosedSilenceableFailure::definiteFailure ();
3230- }
3231-
3232- Operation *szPayloadOp = *szPayloads.begin ();
3233- if (szPayloadOp->getNumResults () != 1 ||
3234- !szPayloadOp->getResult (0 ).getType ().isIndex ()) {
3235- auto diag = this ->emitOpError (
3236- " requires vector size payload op with 1 index result" );
3237- diag.attachNote (szPayloadOp->getLoc ()) << " vector size payload op" ;
3238- return DiagnosedSilenceableFailure::definiteFailure ();
3239- }
3240-
3241- IntegerAttr attr;
3242- if (!matchPattern (szPayloadOp->getResult (0 ), m_Constant (&attr))) {
3243- auto diag = this ->emitOpError (" requires constant vector size" );
3244- diag.attachNote (szPayloadOp->getLoc ()) << " vector size payload op" ;
3245- return DiagnosedSilenceableFailure::definiteFailure ();
3246- }
3247-
3248- vectorSizes.push_back (attr.getInt ());
3249- }
3307+ DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults (
3308+ state, transformOp, getMixedVectorSizes (), vectorSizes);
3309+ if (!status.succeeded ())
3310+ return status;
32503311
32513312 // TODO: Check that the correct number of vectorSizes was provided.
32523313 for (Operation *target : targets) {
0 commit comments