@@ -2415,32 +2415,62 @@ DiagnosedSilenceableFailure
24152415transform::SplitHandleOp::apply (transform::TransformRewriter &rewriter,
24162416 transform::TransformResults &results,
24172417 transform::TransformState &state) {
2418- int64_t numPayloadOps = llvm::range_size (state.getPayloadOps (getHandle ()));
2418+ int64_t numPayloads =
2419+ llvm::TypeSwitch<Type, int64_t >(getHandle ().getType ())
2420+ .Case <TransformHandleTypeInterface>([&](auto x) {
2421+ return llvm::range_size (state.getPayloadOps (getHandle ()));
2422+ })
2423+ .Case <TransformValueHandleTypeInterface>([&](auto x) {
2424+ return llvm::range_size (state.getPayloadValues (getHandle ()));
2425+ })
2426+ .Case <TransformParamTypeInterface>([&](auto x) {
2427+ return llvm::range_size (state.getParams (getHandle ()));
2428+ })
2429+ .Default ([](auto x) {
2430+ llvm_unreachable (" unknown transform dialect type interface" );
2431+ return -1 ;
2432+ });
2433+
24192434 auto produceNumOpsError = [&]() {
24202435 return emitSilenceableError ()
24212436 << getHandle () << " expected to contain " << this ->getNumResults ()
2422- << " payload ops but it contains " << numPayloadOps
2423- << " payload ops" ;
2437+ << " payloads but it contains " << numPayloads << " payloads" ;
24242438 };
24252439
24262440 // Fail if there are more payload ops than results and no overflow result was
24272441 // specified.
2428- if (numPayloadOps > getNumResults () && !getOverflowResult ().has_value ())
2442+ if (numPayloads > getNumResults () && !getOverflowResult ().has_value ())
24292443 return produceNumOpsError ();
24302444
24312445 // Fail if there are more results than payload ops. Unless:
24322446 // - "fail_on_payload_too_small" is set to "false", or
24332447 // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2434- if (numPayloadOps < getNumResults () && getFailOnPayloadTooSmall () &&
2435- (numPayloadOps != 0 || !getPassThroughEmptyHandle ()))
2448+ if (numPayloads < getNumResults () && getFailOnPayloadTooSmall () &&
2449+ (numPayloads != 0 || !getPassThroughEmptyHandle ()))
24362450 return produceNumOpsError ();
24372451
2438- // Distribute payload ops .
2439- SmallVector<SmallVector<Operation * , 1 >> resultHandles (getNumResults (), {});
2452+ // Distribute payloads .
2453+ SmallVector<SmallVector<MappedValue , 1 >> resultHandles (getNumResults (), {});
24402454 if (getOverflowResult ())
2441- resultHandles[*getOverflowResult ()].reserve (numPayloadOps -
2442- getNumResults ());
2443- for (auto &&en : llvm::enumerate (state.getPayloadOps (getHandle ()))) {
2455+ resultHandles[*getOverflowResult ()].reserve (numPayloads - getNumResults ());
2456+
2457+ auto container = [&]() {
2458+ if (isa<TransformHandleTypeInterface>(getHandle ().getType ())) {
2459+ return llvm::map_to_vector (
2460+ state.getPayloadOps (getHandle ()),
2461+ [](Operation *op) -> MappedValue { return op; });
2462+ }
2463+ if (isa<TransformValueHandleTypeInterface>(getHandle ().getType ())) {
2464+ return llvm::map_to_vector (state.getPayloadValues (getHandle ()),
2465+ [](Value v) -> MappedValue { return v; });
2466+ }
2467+ assert (isa<TransformParamTypeInterface>(getHandle ().getType ()) &&
2468+ " unsupported kind of transform dialect type" );
2469+ return llvm::map_to_vector (state.getParams (getHandle ()),
2470+ [](Attribute a) -> MappedValue { return a; });
2471+ }();
2472+
2473+ for (auto &&en : llvm::enumerate (container)) {
24442474 int64_t resultNum = en.index ();
24452475 if (resultNum >= getNumResults ())
24462476 resultNum = *getOverflowResult ();
@@ -2449,7 +2479,8 @@ transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
24492479
24502480 // Set transform op results.
24512481 for (auto &&it : llvm::enumerate (resultHandles))
2452- results.set (llvm::cast<OpResult>(getResult (it.index ())), it.value ());
2482+ results.setMappedValues (llvm::cast<OpResult>(getResult (it.index ())),
2483+ it.value ());
24532484
24542485 return DiagnosedSilenceableFailure::success ();
24552486}
@@ -2466,6 +2497,15 @@ LogicalResult transform::SplitHandleOp::verify() {
24662497 if (getOverflowResult ().has_value () &&
24672498 !(*getOverflowResult () < getNumResults ()))
24682499 return emitOpError (" overflow_result is not a valid result index" );
2500+
2501+ for (Type resultType : getResultTypes ()) {
2502+ if (implementSameTransformInterface (getHandle ().getType (), resultType))
2503+ continue ;
2504+
2505+ return emitOpError (" expects result types to implement the same transform "
2506+ " interface as the operand type" );
2507+ }
2508+
24692509 return success ();
24702510}
24712511
0 commit comments