Skip to content

Commit 6a35e33

Browse files
authored
[GC] Fix parsing/printing of ref types using i31 (#3469)
This lets us parse (ref null i31) and (ref i31) and not just i31ref. It also fixes the parsing of i31ref, making it nullable for now, which we need to do until we support non-nullability. Fix some internal handling of i31 where we had just i31ref (which meant we just handled the non-nullable type). After fixing a bug in printing (where we didn't print out (ref null i31) properly), I found some a simplification, to remove TypeName.
1 parent 5693bc8 commit 6a35e33

13 files changed

+105
-74
lines changed

src/ir/module-utils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,10 @@ inline void collectHeapTypes(Module& wasm,
405405
std::unordered_map<HeapType, Index>& typeIndices) {
406406
struct Counts : public std::unordered_map<HeapType, size_t> {
407407
bool isRelevant(Type type) {
408-
return !type.isBasic() && (type.isRef() || type.isRtt());
408+
if (type.isRef()) {
409+
return !type.getHeapType().isBasic();
410+
}
411+
return type.isRtt();
409412
}
410413
void note(HeapType type) { (*this)[type]++; }
411414
void maybeNote(Type type) {

src/passes/Print.cpp

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,10 @@ static std::ostream& printLocal(Index index, Function* func, std::ostream& o) {
6767
return printName(name, o);
6868
}
6969

70-
// Wrapper for printing a type when we try to print the type name as much as
71-
// possible. For example, for a signature we will print the signature's name,
72-
// not its contents.
73-
struct TypeName {
74-
Type type;
75-
TypeName(Type type) : type(type) {}
76-
};
77-
7870
static void
7971
printHeapTypeName(std::ostream& os, HeapType type, bool first = true);
8072

73+
// Prints the name of a type. This output is guaranteed to not contain spaces.
8174
static void printTypeName(std::ostream& os, Type type) {
8275
if (type.isBasic()) {
8376
os << type;
@@ -131,6 +124,8 @@ static void printFieldName(std::ostream& os, const Field& field) {
131124
}
132125
}
133126

127+
// Prints the name of a heap type. As with printTypeName, this output is
128+
// guaranteed to not contain spaces.
134129
static void printHeapTypeName(std::ostream& os, HeapType type, bool first) {
135130
if (type.isBasic()) {
136131
os << type;
@@ -174,8 +169,8 @@ struct SExprType {
174169
SExprType(Type type) : type(type){};
175170
};
176171

177-
static std::ostream& operator<<(std::ostream& o, const SExprType& localType) {
178-
Type type = localType.type;
172+
static std::ostream& operator<<(std::ostream& o, const SExprType& sType) {
173+
Type type = sType.type;
179174
if (type.isTuple()) {
180175
o << '(';
181176
auto sep = "";
@@ -192,24 +187,17 @@ static std::ostream& operator<<(std::ostream& o, const SExprType& localType) {
192187
}
193188
printHeapTypeName(o, rtt.heapType);
194189
o << ')';
195-
} else {
196-
printTypeName(o, localType.type);
197-
}
198-
return o;
199-
}
200-
201-
std::ostream& operator<<(std::ostream& os, TypeName typeName) {
202-
auto type = typeName.type;
203-
if (type.isRef() && !type.isBasic()) {
204-
os << "(ref ";
190+
} else if (type.isRef() && !type.isBasic()) {
191+
o << "(ref ";
205192
if (type.isNullable()) {
206-
os << "null ";
193+
o << "null ";
207194
}
208-
printHeapTypeName(os, type.getHeapType());
209-
os << ')';
210-
return os;
195+
printHeapTypeName(o, type.getHeapType());
196+
o << ')';
197+
} else {
198+
printTypeName(o, sType.type);
211199
}
212-
return os << SExprType(typeName.type);
200+
return o;
213201
}
214202

215203
// TODO: try to simplify or even remove this, as we may be able to do the same
@@ -229,10 +217,10 @@ std::ostream& operator<<(std::ostream& os, ResultTypeName typeName) {
229217
for (auto t : type) {
230218
os << sep;
231219
sep = " ";
232-
os << TypeName(t);
220+
os << SExprType(t);
233221
}
234222
} else {
235-
os << TypeName(type);
223+
os << SExprType(type);
236224
}
237225
os << ')';
238226
return os;
@@ -2571,7 +2559,7 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
25712559
o << "(param ";
25722560
auto sep = "";
25732561
for (auto type : curr.params) {
2574-
o << sep << TypeName(type);
2562+
o << sep << SExprType(type);
25752563
sep = " ";
25762564
}
25772565
o << ')';
@@ -2581,7 +2569,7 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
25812569
o << "(result ";
25822570
auto sep = "";
25832571
for (auto type : curr.results) {
2584-
o << sep << TypeName(type);
2572+
o << sep << SExprType(type);
25852573
sep = " ";
25862574
}
25872575
o << ')';
@@ -2601,7 +2589,7 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
26012589
WASM_UNREACHABLE("invalid packed type");
26022590
}
26032591
} else {
2604-
o << TypeName(field.type);
2592+
o << SExprType(field.type);
26052593
}
26062594
if (field.mutable_) {
26072595
o << ')';
@@ -2736,7 +2724,7 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
27362724
o << '(';
27372725
printMinor(o, "param ");
27382726
printLocal(i, currFunction, o);
2739-
o << ' ' << TypeName(param) << ')';
2727+
o << ' ' << SExprType(param) << ')';
27402728
++i;
27412729
}
27422730
}
@@ -2750,7 +2738,7 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
27502738
o << '(';
27512739
printMinor(o, "local ");
27522740
printLocal(i, currFunction, o)
2753-
<< ' ' << TypeName(curr->getLocalType(i)) << ')';
2741+
<< ' ' << SExprType(curr->getLocalType(i)) << ')';
27542742
o << maybeNewLine;
27552743
}
27562744
// Print the body.

src/support/string.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ inline std::string trim(const std::string& input) {
115115
return input.substr(0, size);
116116
}
117117

118+
inline bool isNumber(const std::string& str) {
119+
return !str.empty() && std::all_of(str.begin(), str.end(), ::isdigit);
120+
}
121+
118122
} // namespace String
119123

120124
} // namespace wasm

src/wasm/wasm-binary.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1416,7 +1416,8 @@ Type WasmBinaryBuilder::getType(int initial) {
14161416
// FIXME: for now, force all inputs to be nullable
14171417
return Type(getHeapType(), Nullable);
14181418
case BinaryConsts::EncodedType::i31ref:
1419-
return Type::i31ref;
1419+
// FIXME: for now, force all inputs to be nullable
1420+
return Type(HeapType::BasicHeapType::i31, Nullable);
14201421
case BinaryConsts::EncodedType::rtt_n: {
14211422
auto depth = getU32LEB();
14221423
auto heapType = getHeapType();

src/wasm/wasm-s-parser.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include "ir/branch-utils.h"
2424
#include "shared-constants.h"
25+
#include "support/string.h"
2526
#include "wasm-binary.h"
2627
#include "wasm-builder.h"
2728

@@ -867,7 +868,8 @@ Type SExpressionWasmBuilder::stringToType(const char* str,
867868
return Type::eqref;
868869
}
869870
if (strncmp(str, "i31ref", 6) == 0 && (prefix || str[6] == 0)) {
870-
return Type::i31ref;
871+
// FIXME: for now, force all inputs to be nullable
872+
return Type(HeapType::BasicHeapType::i31, Nullable);
871873
}
872874
if (allowError) {
873875
return Type::none;
@@ -2802,12 +2804,17 @@ HeapType SExpressionWasmBuilder::parseHeapType(Element& s) {
28022804
}
28032805
return types[it->second];
28042806
} else {
2805-
// index
2806-
size_t offset = atoi(s.str().c_str());
2807-
if (offset >= types.size()) {
2808-
throw ParseException("unknown indexed function type", s.line, s.col);
2807+
// It may be a numerical index, or it may be a built-in type name like
2808+
// "i31".
2809+
auto* str = s.str().c_str();
2810+
if (String::isNumber(str)) {
2811+
size_t offset = atoi(str);
2812+
if (offset >= types.size()) {
2813+
throw ParseException("unknown indexed function type", s.line, s.col);
2814+
}
2815+
return types[offset];
28092816
}
2810-
return types[offset];
2817+
return stringToHeapType(str, /* prefix = */ false);
28112818
}
28122819
}
28132820
// It's a list.

src/wasm/wasm-type.cpp

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -481,34 +481,36 @@ Type Type::reinterpret() const {
481481
FeatureSet Type::getFeatures() const {
482482
auto getSingleFeatures = [](Type t) -> FeatureSet {
483483
if (t.isRef()) {
484-
if (t != Type::funcref && t.isFunction()) {
485-
// Strictly speaking, typed function references require the typed
486-
// function references feature, however, we use these types internally
487-
// regardless of the presence of features (in particular, since during
488-
// load of the wasm we don't know the features yet, so we apply the more
489-
// refined types).
490-
return FeatureSet::ReferenceTypes;
491-
}
484+
// A reference type implies we need that feature. Some also require more,
485+
// such as GC or exceptions.
492486
auto heapType = t.getHeapType();
493487
if (heapType.isStruct() || heapType.isArray()) {
494488
return FeatureSet::ReferenceTypes | FeatureSet::GC;
495489
}
490+
if (heapType.isBasic()) {
491+
switch (heapType.getBasic()) {
492+
case HeapType::BasicHeapType::exn:
493+
return FeatureSet::ReferenceTypes | FeatureSet::ExceptionHandling;
494+
case HeapType::BasicHeapType::any:
495+
case HeapType::BasicHeapType::eq:
496+
case HeapType::BasicHeapType::i31:
497+
return FeatureSet::ReferenceTypes | FeatureSet::GC;
498+
default: {}
499+
}
500+
}
501+
// Note: Technically typed function references also require the typed
502+
// function references feature, however, we use these types internally
503+
// regardless of the presence of features (in particular, since during
504+
// load of the wasm we don't know the features yet, so we apply the more
505+
// refined types), so we don't add that in any case here.
506+
return FeatureSet::ReferenceTypes;
496507
} else if (t.isRtt()) {
497508
return FeatureSet::ReferenceTypes | FeatureSet::GC;
498509
}
499510
TODO_SINGLE_COMPOUND(t);
500511
switch (t.getBasic()) {
501512
case Type::v128:
502513
return FeatureSet::SIMD;
503-
case Type::funcref:
504-
case Type::externref:
505-
return FeatureSet::ReferenceTypes;
506-
case Type::exnref:
507-
return FeatureSet::ReferenceTypes | FeatureSet::ExceptionHandling;
508-
case Type::anyref:
509-
case Type::eqref:
510-
case Type::i31ref:
511-
return FeatureSet::ReferenceTypes | FeatureSet::GC;
512514
default:
513515
return FeatureSet::MVP;
514516
}
@@ -594,17 +596,21 @@ bool Type::isSubType(Type left, Type right) {
594596
return true;
595597
}
596598
// Various things are subtypes of eqref.
597-
if ((left == Type::i31ref || left.getHeapType().isArray() ||
598-
left.getHeapType().isStruct()) &&
599-
right == Type::eqref) {
599+
auto leftHeap = left.getHeapType();
600+
auto rightHeap = right.getHeapType();
601+
if ((leftHeap == HeapType::i31 || leftHeap.isArray() ||
602+
leftHeap.isStruct()) &&
603+
rightHeap == HeapType::eq &&
604+
(!left.isNullable() || right.isNullable())) {
600605
return true;
601606
}
602607
// All typed function signatures are subtypes of funcref.
603-
if (left.getHeapType().isSignature() && right == Type::funcref) {
608+
if (leftHeap.isSignature() && rightHeap == HeapType::func &&
609+
(!left.isNullable() || right.isNullable())) {
604610
return true;
605611
}
606612
// A non-nullable type is a supertype of a nullable one
607-
if (left.getHeapType() == right.getHeapType() && !left.isNullable()) {
613+
if (leftHeap == rightHeap && !left.isNullable()) {
608614
// The only difference is the nullability.
609615
assert(right.isNullable());
610616
return true;

src/wasm/wasm-validator.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2191,9 +2191,11 @@ void FunctionValidator::visitI31Get(I31Get* curr) {
21912191
shouldBeTrue(getModule()->features.hasGC(),
21922192
curr,
21932193
"i31.get_s/u requires gc to be enabled");
2194+
// FIXME: use i31ref here, which is non-nullable, when we support non-
2195+
// nullability.
21942196
shouldBeSubTypeOrFirstIsUnreachable(
21952197
curr->i31->type,
2196-
Type::i31ref,
2198+
Type(HeapType::i31, Nullable),
21972199
curr->i31,
21982200
"i31.get_s/u's argument should be i31ref");
21992201
}

test/gc.wast

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,9 @@
6868
(local.set $local_i32 (i31.get_s (local.get $local_i31ref)))
6969
(local.set $local_i32 (i31.get_u (local.get $local_i31ref)))
7070
)
71+
72+
(func $test-variants
73+
(local $local_i31refnull (ref null i31))
74+
(local $local_i31refnonnull (ref i31))
75+
)
7176
)

test/gc.wast.from-wast

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
(type $none_=>_none (func))
33
(global $global_anyref (mut anyref) (ref.null any))
44
(global $global_eqref (mut eqref) (ref.null eq))
5-
(global $global_i31ref (mut i31ref) (i31.new
5+
(global $global_i31ref (mut (ref null i31)) (i31.new
66
(i32.const 0)
77
))
88
(global $global_anyref2 (mut anyref) (ref.null eq))
@@ -16,7 +16,7 @@
1616
(local $local_i32 i32)
1717
(local $local_anyref anyref)
1818
(local $local_eqref eqref)
19-
(local $local_i31ref i31ref)
19+
(local $local_i31ref (ref null i31))
2020
(local.set $local_anyref
2121
(local.get $local_anyref)
2222
)
@@ -148,4 +148,9 @@
148148
)
149149
)
150150
)
151+
(func $test-variants
152+
(local $local_i31refnull (ref null i31))
153+
(local $local_i31refnonnull (ref null i31))
154+
(nop)
155+
)
151156
)

test/gc.wast.fromBinary

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
(type $none_=>_none (func))
33
(global $global_anyref (mut anyref) (ref.null any))
44
(global $global_eqref (mut eqref) (ref.null eq))
5-
(global $global_i31ref (mut i31ref) (i31.new
5+
(global $global_i31ref (mut (ref null i31)) (i31.new
66
(i32.const 0)
77
))
88
(global $global_anyref2 (mut anyref) (ref.null eq))
@@ -16,7 +16,7 @@
1616
(local $local_i32 i32)
1717
(local $local_anyref anyref)
1818
(local $local_eqref eqref)
19-
(local $local_i31ref i31ref)
19+
(local $local_i31ref (ref null i31))
2020
(local.set $local_anyref
2121
(local.get $local_anyref)
2222
)
@@ -148,5 +148,10 @@
148148
)
149149
)
150150
)
151+
(func $test-variants
152+
(local $local_i31refnull (ref null i31))
153+
(local $local_i31refnonnull (ref null i31))
154+
(nop)
155+
)
151156
)
152157

0 commit comments

Comments
 (0)