@@ -815,19 +815,18 @@ void mlir::configureImitateUnsupportedTypesLegality(
815815// ===----------------------------------------------------------------------===//
816816
817817namespace {
818+
818819struct GpuImitateUnsupportedTypesPass
819820 : public impl::GpuImitateUnsupportedTypesBase<
820821 GpuImitateUnsupportedTypesPass> {
821822 using Base::Base;
822823
823- void runOnOperation () override {
824- MLIRContext *ctx = &getContext ();
825- Operation *op = getOperation ();
826-
827- SmallVector<Type> sourceTypes;
828- SmallVector<Type> targetTypes;
824+ SmallVector<Type> sourceTypes;
825+ SmallVector<Type> targetTypes;
826+ TypeConverter typeConverter;
829827
830- // Parse source types
828+ LogicalResult initialize (MLIRContext *ctx) override {
829+ // Parse source types.
831830 for (StringRef sourceTypeStr : sourceTypeStrs) {
832831 std::optional<Type> maybeSourceType =
833832 arith::parseIntOrFloatType (ctx, sourceTypeStr);
@@ -836,7 +835,7 @@ struct GpuImitateUnsupportedTypesPass
836835 emitError (UnknownLoc::get (ctx),
837836 " could not map source type '" + sourceTypeStr +
838837 " ' to a known integer or floating-point type." );
839- return signalPassFailure ();
838+ return failure ();
840839 }
841840 sourceTypes.push_back (*maybeSourceType);
842841 }
@@ -847,7 +846,7 @@ struct GpuImitateUnsupportedTypesPass
847846 " nothing" );
848847 }
849848
850- // Parse target types
849+ // Parse target types.
851850 for (StringRef targetTypeStr : targetTypeStrs) {
852851 std::optional<Type> maybeTargetType =
853852 arith::parseIntOrFloatType (ctx, targetTypeStr);
@@ -856,14 +855,14 @@ struct GpuImitateUnsupportedTypesPass
856855 emitError (UnknownLoc::get (ctx),
857856 " could not map target type '" + targetTypeStr +
858857 " ' to a known integer or floating-point type" );
859- return signalPassFailure ();
858+ return failure ();
860859 }
861860 targetTypes.push_back (*maybeTargetType);
862861
863862 if (llvm::is_contained (sourceTypes, *maybeTargetType)) {
864863 emitError (UnknownLoc::get (ctx),
865864 " target type cannot be an unsupported source type" );
866- return signalPassFailure ();
865+ return failure ();
867866 }
868867 }
869868 if (targetTypes.empty ()) {
@@ -872,44 +871,50 @@ struct GpuImitateUnsupportedTypesPass
872871 " no target types specified, type imitation will do nothing" );
873872 }
874873
875- // Set up the type converter
876- TypeConverter typeConverter;
874+ if (sourceTypes.size () != targetTypes.size ()) {
875+ emitError (UnknownLoc::get (ctx),
876+ " source and target types must have the same size" );
877+ return failure ();
878+ }
879+ // Set up the type converter.
877880 populateImitateUnsupportedTypesTypeConverter (typeConverter, sourceTypes,
878881 targetTypes);
882+ return success ();
883+ }
879884
880- // Populate the conversion patterns
885+ void runOnOperation () override {
886+ MLIRContext *ctx = &getContext ();
887+ Operation *op = getOperation ();
888+
889+ // Populate the conversion patterns.
881890 RewritePatternSet patterns (ctx);
882891 DenseMap<StringAttr, FunctionType> convertedFuncTypes;
883892 populateImitateUnsupportedTypesConversionPatterns (
884893 patterns, typeConverter, sourceTypes, targetTypes, convertedFuncTypes);
885894
886- // Set up conversion target and configure the legality of the conversion
895+ // Set up conversion target and configure the legality of the conversion.
887896 ConversionTarget target (*ctx);
888897 configureImitateUnsupportedTypesLegality (target, typeConverter);
889898
890- // Apply the conversion
899+ // Apply the conversion.
891900 if (failed (applyPartialConversion (op, target, std::move (patterns))))
892- signalPassFailure ();
901+ return signalPassFailure ();
893902
894903 // Post-conversion validation: check for any remaining
895- // unrealized_conversion_cast
896- bool hasUnresolvedCast = false ;
904+ // unrealized_conversion_cast.
897905 op->walk ([&](UnrealizedConversionCastOp op) {
898- // Check if the cast is from a source type to a target type
906+ // Check if the cast is from a source type to a target type.
899907 for (auto [sourceType, targetType] :
900908 llvm::zip_equal (sourceTypes, targetTypes)) {
901909 if (getElementTypeOrSelf (op.getOperand (0 ).getType ()) == sourceType &&
902910 getElementTypeOrSelf (op.getResult (0 ).getType ()) == targetType) {
903911 op->emitError (" unresolved unrealized_conversion_cast left in IR "
904912 " after conversion" );
905- hasUnresolvedCast = true ;
913+ return signalPassFailure () ;
906914 }
907915 }
908916 });
909-
910- if (hasUnresolvedCast) {
911- signalPassFailure ();
912- }
913917 }
914918};
919+
915920} // namespace
0 commit comments