diff --git a/src/passes/Flatten.cpp b/src/passes/Flatten.cpp index 1c2cfbcd536..38020e6a3b8 100644 --- a/src/passes/Flatten.cpp +++ b/src/passes/Flatten.cpp @@ -37,6 +37,7 @@ // local, but in principle there's no reason it couldn't be. For now, error on // this. +#include "wasm-type.h" #include #include #include @@ -308,6 +309,118 @@ struct Flatten } } + } else if (auto* br = curr->dynCast()) { + if (br->op == BrOnOp::BrOnNull || br->op == BrOnOp::BrOnNonNull) { + auto nullableType = br->ref->type; + + Index nullableTemp = builder.addVar(getFunction(), nullableType); + ourPreludes.push_back(builder.makeLocalSet(nullableTemp, br->ref)); + + Index isNullTemp = builder.addVar(getFunction(), Type::i32); + ourPreludes.push_back( + builder.makeLocalSet(isNullTemp, + builder.makeRefIsNull(builder.makeLocalGet( + nullableTemp, nullableType)))); + + if (br->op == BrOnOp::BrOnNull) { + ourPreludes.push_back(builder.makeBreak( + br->name, nullptr, builder.makeLocalGet(isNullTemp, Type::i32))); + + replaceCurrent(builder.makeRefAs( + RefAsOp::RefAsNonNull, + builder.makeLocalGet(nullableTemp, nullableType))); + } else { // br_on_non_null + Index isNotNullTemp = builder.addVar(getFunction(), Type::i32); + ourPreludes.push_back(builder.makeLocalSet( + isNotNullTemp, + builder.makeUnary(UnaryOp::EqZInt32, + builder.makeLocalGet(isNullTemp, Type::i32)))); + + Index breakTargetTemp = getTempForBreakTarget( + br->name, nullableType.with(Nullability::NonNullable)); + + std::vector successBlock; + successBlock.push_back(builder.makeLocalSet( + breakTargetTemp, + builder.makeRefAs( + RefAsOp::RefAsNonNull, + builder.makeLocalGet(nullableTemp, nullableType)))); + successBlock.push_back(builder.makeBreak(br->name)); + + replaceCurrent( + builder.makeIf(builder.makeLocalGet(isNotNullTemp, Type::i32), + builder.makeBlock(successBlock))); + } + } else if (br->op == BrOnCast || br->op == BrOnCastFail) { + auto sourceType = br->ref->type; + auto targetType = br->castType; + + Index sourceTypeTemp = builder.addVar(getFunction(), sourceType); + ourPreludes.push_back(builder.makeLocalSet(sourceTypeTemp, br->ref)); + + Index typeTestTemp = builder.addVar(getFunction(), Type::i32); + ourPreludes.push_back(builder.makeLocalSet( + typeTestTemp, + builder.makeRefTest( + builder.makeLocalGet(sourceTypeTemp, sourceType), targetType))); + + // On cast failure the source type is made non-nullable if the target + // type is nullable. + Expression* failValue = + builder.makeLocalGet(sourceTypeTemp, sourceType); + if (br->castType.isNullable()) { + failValue = builder.makeRefAs(RefAsOp::RefAsNonNull, failValue); + } + + if (br->op == BrOnCast) { + Index breakTargetTemp = getTempForBreakTarget(br->name, targetType); + + std::vector successBlock; + successBlock.push_back(builder.makeLocalSet( + breakTargetTemp, + builder.makeRefCast( + builder.makeLocalGet(sourceTypeTemp, sourceType), targetType))); + successBlock.push_back(builder.makeBreak(br->name)); + + ourPreludes.push_back( + builder.makeIf(builder.makeLocalGet(typeTestTemp, Type::i32), + builder.makeBlock(successBlock))); + + Index failTemp = builder.addVar(getFunction(), failValue->type); + ourPreludes.push_back(builder.makeLocalSet(failTemp, failValue)); + replaceCurrent(builder.makeLocalGet(failTemp, failValue->type)); + } else { // br_on_cast_fail + Index breakTargetTemp = + getTempForBreakTarget(br->name, failValue->type); + + std::vector failureBlock; + failureBlock.push_back( + builder.makeLocalSet(breakTargetTemp, failValue)); + failureBlock.push_back(builder.makeBreak(br->name)); + + Index typeTestFailTemp = builder.addVar(getFunction(), Type::i32); + ourPreludes.push_back(builder.makeLocalSet( + typeTestFailTemp, + builder.makeUnary( + UnaryOp::EqZInt32, + builder.makeLocalGet(typeTestTemp, Type::i32)))); + + ourPreludes.push_back( + builder.makeIf(builder.makeLocalGet(typeTestFailTemp, Type::i32), + builder.makeBlock(failureBlock))); + + Index targetTypeTemp = builder.addVar(getFunction(), targetType); + ourPreludes.push_back(builder.makeLocalSet( + targetTypeTemp, + builder.makeRefCast( + builder.makeLocalGet(sourceTypeTemp, sourceType), targetType))); + + replaceCurrent(builder.makeLocalGet(targetTypeTemp, targetType)); + } + + } else { + Fatal() << "Unsupported instruction for Flatten: BrOn " << br->op; + } } else if (auto* sw = curr->dynCast()) { if (sw->value) { auto type = sw->value->type; @@ -333,7 +446,7 @@ struct Flatten } } - if (curr->is() || curr->is()) { + if (curr->is()) { Fatal() << "Unsupported instruction for Flatten: " << getExpressionName(curr); } diff --git a/test/lit/passes/flatten_br_on_cast.wast b/test/lit/passes/flatten_br_on_cast.wast new file mode 100644 index 00000000000..e8308ef5e6a --- /dev/null +++ b/test/lit/passes/flatten_br_on_cast.wast @@ -0,0 +1,84 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. +;; NOTE: This test was ported using port_passes_tests_to_lit.py and could be cleaned up. + +;; RUN: foreach %s %t wasm-opt --flatten --all-features -S -o - | filecheck %s + +(module + ;; CHECK: (type $s (sub (struct))) + (type $s (sub (struct))) + ;; CHECK: (type $t (sub $s (struct))) + (type $t (sub $s (struct))) + + ;; CHECK: (type $2 (func (param (ref $s)) (result (ref $t)))) + + ;; CHECK: (func $br_on_cast (type $2) (param $x (ref $s)) (result (ref $t)) + ;; CHECK-NEXT: (local $1 (ref $s)) + ;; CHECK-NEXT: (local $2 (ref $s)) + ;; CHECK-NEXT: (local $3 i32) + ;; CHECK-NEXT: (local $4 (ref null $t)) + ;; CHECK-NEXT: (local $5 (ref $s)) + ;; CHECK-NEXT: (local $6 (ref $s)) + ;; CHECK-NEXT: (local $7 (ref (exact $t))) + ;; CHECK-NEXT: (local $8 (ref $t)) + ;; CHECK-NEXT: (block $label0 + ;; CHECK-NEXT: (local.set $1 + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $2 + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $3 + ;; CHECK-NEXT: (ref.test (ref $t) + ;; CHECK-NEXT: (local.get $2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (if + ;; CHECK-NEXT: (local.get $3) + ;; CHECK-NEXT: (then + ;; CHECK-NEXT: (local.set $4 + ;; CHECK-NEXT: (ref.cast (ref $t) + ;; CHECK-NEXT: (local.get $2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (br $label0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $5 + ;; CHECK-NEXT: (local.get $2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $6 + ;; CHECK-NEXT: (local.get $5) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (local.get $6) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $7 + ;; CHECK-NEXT: (struct.new_default $t) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $4 + ;; CHECK-NEXT: (local.get $7) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $8 + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $4) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (return + ;; CHECK-NEXT: (local.get $8) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + (func $br_on_cast (param $x (ref $s)) (result (ref $t)) + (return + (block $label0 (result (ref $t)) + (drop + (br_on_cast $label0 (ref $s) (ref $t) + (local.get $x) + ) + ) + (struct.new $t)) + ) + ) +) + diff --git a/test/lit/passes/flatten_br_on_cast_fail.wast b/test/lit/passes/flatten_br_on_cast_fail.wast new file mode 100644 index 00000000000..8c13eed805a --- /dev/null +++ b/test/lit/passes/flatten_br_on_cast_fail.wast @@ -0,0 +1,100 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. +;; NOTE: This test was ported using port_passes_tests_to_lit.py and could be cleaned up. + +;; RUN: foreach %s %t wasm-opt --flatten --all-features -S -o - | filecheck %s + +(module + ;; CHECK: (type $s (sub (struct))) + (type $s (sub (struct))) + ;; CHECK: (type $t (sub $s (struct))) + (type $t (sub $s (struct))) + + ;; CHECK: (type $2 (func (param (ref $s)) (result (ref $t)))) + + ;; CHECK: (func $br_on_cast_fail (type $2) (param $x (ref $s)) (result (ref $t)) + ;; CHECK-NEXT: (local $1 (ref $s)) + ;; CHECK-NEXT: (local $2 (ref $s)) + ;; CHECK-NEXT: (local $3 i32) + ;; CHECK-NEXT: (local $4 (ref null $s)) + ;; CHECK-NEXT: (local $5 i32) + ;; CHECK-NEXT: (local $6 (ref $t)) + ;; CHECK-NEXT: (local $7 (ref $t)) + ;; CHECK-NEXT: (local $8 (ref $s)) + ;; CHECK-NEXT: (local $9 (ref (exact $t))) + ;; CHECK-NEXT: (local $10 (ref $t)) + ;; CHECK-NEXT: (local $11 (ref $t)) + ;; CHECK-NEXT: (block + ;; CHECK-NEXT: (block $label0 + ;; CHECK-NEXT: (local.set $1 + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $2 + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $3 + ;; CHECK-NEXT: (ref.test (ref $t) + ;; CHECK-NEXT: (local.get $2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $5 + ;; CHECK-NEXT: (i32.eqz + ;; CHECK-NEXT: (local.get $3) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (if + ;; CHECK-NEXT: (local.get $5) + ;; CHECK-NEXT: (then + ;; CHECK-NEXT: (local.set $4 + ;; CHECK-NEXT: (local.get $2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (br $label0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $6 + ;; CHECK-NEXT: (ref.cast (ref $t) + ;; CHECK-NEXT: (local.get $2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $7 + ;; CHECK-NEXT: (local.get $6) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (return + ;; CHECK-NEXT: (local.get $7) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $8 + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $4) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (local.get $8) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $9 + ;; CHECK-NEXT: (struct.new_default $t) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $10 + ;; CHECK-NEXT: (local.get $9) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $11 + ;; CHECK-NEXT: (local.get $10) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (return + ;; CHECK-NEXT: (local.get $11) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $br_on_cast_fail (param $x (ref $s)) (result (ref $t)) + (drop + (block $label0 (result (ref $s)) + (return + (br_on_cast_fail $label0 (ref $s) (ref $t) + (local.get $x) + ) + ) + ) + ) + (struct.new $t) + ) +) diff --git a/test/lit/passes/flatten_br_on_non_null.wast b/test/lit/passes/flatten_br_on_non_null.wast new file mode 100644 index 00000000000..b58d6a49f23 --- /dev/null +++ b/test/lit/passes/flatten_br_on_non_null.wast @@ -0,0 +1,77 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. +;; NOTE: This test was ported using port_passes_tests_to_lit.py and could be cleaned up. + +;; RUN: foreach %s %t wasm-opt --flatten --all-features -S -o - | filecheck %s + +(module + ;; CHECK: (type $s (sub (struct))) + (type $s (sub (struct))) + (type $t (sub $s (struct))) + + ;; CHECK: (type $1 (func (param (ref null $s)) (result (ref $s)))) + + ;; CHECK: (func $br_on_non_null (type $1) (param $x (ref null $s)) (result (ref $s)) + ;; CHECK-NEXT: (local $1 (ref null $s)) + ;; CHECK-NEXT: (local $2 (ref null $s)) + ;; CHECK-NEXT: (local $3 i32) + ;; CHECK-NEXT: (local $4 i32) + ;; CHECK-NEXT: (local $5 (ref null $s)) + ;; CHECK-NEXT: (local $6 (ref (exact $s))) + ;; CHECK-NEXT: (local $7 (ref $s)) + ;; CHECK-NEXT: (block $label0 + ;; CHECK-NEXT: (local.set $1 + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $2 + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $3 + ;; CHECK-NEXT: (ref.is_null + ;; CHECK-NEXT: (local.get $2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $4 + ;; CHECK-NEXT: (i32.eqz + ;; CHECK-NEXT: (local.get $3) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (if + ;; CHECK-NEXT: (local.get $4) + ;; CHECK-NEXT: (then + ;; CHECK-NEXT: (local.set $5 + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (br $label0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $6 + ;; CHECK-NEXT: (struct.new_default $s) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $5 + ;; CHECK-NEXT: (local.get $6) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $7 + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $5) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (return + ;; CHECK-NEXT: (local.get $7) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + (func $br_on_non_null (param $x (ref null $s)) (result (ref $s)) + (return + (block $label0 (result (ref $s)) + (br_on_non_null $label0 + (local.get $x) + ) + (struct.new $s) + ) + ) + ) +) + diff --git a/test/lit/passes/flatten_br_on_null.wast b/test/lit/passes/flatten_br_on_null.wast new file mode 100644 index 00000000000..1edc253f11f --- /dev/null +++ b/test/lit/passes/flatten_br_on_null.wast @@ -0,0 +1,69 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. +;; NOTE: This test was ported using port_passes_tests_to_lit.py and could be cleaned up. + +;; RUN: foreach %s %t wasm-opt --flatten --all-features -S -o - | filecheck %s + +(module + ;; CHECK: (type $s (sub (struct))) + (type $s (sub (struct))) + (type $t (sub $s (struct))) + + ;; CHECK: (type $1 (func (param (ref null $s)) (result (ref $s)))) + + ;; CHECK: (func $br_on_null (type $1) (param $x (ref null $s)) (result (ref $s)) + ;; CHECK-NEXT: (local $1 (ref null $s)) + ;; CHECK-NEXT: (local $2 (ref null $s)) + ;; CHECK-NEXT: (local $3 i32) + ;; CHECK-NEXT: (local $4 (ref $s)) + ;; CHECK-NEXT: (local $5 (ref (exact $s))) + ;; CHECK-NEXT: (local $6 (ref $s)) + ;; CHECK-NEXT: (local $7 (ref $s)) + ;; CHECK-NEXT: (block + ;; CHECK-NEXT: (block $label0 + ;; CHECK-NEXT: (local.set $1 + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $2 + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $3 + ;; CHECK-NEXT: (ref.is_null + ;; CHECK-NEXT: (local.get $2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (br_if $label0 + ;; CHECK-NEXT: (local.get $3) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $4 + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (return + ;; CHECK-NEXT: (local.get $4) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $5 + ;; CHECK-NEXT: (struct.new_default $s) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $6 + ;; CHECK-NEXT: (local.get $5) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $7 + ;; CHECK-NEXT: (local.get $6) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (return + ;; CHECK-NEXT: (local.get $7) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $br_on_null (param $x (ref null $s)) (result (ref $s)) + (block $label0 + (return + (br_on_null $label0 (local.get $x)) + ) + ) + (struct.new $s) + ) +)