Skip to content

Commit 2dbf4f3

Browse files
committed
Add decomposition for microsoft FusedConv
Signed-off-by: Rickert, Jonas <[email protected]>
1 parent a9a2cec commit 2dbf4f3

File tree

2 files changed

+298
-11
lines changed

2 files changed

+298
-11
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 117 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include "src/Pass/Passes.hpp"
4747
#include "src/Support/TypeUtilities.hpp"
4848
#include "llvm/ADT/ArrayRef.h"
49+
#include "llvm/ADT/StringSet.h"
4950
#include "llvm/Support/Debug.h"
5051

5152
#define DEBUG_TYPE "decompose"
@@ -2872,15 +2873,30 @@ struct CustomOpMicrosoftToOnnxOps : public OpRewritePattern<ONNXCustomOp> {
28722873
static LogicalResult verifyOpsErasingOnError(
28732874
ValueRange values, PatternRewriter &rewriter) {
28742875
if (llvm::any_of(values, [](Value value) {
2875-
return failed(verifyOpValidity(value.getDefiningOp()));
2876+
return value && failed(verifyOpValidity(value.getDefiningOp()));
28762877
})) {
28772878
for (auto value : values)
2878-
rewriter.eraseOp(value.getDefiningOp());
2879+
if (value) {
2880+
rewriter.eraseOp(value.getDefiningOp());
2881+
}
28792882
return failure();
28802883
}
28812884
return success();
28822885
}
28832886

2887+
static SmallVector<NamedAttribute> getFilteredAttrs(
2888+
ArrayRef<NamedAttribute> attrs,
2889+
ArrayRef<StringRef> additionalAttrNamesToFilter = {}) {
2890+
static const llvm::StringSet<> commonFilter{"domain_name", "function_name",
2891+
"output_element_type", "shape_infer_pattern", "inputs_for_infer"};
2892+
return SmallVector<NamedAttribute>{llvm::make_filter_range(
2893+
attrs, [&additionalAttrNamesToFilter](NamedAttribute attr) {
2894+
return !llvm::is_contained(commonFilter, attr.getName()) &&
2895+
!llvm::is_contained(
2896+
additionalAttrNamesToFilter, attr.getName());
2897+
})};
2898+
}
2899+
28842900
const std::string operationNameToRewrite;
28852901
};
28862902

@@ -2896,8 +2912,6 @@ struct MicrosoftBiasGelu : public CustomOpMicrosoftToOnnxOps {
28962912

28972913
auto input = customOp->getOperand(0);
28982914
auto bias = customOp->getOperand(1);
2899-
auto inputType = cast<ShapedType>(input.getType());
2900-
auto biasType = cast<ShapedType>(bias.getType());
29012915
MultiDialectBuilder<OnnxBuilder> create(rewriter, customOp->getLoc());
29022916
Value biasedInput = create.onnx.add(input, bias);
29032917
Value gelu = create.onnx.gelu(biasedInput,
@@ -2911,6 +2925,103 @@ struct MicrosoftBiasGelu : public CustomOpMicrosoftToOnnxOps {
29112925
}
29122926
};
29132927

2928+
struct MicrosoftFusedConv : public CustomOpMicrosoftToOnnxOps {
2929+
MicrosoftFusedConv(MLIRContext *context, PatternBenefit benefit = 1)
2930+
: CustomOpMicrosoftToOnnxOps(context, "FusedConv", benefit) {}
2931+
2932+
LogicalResult matchAndRewriteImpl(
2933+
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
2934+
using namespace onnx_mlir;
2935+
assert(customOp.getNumOperands() >= 2 && customOp.getNumOperands() <= 4 &&
2936+
"Expected 2 to 4 operands for FusedConv");
2937+
if (customOp.getNumOperands() > 3) {
2938+
return rewriter.notifyMatchFailure(
2939+
customOp, "Decomposition does not support 'Sum/Z'");
2940+
}
2941+
2942+
assert(customOp->hasAttrOfType<StringAttr>("activation"));
2943+
assert(customOp->hasAttrOfType<ArrayAttr>("activation_params"));
2944+
2945+
const SmallVector<NamedAttribute> filteredAttrs(getFilteredAttrs(
2946+
customOp->getAttrs(), {"activation", "activation_params"}));
2947+
SmallVector<Value> convOperands{customOp.getOperands()};
2948+
Value noneBias;
2949+
if (convOperands.size() < 3) {
2950+
noneBias = rewriter.create<ONNXNoneOp>(customOp->getLoc())->getResult(0);
2951+
convOperands.push_back(noneBias);
2952+
}
2953+
2954+
auto conv = rewriter.create<ONNXConvOp>(customOp->getLoc(),
2955+
customOp->getResultTypes(), convOperands, filteredAttrs);
2956+
Value convOpResult = conv.getResult();
2957+
const auto activation =
2958+
customOp->getAttrOfType<StringAttr>("activation").strref();
2959+
auto activationParams =
2960+
customOp->getAttrOfType<ArrayAttr>("activation_params");
2961+
SmallVector<FloatAttr> activationParamsValues;
2962+
for (auto attr : activationParams) {
2963+
auto asFloatAttr = dyn_cast<FloatAttr>(attr);
2964+
assert(asFloatAttr && asFloatAttr.getType().isF32() &&
2965+
"All activation params "
2966+
"must be f32");
2967+
activationParamsValues.push_back(asFloatAttr);
2968+
}
2969+
Value activationFunc;
2970+
Value castMin;
2971+
Value castMax;
2972+
if (activation == "Relu") {
2973+
activationFunc = rewriter.create<ONNXReluOp>(
2974+
customOp->getLoc(), convOpResult.getType(), convOpResult);
2975+
} else if (activation == "Tanh") {
2976+
activationFunc = rewriter.create<ONNXTanhOp>(
2977+
customOp->getLoc(), convOpResult.getType(), convOpResult);
2978+
} else if (activation == "Sigmoid") {
2979+
activationFunc = rewriter.create<ONNXSigmoidOp>(
2980+
customOp->getLoc(), convOpResult.getType(), convOpResult);
2981+
} else if (activation == "LeakyRelu") {
2982+
assert(activationParamsValues.size() == 1 &&
2983+
"LeakyRelu must have exactly one parameter");
2984+
activationFunc = rewriter.create<ONNXLeakyReluOp>(customOp->getLoc(),
2985+
convOpResult.getType(), convOpResult, activationParamsValues[0]);
2986+
} else if (activation == "Clip") {
2987+
assert(activationParamsValues.size() == 2 &&
2988+
"Clip must have exactly two parameters");
2989+
MultiDialectBuilder<OnnxBuilder> create(rewriter, customOp->getLoc());
2990+
auto scalarType = RankedTensorType::get({}, rewriter.getF32Type());
2991+
auto minVal = create.onnx.constant(
2992+
DenseElementsAttr::get(scalarType, activationParamsValues[0]));
2993+
auto castToType =
2994+
cast<ShapedType>(convOpResult.getType()).getElementType();
2995+
castMin = create.onnx.cast(minVal, castToType);
2996+
auto maxVal = create.onnx.constant(
2997+
DenseElementsAttr::get(scalarType, activationParamsValues[1]));
2998+
castMax = create.onnx.cast(maxVal, castToType);
2999+
activationFunc = rewriter.create<ONNXClipOp>(customOp->getLoc(),
3000+
convOpResult.getType(), convOpResult, castMin, castMax);
3001+
} else if (activation == "HardSigmoid") {
3002+
assert(activationParamsValues.size() == 2 &&
3003+
"HardSigmoid must have exactly two parameters");
3004+
activationFunc = rewriter.create<ONNXHardSigmoidOp>(customOp->getLoc(),
3005+
convOpResult.getType(), convOpResult, activationParamsValues[0],
3006+
activationParamsValues[1]);
3007+
} else {
3008+
rewriter.eraseOp(conv);
3009+
if (noneBias) {
3010+
rewriter.eraseOp(noneBias.getDefiningOp());
3011+
}
3012+
return rewriter.notifyMatchFailure(customOp,
3013+
"Decomposition only supports Relu, Tanh, Sigmoid, LeakyRelu, Clip, "
3014+
"and HardSigmoid activations");
3015+
}
3016+
if (failed(verifyOpsErasingOnError(
3017+
{noneBias, conv, castMin, castMax, activationFunc}, rewriter))) {
3018+
return rewriter.notifyMatchFailure(customOp, "Failed verification");
3019+
}
3020+
rewriter.replaceOp(customOp, activationFunc);
3021+
return success();
3022+
}
3023+
};
3024+
29143025
template <typename OpToCreate>
29153026
struct CustomOpMicrosoftToSingleOnnxOp : public CustomOpMicrosoftToOnnxOps {
29163027
using CustomOpMicrosoftToOnnxOps::CustomOpMicrosoftToOnnxOps;
@@ -2922,13 +3033,7 @@ struct CustomOpMicrosoftToSingleOnnxOp : public CustomOpMicrosoftToOnnxOps {
29223033
}
29233034

29243035
const SmallVector<NamedAttribute> filteredAttrs(
2925-
llvm::make_filter_range(customOp->getAttrs(), [](NamedAttribute attr) {
2926-
return attr.getName() != "domain_name" &&
2927-
attr.getName() != "function_name" &&
2928-
attr.getName() != "output_element_type" &&
2929-
attr.getName() != "shape_infer_pattern" &&
2930-
attr.getName() != "inputs_for_infer";
2931-
}));
3036+
getFilteredAttrs(customOp->getAttrs()));
29323037

29333038
auto newOp = rewriter.create<OpToCreate>(customOp->getLoc(),
29343039
customOp->getResultTypes(), customOp.getOperands(), filteredAttrs);
@@ -3323,6 +3428,7 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
33233428
context, "DequantizeLinear");
33243429
patterns.insert<CustomOpMicrosoftToSingleOnnxOp<ONNXGeluOp>>(context, "Gelu");
33253430
patterns.insert<MicrosoftBiasGelu>(context);
3431+
patterns.insert<MicrosoftFusedConv>(context);
33263432
patterns.insert<DecomposeSlicePadPattern>(context);
33273433
patterns.insert<DecomposeScatterNDPattern>(context);
33283434
patterns.insert<SoftmaxCrossEntropyPattern>(context);

0 commit comments

Comments
 (0)