Skip to content

Commit 0cb8c96

Browse files
committed
Address review comments.
Move common pass logics to initialize() from runOnOperation().
1 parent 3966b5d commit 0cb8c96

File tree

1 file changed

+30
-25
lines changed

1 file changed

+30
-25
lines changed

mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -815,19 +815,18 @@ void mlir::configureImitateUnsupportedTypesLegality(
815815
//===----------------------------------------------------------------------===//
816816

817817
namespace {
818+
818819
struct GpuImitateUnsupportedTypesPass
819820
: public impl::GpuImitateUnsupportedTypesBase<
820821
GpuImitateUnsupportedTypesPass> {
821822
using Base::Base;
822823

823-
void runOnOperation() override {
824-
MLIRContext *ctx = &getContext();
825-
Operation *op = getOperation();
826-
827-
SmallVector<Type> sourceTypes;
828-
SmallVector<Type> targetTypes;
824+
SmallVector<Type> sourceTypes;
825+
SmallVector<Type> targetTypes;
826+
TypeConverter typeConverter;
829827

830-
// Parse source types
828+
LogicalResult initialize(MLIRContext *ctx) override {
829+
// Parse source types.
831830
for (StringRef sourceTypeStr : sourceTypeStrs) {
832831
std::optional<Type> maybeSourceType =
833832
arith::parseIntOrFloatType(ctx, sourceTypeStr);
@@ -836,7 +835,7 @@ struct GpuImitateUnsupportedTypesPass
836835
emitError(UnknownLoc::get(ctx),
837836
"could not map source type '" + sourceTypeStr +
838837
"' to a known integer or floating-point type.");
839-
return signalPassFailure();
838+
return failure();
840839
}
841840
sourceTypes.push_back(*maybeSourceType);
842841
}
@@ -847,7 +846,7 @@ struct GpuImitateUnsupportedTypesPass
847846
"nothing");
848847
}
849848

850-
// Parse target types
849+
// Parse target types.
851850
for (StringRef targetTypeStr : targetTypeStrs) {
852851
std::optional<Type> maybeTargetType =
853852
arith::parseIntOrFloatType(ctx, targetTypeStr);
@@ -856,14 +855,14 @@ struct GpuImitateUnsupportedTypesPass
856855
emitError(UnknownLoc::get(ctx),
857856
"could not map target type '" + targetTypeStr +
858857
"' to a known integer or floating-point type");
859-
return signalPassFailure();
858+
return failure();
860859
}
861860
targetTypes.push_back(*maybeTargetType);
862861

863862
if (llvm::is_contained(sourceTypes, *maybeTargetType)) {
864863
emitError(UnknownLoc::get(ctx),
865864
"target type cannot be an unsupported source type");
866-
return signalPassFailure();
865+
return failure();
867866
}
868867
}
869868
if (targetTypes.empty()) {
@@ -872,44 +871,50 @@ struct GpuImitateUnsupportedTypesPass
872871
"no target types specified, type imitation will do nothing");
873872
}
874873

875-
// Set up the type converter
876-
TypeConverter typeConverter;
874+
if (sourceTypes.size() != targetTypes.size()) {
875+
emitError(UnknownLoc::get(ctx),
876+
"source and target types must have the same size");
877+
return failure();
878+
}
879+
// Set up the type converter.
877880
populateImitateUnsupportedTypesTypeConverter(typeConverter, sourceTypes,
878881
targetTypes);
882+
return success();
883+
}
879884

880-
// Populate the conversion patterns
885+
void runOnOperation() override {
886+
MLIRContext *ctx = &getContext();
887+
Operation *op = getOperation();
888+
889+
// Populate the conversion patterns.
881890
RewritePatternSet patterns(ctx);
882891
DenseMap<StringAttr, FunctionType> convertedFuncTypes;
883892
populateImitateUnsupportedTypesConversionPatterns(
884893
patterns, typeConverter, sourceTypes, targetTypes, convertedFuncTypes);
885894

886-
// Set up conversion target and configure the legality of the conversion
895+
// Set up conversion target and configure the legality of the conversion.
887896
ConversionTarget target(*ctx);
888897
configureImitateUnsupportedTypesLegality(target, typeConverter);
889898

890-
// Apply the conversion
899+
// Apply the conversion.
891900
if (failed(applyPartialConversion(op, target, std::move(patterns))))
892-
signalPassFailure();
901+
return signalPassFailure();
893902

894903
// Post-conversion validation: check for any remaining
895-
// unrealized_conversion_cast
896-
bool hasUnresolvedCast = false;
904+
// unrealized_conversion_cast.
897905
op->walk([&](UnrealizedConversionCastOp op) {
898-
// Check if the cast is from a source type to a target type
906+
// Check if the cast is from a source type to a target type.
899907
for (auto [sourceType, targetType] :
900908
llvm::zip_equal(sourceTypes, targetTypes)) {
901909
if (getElementTypeOrSelf(op.getOperand(0).getType()) == sourceType &&
902910
getElementTypeOrSelf(op.getResult(0).getType()) == targetType) {
903911
op->emitError("unresolved unrealized_conversion_cast left in IR "
904912
"after conversion");
905-
hasUnresolvedCast = true;
913+
return signalPassFailure();
906914
}
907915
}
908916
});
909-
910-
if (hasUnresolvedCast) {
911-
signalPassFailure();
912-
}
913917
}
914918
};
919+
915920
} // namespace

0 commit comments

Comments
 (0)