diff --git a/src/ir/type-updating.cpp b/src/ir/type-updating.cpp index 4870152d74f..d402cd221cc 100644 --- a/src/ir/type-updating.cpp +++ b/src/ir/type-updating.cpp @@ -26,24 +26,37 @@ namespace wasm { -GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm) : wasm(wasm) {} - -void GlobalTypeRewriter::update() { - mapTypes(rebuildTypes(getSortedTypes(getPrivatePredecessors()))); -} - -GlobalTypeRewriter::PredecessorGraph -GlobalTypeRewriter::getPrivatePredecessors() { +GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm) + : wasm(wasm), publicGroups(wasm.features) { // Find the heap types that are not publicly observable. Even in a closed // world scenario, don't modify public types because we assume that they may // be reflected on or used for linking. Figure out where each private type // will be located in the builder. - auto typeInfo = ModuleUtils::collectHeapTypeInfo( + typeInfo = ModuleUtils::collectHeapTypeInfo( wasm, ModuleUtils::TypeInclusion::UsedIRTypes, ModuleUtils::VisibilityHandling::FindVisibility); - // Check if a type is private, by looking up its info. + std::unordered_set seenGroups; + for (auto& [type, info] : typeInfo) { + if (info.visibility == ModuleUtils::Visibility::Public) { + auto group = type.getRecGroup(); + if (seenGroups.insert(type.getRecGroup()).second) { + std::vector groupTypes(group.begin(), group.end()); + publicGroups.insert(std::move(groupTypes)); + } + } + } +} + +void GlobalTypeRewriter::update() { + mapTypes(rebuildTypes(getSortedTypes(getPrivatePredecessors()))); +} + +GlobalTypeRewriter::PredecessorGraph +GlobalTypeRewriter::getPrivatePredecessors() { + // Check if a type is private, looking for its info (if there is none, it is + // not private). auto isPublic = [&](HeapType type) { auto it = typeInfo.find(type); assert(it != typeInfo.end()); @@ -185,11 +198,8 @@ GlobalTypeRewriter::rebuildTypes(std::vector types) { << " at index " << err->index; } #endif - auto& newTypes = *buildResults; - - // TODO: It is possible that the newly built rec group matches some public rec - // group. If that is the case, we need to try a different permutation of the - // types or add a brand type to distinguish the private types. + // Ensure the new types are different from any public rec group. + const auto& newTypes = publicGroups.insert(*buildResults); // Map the old types to the new ones. TypeMap oldToNewTypes; diff --git a/src/ir/type-updating.h b/src/ir/type-updating.h index c01a97a0977..3333b288461 100644 --- a/src/ir/type-updating.h +++ b/src/ir/type-updating.h @@ -18,8 +18,11 @@ #define wasm_ir_type_updating_h #include "ir/branch-utils.h" +#include "ir/module-utils.h" #include "support/insert_ordered.h" #include "wasm-traversal.h" +#include "wasm-type-shape.h" +#include "wasm-type.h" namespace wasm { @@ -348,6 +351,13 @@ class GlobalTypeRewriter { Module& wasm; + // The module's types and their visibilities. + InsertOrderedMap typeInfo; + + // The shapes of public rec groups, so we can be sure that the rewritten + // private types do not conflict with public types. + UniqueRecGroups publicGroups; + GlobalTypeRewriter(Module& wasm); virtual ~GlobalTypeRewriter() {} diff --git a/src/passes/MinimizeRecGroups.cpp b/src/passes/MinimizeRecGroups.cpp index 13f569e464a..6a3c3223034 100644 --- a/src/passes/MinimizeRecGroups.cpp +++ b/src/passes/MinimizeRecGroups.cpp @@ -100,84 +100,6 @@ struct TypeSCCs } }; -// After all their permutations with distinct shapes have been used, different -// groups with the same shapes must be differentiated by adding in a "brand" -// type. Even with a brand mixed in, we might run out of permutations with -// distinct shapes, in which case we need a new brand type. This iterator -// provides an infinite sequence of possible brand types, prioritizing those -// with the most compact encoding. -struct BrandTypeIterator { - static constexpr Index optionCount = 18; - static constexpr std::array fieldOptions = {{ - Field(Field::i8, Mutable), - Field(Field::i16, Mutable), - Field(Type::i32, Mutable), - Field(Type::i64, Mutable), - Field(Type::f32, Mutable), - Field(Type::f64, Mutable), - Field(Type(HeapType::any, Nullable), Mutable), - Field(Type(HeapType::func, Nullable), Mutable), - Field(Type(HeapType::ext, Nullable), Mutable), - Field(Type(HeapType::none, Nullable), Mutable), - Field(Type(HeapType::nofunc, Nullable), Mutable), - Field(Type(HeapType::noext, Nullable), Mutable), - Field(Type(HeapType::any, NonNullable), Mutable), - Field(Type(HeapType::func, NonNullable), Mutable), - Field(Type(HeapType::ext, NonNullable), Mutable), - Field(Type(HeapType::none, NonNullable), Mutable), - Field(Type(HeapType::nofunc, NonNullable), Mutable), - Field(Type(HeapType::noext, NonNullable), Mutable), - }}; - - struct FieldInfo { - uint8_t index = 0; - bool immutable = false; - - operator Field() const { - auto field = fieldOptions[index]; - if (immutable) { - field.mutable_ = Immutable; - } - return field; - } - - bool advance() { - if (!immutable) { - immutable = true; - return true; - } - immutable = false; - index = (index + 1) % optionCount; - return index != 0; - } - }; - - bool useArray = false; - std::vector fields; - - HeapType operator*() const { - if (useArray) { - return Array(fields[0]); - } - return Struct(std::vector(fields.begin(), fields.end())); - } - - BrandTypeIterator& operator++() { - for (Index i = fields.size(); i > 0; --i) { - if (fields[i - 1].advance()) { - return *this; - } - } - if (useArray) { - useArray = false; - return *this; - } - fields.emplace_back(); - useArray = fields.size() == 1; - return *this; - } -}; - // Create an adjacency list with edges from supertype to subtype and from // described type to descriptor. std::vector> diff --git a/src/wasm-type-shape.h b/src/wasm-type-shape.h index e72f28dd530..b649e1b72db 100644 --- a/src/wasm-type-shape.h +++ b/src/wasm-type-shape.h @@ -18,6 +18,8 @@ #define wasm_wasm_type_shape_h #include +#include +#include #include #include "wasm-features.h" @@ -79,4 +81,100 @@ template<> class hash { } // namespace std +namespace wasm { + +// Provides an infinite sequence of possible brand types, prioritizing those +// with the most compact encoding. +struct BrandTypeIterator { + static constexpr Index optionCount = 18; + static constexpr std::array fieldOptions = {{ + Field(Field::i8, Mutable), + Field(Field::i16, Mutable), + Field(Type::i32, Mutable), + Field(Type::i64, Mutable), + Field(Type::f32, Mutable), + Field(Type::f64, Mutable), + Field(Type(HeapType::any, Nullable), Mutable), + Field(Type(HeapType::func, Nullable), Mutable), + Field(Type(HeapType::ext, Nullable), Mutable), + Field(Type(HeapType::none, Nullable), Mutable), + Field(Type(HeapType::nofunc, Nullable), Mutable), + Field(Type(HeapType::noext, Nullable), Mutable), + Field(Type(HeapType::any, NonNullable), Mutable), + Field(Type(HeapType::func, NonNullable), Mutable), + Field(Type(HeapType::ext, NonNullable), Mutable), + Field(Type(HeapType::none, NonNullable), Mutable), + Field(Type(HeapType::nofunc, NonNullable), Mutable), + Field(Type(HeapType::noext, NonNullable), Mutable), + }}; + + struct FieldInfo { + uint8_t index = 0; + bool immutable = false; + + operator Field() const { + auto field = fieldOptions[index]; + if (immutable) { + field.mutable_ = Immutable; + } + return field; + } + + bool advance() { + if (!immutable) { + immutable = true; + return true; + } + immutable = false; + index = (index + 1) % optionCount; + return index != 0; + } + }; + + bool useArray = false; + std::vector fields; + + HeapType operator*() const { + if (useArray) { + return Array(fields[0]); + } + return Struct(std::vector(fields.begin(), fields.end())); + } + + BrandTypeIterator& operator++() { + for (Index i = fields.size(); i > 0; --i) { + if (fields[i - 1].advance()) { + return *this; + } + } + if (useArray) { + useArray = false; + return *this; + } + fields.emplace_back(); + useArray = fields.size() == 1; + return *this; + } +}; + +// A set of unique rec group shapes. Upon inserting a new group of types, if it +// has the same shape as a previously inserted group, the types will be rebuilt +// with an extra brand type at the end of the group that differentiates it from +// previous group. +struct UniqueRecGroups { + std::list> groups; + std::unordered_set shapes; + + FeatureSet features; + + UniqueRecGroups(FeatureSet features) : features(features) {} + + // Insert a rec group. If it is already unique, return the original types. + // Otherwise rebuild the group make it unique and return the rebuilt types, + // including the brand. + const std::vector& insert(std::vector group); +}; + +} // namespace wasm + #endif // wasm_wasm_type_shape_h diff --git a/src/wasm/wasm-type-shape.cpp b/src/wasm/wasm-type-shape.cpp index 5541d8b72a1..d2de6505d44 100644 --- a/src/wasm/wasm-type-shape.cpp +++ b/src/wasm/wasm-type-shape.cpp @@ -370,6 +370,41 @@ bool ComparableRecGroupShape::operator>(const RecGroupShape& other) const { return GT == compareComparable(*this, other); } +const std::vector& +UniqueRecGroups::insert(std::vector types) { + auto& group = *groups.emplace(groups.end(), std::move(types)); + if (shapes.emplace(RecGroupShape(group, features)).second) { + // The types are already unique. + return group; + } + // There is a conflict. Find a brand that makes the group unique. + BrandTypeIterator brand; + group.push_back(*brand); + while (!shapes.emplace(RecGroupShape(group, features)).second) { + group.back() = *++brand; + } + // Rebuild the rec group to include the brand. Map the old types (excluding + // the brand) to their corresponding new types to preserve recursions within + // the group. + Index size = group.size(); + TypeBuilder builder(size); + std::unordered_map newTypes; + for (Index i = 0; i < size - 1; ++i) { + newTypes[group[i]] = builder[i]; + } + for (Index i = 0; i < size; ++i) { + builder[i].copy(group[i], [&](HeapType type) { + if (auto newType = newTypes.find(type); newType != newTypes.end()) { + return newType->second; + } + return type; + }); + } + builder.createRecGroup(0, size); + group = *builder.build(); + return group; +} + } // namespace wasm namespace std { diff --git a/test/lit/passes/signature-pruning-public-collision.wast b/test/lit/passes/signature-pruning-public-collision.wast new file mode 100644 index 00000000000..4b23ed900db --- /dev/null +++ b/test/lit/passes/signature-pruning-public-collision.wast @@ -0,0 +1,59 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. +;; RUN: wasm-opt %s -all --closed-world --signature-pruning --fuzz-exec -S -o - | filecheck %s + +(module + ;; CHECK: (type $public (func)) + + ;; CHECK: (rec + ;; CHECK-NEXT: (type $private (func)) + + ;; CHECK: (type $2 (struct)) + + ;; CHECK: (type $test (func (result i32))) + (type $test (func (result i32))) + + (type $public (func)) + + ;; After signature pruning this will be (func), which is the same as $public. + ;; We must make sure we keep $private a distinct type. + (type $private (func (param i32))) + + ;; CHECK: (import "" "" (func $public (type $public))) + (import "" "" (func $public (type $public))) + + ;; CHECK: (elem declare func $public) + + ;; CHECK: (export "test" (func $test)) + + ;; CHECK: (func $private (type $private) + ;; CHECK-NEXT: (local $0 i32) + ;; CHECK-NEXT: (nop) + ;; CHECK-NEXT: ) + (func $private (type $private) (param $unused i32) + (nop) + ) + + ;; CHECK: (func $test (type $test) (result i32) + ;; CHECK-NEXT: (local $0 funcref) + ;; CHECK-NEXT: (ref.test (ref $private) + ;; CHECK-NEXT: (select (result funcref) + ;; CHECK-NEXT: (ref.func $public) + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $test (export "test") (type $test) (result i32) + (local funcref) + ;; Test that $private and $public are separate types. This should return 0. + (ref.test (ref $private) + ;; Use select to prevent the ref.test from being optimized in + ;; finalization. + (select (result funcref) + (ref.func $public) + (local.get 0) + (i32.const 1) + ) + ) + ) +)