Skip to content

Commit 23050e2

Browse files
authored
Merge pull request #472 from Xilinx/jrickert.default_attrs
Handle default attributes in decomposition
2 parents aaf7f92 + efe9701 commit 23050e2

File tree

2 files changed

+41
-11
lines changed

2 files changed

+41
-11
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3071,11 +3071,22 @@ struct SimplifiedLayerNorm : public CustomOpToOnnxOps {
30713071
bias = customOp.getOperand(2);
30723072

30733073
auto epsAttr = customOp->getAttrOfType<FloatAttr>("epsilon");
3074-
assert(epsAttr && "Expected Epsilon");
3074+
if (!epsAttr)
3075+
epsAttr =
3076+
rewriter.getF32FloatAttr(9.999999747378752e-06f); // default epsilon
3077+
30753078
auto axisAttr = customOp->getAttrOfType<IntegerAttr>("axis");
3076-
assert(axisAttr && "Expected Axis");
3079+
if (!axisAttr) {
3080+
auto si64Type = rewriter.getIntegerType(64, /*isSigned=*/true);
3081+
axisAttr = rewriter.getIntegerAttr(si64Type, -1); // default axis
3082+
}
3083+
30773084
auto stashTypeAttr = customOp->getAttrOfType<IntegerAttr>("stash_type");
3078-
assert(stashTypeAttr && "Expected Stash Type");
3085+
if (!stashTypeAttr) {
3086+
auto si64Type = rewriter.getIntegerType(64, /*isSigned=*/true);
3087+
stashTypeAttr =
3088+
rewriter.getIntegerAttr(si64Type, 1); // default stash_type
3089+
}
30793090

30803091
SmallVector<Type, 2> resultTypes;
30813092
resultTypes.push_back(customOp->getResultTypes()[0]);
@@ -3134,7 +3145,9 @@ struct MicrosoftSkipLayerNorm : public CustomOpToOnnxOps {
31343145
bias = customOp.getOperand(4);
31353146

31363147
auto epsAttr = customOp->getAttrOfType<FloatAttr>("epsilon");
3137-
assert(epsAttr && "Expected Epsilon");
3148+
if (!epsAttr)
3149+
epsAttr =
3150+
rewriter.getF32FloatAttr(9.999999747378752e-06f); // default epsilon
31383151

31393152
Value skipAdd = create.onnx.add(input, skip);
31403153
Value sumIS;
@@ -3209,7 +3222,9 @@ struct MicrosoftSkipSimplifiedLayerNorm : public CustomOpToOnnxOps {
32093222
bias = customOp.getOperand(3);
32103223

32113224
auto epsAttr = customOp->getAttrOfType<FloatAttr>("epsilon");
3212-
assert(epsAttr && "Expected Epsilon");
3225+
if (!epsAttr)
3226+
epsAttr =
3227+
rewriter.getF32FloatAttr(9.999999747378752e-06f); // default epsilon
32133228

32143229
Value skipAdd = create.onnx.add(input, skip);
32153230
Value sumIS;

test/mlir/onnx/onnx_decompose_customop.mlir

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -374,13 +374,15 @@ func.func @skip_layernorm_basic(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf
374374
}
375375

376376

377+
378+
377379
// -----
378380
// SkipLayerNormalization: 4 inputs (beta), 1 output
379381

380-
func.func @skip_layernorm_beta(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>, %beta: tensor<8xf32>) -> tensor<2x4x8xf32> {
381-
%r = "onnx.Custom"(%input, %skip, %gamma, %beta) {domain_name = "com.microsoft", function_name = "SkipLayerNormalization", epsilon = 1.000000e-05 : f32} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
382+
func.func @skip_layernorm_beta_no_eps(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>, %beta: tensor<8xf32>) -> tensor<2x4x8xf32> {
383+
%r = "onnx.Custom"(%input, %skip, %gamma, %beta) {domain_name = "com.microsoft", function_name = "SkipLayerNormalization"} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
382384
onnx.Return %r : tensor<2x4x8xf32>
383-
// CHECK-LABEL: func.func @skip_layernorm_beta
385+
// CHECK-LABEL: func.func @skip_layernorm_beta_no_eps
384386
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
385387
// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
386388
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, none, none)
@@ -461,6 +463,19 @@ func.func @simplified_layernorm_basic(%input: tensor<2x4x8xf32>, %scale: tensor<
461463
// CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
462464
}
463465

466+
// -----
467+
468+
func.func @simplified_layernorm_no_attrs(%input: tensor<2x4x8xf32>, %scale: tensor<8xf32>) -> tensor<2x4x8xf32> {
469+
%r = "onnx.Custom"(%input, %scale) {domain_name = "", function_name = "SimplifiedLayerNormalization"} : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
470+
onnx.Return %r : tensor<2x4x8xf32>
471+
// CHECK-LABEL: func.func @simplified_layernorm_no_attrs
472+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
473+
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
474+
// CHECK: [[VAR_Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, none) -> (tensor<2x4x8xf32>, none)
475+
// CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
476+
}
477+
478+
464479

465480
// -----
466481
// SimplifiedLayerNormalization: 3 inputs (with bias), 1 output
@@ -517,10 +532,10 @@ func.func @simplified_layernorm_two_outputs_mean_used(%input: tensor<2x4x8xf32>,
517532
// -----
518533
// SkipSimplifiedLayerNormalization: 3 inputs, 1 output
519534

520-
func.func @skip_simplified_layernorm_basic(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>) -> tensor<2x4x8xf32> {
521-
%r = "onnx.Custom"(%input, %skip, %gamma) {domain_name = "com.microsoft", function_name = "SkipSimplifiedLayerNormalization", epsilon = 1.000000e-05 : f32} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
535+
func.func @skip_simplified_layernorm_basic_no_attr(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>) -> tensor<2x4x8xf32> {
536+
%r = "onnx.Custom"(%input, %skip, %gamma) {domain_name = "com.microsoft", function_name = "SkipSimplifiedLayerNormalization"} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
522537
onnx.Return %r : tensor<2x4x8xf32>
523-
// CHECK-LABEL: func.func @skip_simplified_layernorm_basic
538+
// CHECK-LABEL: func.func @skip_simplified_layernorm_basic_no_attr
524539
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
525540
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
526541
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>

0 commit comments

Comments
 (0)