4646#include " src/Pass/Passes.hpp"
4747#include " src/Support/TypeUtilities.hpp"
4848#include " llvm/ADT/ArrayRef.h"
49+ #include " llvm/ADT/StringSet.h"
4950#include " llvm/Support/Debug.h"
5051
5152#define DEBUG_TYPE " decompose"
@@ -2872,15 +2873,30 @@ struct CustomOpMicrosoftToOnnxOps : public OpRewritePattern<ONNXCustomOp> {
28722873 static LogicalResult verifyOpsErasingOnError (
28732874 ValueRange values, PatternRewriter &rewriter) {
28742875 if (llvm::any_of (values, [](Value value) {
2875- return failed (verifyOpValidity (value.getDefiningOp ()));
2876+ return value && failed (verifyOpValidity (value.getDefiningOp ()));
28762877 })) {
28772878 for (auto value : values)
2878- rewriter.eraseOp (value.getDefiningOp ());
2879+ if (value) {
2880+ rewriter.eraseOp (value.getDefiningOp ());
2881+ }
28792882 return failure ();
28802883 }
28812884 return success ();
28822885 }
28832886
2887+ static SmallVector<NamedAttribute> getFilteredAttrs (
2888+ ArrayRef<NamedAttribute> attrs,
2889+ ArrayRef<StringRef> additionalAttrNamesToFilter = {}) {
2890+ static const llvm::StringSet<> commonFilter{" domain_name" , " function_name" ,
2891+ " output_element_type" , " shape_infer_pattern" , " inputs_for_infer" };
2892+ return SmallVector<NamedAttribute>{llvm::make_filter_range (
2893+ attrs, [&additionalAttrNamesToFilter](NamedAttribute attr) {
2894+ return !llvm::is_contained (commonFilter, attr.getName ()) &&
2895+ !llvm::is_contained (
2896+ additionalAttrNamesToFilter, attr.getName ());
2897+ })};
2898+ }
2899+
28842900 const std::string operationNameToRewrite;
28852901};
28862902
@@ -2896,8 +2912,6 @@ struct MicrosoftBiasGelu : public CustomOpMicrosoftToOnnxOps {
28962912
28972913 auto input = customOp->getOperand (0 );
28982914 auto bias = customOp->getOperand (1 );
2899- auto inputType = cast<ShapedType>(input.getType ());
2900- auto biasType = cast<ShapedType>(bias.getType ());
29012915 MultiDialectBuilder<OnnxBuilder> create (rewriter, customOp->getLoc ());
29022916 Value biasedInput = create.onnx .add (input, bias);
29032917 Value gelu = create.onnx .gelu (biasedInput,
@@ -2911,6 +2925,103 @@ struct MicrosoftBiasGelu : public CustomOpMicrosoftToOnnxOps {
29112925 }
29122926};
29132927
2928+ struct MicrosoftFusedConv : public CustomOpMicrosoftToOnnxOps {
2929+ MicrosoftFusedConv (MLIRContext *context, PatternBenefit benefit = 1 )
2930+ : CustomOpMicrosoftToOnnxOps(context, " FusedConv" , benefit) {}
2931+
2932+ LogicalResult matchAndRewriteImpl (
2933+ ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
2934+ using namespace onnx_mlir ;
2935+ assert (customOp.getNumOperands () >= 2 && customOp.getNumOperands () <= 4 &&
2936+ " Expected 2 to 4 operands for FusedConv" );
2937+ if (customOp.getNumOperands () > 3 ) {
2938+ return rewriter.notifyMatchFailure (
2939+ customOp, " Decomposition does not support 'Sum/Z'" );
2940+ }
2941+
2942+ assert (customOp->hasAttrOfType <StringAttr>(" activation" ));
2943+ assert (customOp->hasAttrOfType <ArrayAttr>(" activation_params" ));
2944+
2945+ const SmallVector<NamedAttribute> filteredAttrs (getFilteredAttrs (
2946+ customOp->getAttrs (), {" activation" , " activation_params" }));
2947+ SmallVector<Value> convOperands{customOp.getOperands ()};
2948+ Value noneBias;
2949+ if (convOperands.size () < 3 ) {
2950+ noneBias = rewriter.create <ONNXNoneOp>(customOp->getLoc ())->getResult (0 );
2951+ convOperands.push_back (noneBias);
2952+ }
2953+
2954+ auto conv = rewriter.create <ONNXConvOp>(customOp->getLoc (),
2955+ customOp->getResultTypes (), convOperands, filteredAttrs);
2956+ Value convOpResult = conv.getResult ();
2957+ const auto activation =
2958+ customOp->getAttrOfType <StringAttr>(" activation" ).strref ();
2959+ auto activationParams =
2960+ customOp->getAttrOfType <ArrayAttr>(" activation_params" );
2961+ SmallVector<FloatAttr> activationParamsValues;
2962+ for (auto attr : activationParams) {
2963+ auto asFloatAttr = dyn_cast<FloatAttr>(attr);
2964+ assert (asFloatAttr && asFloatAttr.getType ().isF32 () &&
2965+ " All activation params "
2966+ " must be f32" );
2967+ activationParamsValues.push_back (asFloatAttr);
2968+ }
2969+ Value activationFunc;
2970+ Value castMin;
2971+ Value castMax;
2972+ if (activation == " Relu" ) {
2973+ activationFunc = rewriter.create <ONNXReluOp>(
2974+ customOp->getLoc (), convOpResult.getType (), convOpResult);
2975+ } else if (activation == " Tanh" ) {
2976+ activationFunc = rewriter.create <ONNXTanhOp>(
2977+ customOp->getLoc (), convOpResult.getType (), convOpResult);
2978+ } else if (activation == " Sigmoid" ) {
2979+ activationFunc = rewriter.create <ONNXSigmoidOp>(
2980+ customOp->getLoc (), convOpResult.getType (), convOpResult);
2981+ } else if (activation == " LeakyRelu" ) {
2982+ assert (activationParamsValues.size () == 1 &&
2983+ " LeakyRelu must have exactly one parameter" );
2984+ activationFunc = rewriter.create <ONNXLeakyReluOp>(customOp->getLoc (),
2985+ convOpResult.getType (), convOpResult, activationParamsValues[0 ]);
2986+ } else if (activation == " Clip" ) {
2987+ assert (activationParamsValues.size () == 2 &&
2988+ " Clip must have exactly two parameters" );
2989+ MultiDialectBuilder<OnnxBuilder> create (rewriter, customOp->getLoc ());
2990+ auto scalarType = RankedTensorType::get ({}, rewriter.getF32Type ());
2991+ auto minVal = create.onnx .constant (
2992+ DenseElementsAttr::get (scalarType, activationParamsValues[0 ]));
2993+ auto castToType =
2994+ cast<ShapedType>(convOpResult.getType ()).getElementType ();
2995+ castMin = create.onnx .cast (minVal, castToType);
2996+ auto maxVal = create.onnx .constant (
2997+ DenseElementsAttr::get (scalarType, activationParamsValues[1 ]));
2998+ castMax = create.onnx .cast (maxVal, castToType);
2999+ activationFunc = rewriter.create <ONNXClipOp>(customOp->getLoc (),
3000+ convOpResult.getType (), convOpResult, castMin, castMax);
3001+ } else if (activation == " HardSigmoid" ) {
3002+ assert (activationParamsValues.size () == 2 &&
3003+ " HardSigmoid must have exactly two parameters" );
3004+ activationFunc = rewriter.create <ONNXHardSigmoidOp>(customOp->getLoc (),
3005+ convOpResult.getType (), convOpResult, activationParamsValues[0 ],
3006+ activationParamsValues[1 ]);
3007+ } else {
3008+ rewriter.eraseOp (conv);
3009+ if (noneBias) {
3010+ rewriter.eraseOp (noneBias.getDefiningOp ());
3011+ }
3012+ return rewriter.notifyMatchFailure (customOp,
3013+ " Decomposition only supports Relu, Tanh, Sigmoid, LeakyRelu, Clip, "
3014+ " and HardSigmoid activations" );
3015+ }
3016+ if (failed (verifyOpsErasingOnError (
3017+ {noneBias, conv, castMin, castMax, activationFunc}, rewriter))) {
3018+ return rewriter.notifyMatchFailure (customOp, " Failed verification" );
3019+ }
3020+ rewriter.replaceOp (customOp, activationFunc);
3021+ return success ();
3022+ }
3023+ };
3024+
29143025template <typename OpToCreate>
29153026struct CustomOpMicrosoftToSingleOnnxOp : public CustomOpMicrosoftToOnnxOps {
29163027 using CustomOpMicrosoftToOnnxOps::CustomOpMicrosoftToOnnxOps;
@@ -2922,13 +3033,7 @@ struct CustomOpMicrosoftToSingleOnnxOp : public CustomOpMicrosoftToOnnxOps {
29223033 }
29233034
29243035 const SmallVector<NamedAttribute> filteredAttrs (
2925- llvm::make_filter_range (customOp->getAttrs (), [](NamedAttribute attr) {
2926- return attr.getName () != " domain_name" &&
2927- attr.getName () != " function_name" &&
2928- attr.getName () != " output_element_type" &&
2929- attr.getName () != " shape_infer_pattern" &&
2930- attr.getName () != " inputs_for_infer" ;
2931- }));
3036+ getFilteredAttrs (customOp->getAttrs ()));
29323037
29333038 auto newOp = rewriter.create <OpToCreate>(customOp->getLoc (),
29343039 customOp->getResultTypes (), customOp.getOperands (), filteredAttrs);
@@ -3323,6 +3428,7 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
33233428 context, " DequantizeLinear" );
33243429 patterns.insert <CustomOpMicrosoftToSingleOnnxOp<ONNXGeluOp>>(context, " Gelu" );
33253430 patterns.insert <MicrosoftBiasGelu>(context);
3431+ patterns.insert <MicrosoftFusedConv>(context);
33263432 patterns.insert <DecomposeSlicePadPattern>(context);
33273433 patterns.insert <DecomposeScatterNDPattern>(context);
33283434 patterns.insert <SoftmaxCrossEntropyPattern>(context);
0 commit comments