Skip to content

Commit b3a14c2

Browse files
authored
[custom-descriptors] ref.cast_desc (#7630)
Add an optional `desc` child to `RefCast` for use in the new `ref.cast_desc` instruction. Add support for parsing, printing, and validating the new instruction. Fix a few minor issues left over from adding the branching descriptor casts along the way.
1 parent bb17db3 commit b3a14c2

20 files changed

+591
-74
lines changed

scripts/gen-s-parser.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,8 @@
609609
("i31.get_s", "makeI31Get(true)"),
610610
("i31.get_u", "makeI31Get(false)"),
611611
("ref.test", "makeRefTest()"),
612-
("ref.cast", "makeRefCast()"),
612+
("ref.cast", "makeRefCast(false)"),
613+
("ref.cast_desc", "makeRefCast(true)"),
613614
("ref.get_desc", "makeRefGetDesc()"),
614615
("br_on_null", "makeBrOnNull()"),
615616
("br_on_non_null", "makeBrOnNull(true)"),

scripts/test/fuzzing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
'custom-descriptors.wast',
116116
'br_on_cast_desc.wast',
117117
'ref.get_cast.wast',
118+
'ref.cast_desc.wast',
118119
# TODO: fix split_wast() on tricky escaping situations like a string ending
119120
# in \\" (the " is not escaped - there is an escaped \ before it)
120121
'string-lifting-section.wast',

src/gen-s-parser.inc

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4807,12 +4807,23 @@ switch (buf[0]) {
48074807
return Ok{};
48084808
}
48094809
goto parse_error;
4810-
case 'c':
4811-
if (op == "ref.cast"sv) {
4812-
CHECK_ERR(makeRefCast(ctx, pos, annotations));
4813-
return Ok{};
4810+
case 'c': {
4811+
switch (buf[8]) {
4812+
case '\0':
4813+
if (op == "ref.cast"sv) {
4814+
CHECK_ERR(makeRefCast(ctx, pos, annotations, false));
4815+
return Ok{};
4816+
}
4817+
goto parse_error;
4818+
case '_':
4819+
if (op == "ref.cast_desc"sv) {
4820+
CHECK_ERR(makeRefCast(ctx, pos, annotations, true));
4821+
return Ok{};
4822+
}
4823+
goto parse_error;
4824+
default: goto parse_error;
48144825
}
4815-
goto parse_error;
4826+
}
48164827
case 'e':
48174828
if (op == "ref.eq"sv) {
48184829
CHECK_ERR(makeRefEq(ctx, pos, annotations));

src/ir/child-typer.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -852,9 +852,17 @@ template<typename Subtype> struct ChildTyper : OverriddenVisitor<Subtype> {
852852
note(&curr->ref, Type(top, Nullable));
853853
}
854854

855-
void visitRefCast(RefCast* curr) {
855+
void visitRefCast(RefCast* curr, std::optional<Type> target = std::nullopt) {
856856
auto top = curr->type.getHeapType().getTop();
857857
note(&curr->ref, Type(top, Nullable));
858+
if (curr->desc) {
859+
if (!target) {
860+
target = curr->type;
861+
}
862+
auto desc = target->getHeapType().getDescriptorType();
863+
assert(desc);
864+
note(&curr->desc, Type(*desc, Nullable, curr->type.getExactness()));
865+
}
858866
}
859867

860868
void visitRefGetDesc(RefGetDesc* curr,
@@ -881,8 +889,9 @@ template<typename Subtype> struct ChildTyper : OverriddenVisitor<Subtype> {
881889
auto top = target->getHeapType().getTop();
882890
note(&curr->ref, Type(top, Nullable));
883891
if (curr->op == BrOnCastDesc || curr->op == BrOnCastDescFail) {
884-
auto descriptor = *target->getHeapType().getDescriptorType();
885-
note(&curr->desc, Type(descriptor, Nullable));
892+
auto desc = target->getHeapType().getDescriptorType();
893+
assert(desc);
894+
note(&curr->desc, Type(*desc, Nullable));
886895
}
887896
return;
888897
}

src/parser/contexts.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ struct NullInstrParserCtx {
733733
return Ok{};
734734
}
735735
template<typename TypeT>
736-
Result<> makeRefCast(Index, const std::vector<Annotation>&, TypeT) {
736+
Result<> makeRefCast(Index, const std::vector<Annotation>&, TypeT, bool) {
737737
return Ok{};
738738
}
739739
template<typename HeapTypeT>
@@ -2592,8 +2592,9 @@ struct ParseDefsCtx : TypeParserCtx<ParseDefsCtx>, AnnotationParserCtx {
25922592

25932593
Result<> makeRefCast(Index pos,
25942594
const std::vector<Annotation>& annotations,
2595-
Type type) {
2596-
return withLoc(pos, irBuilder.makeRefCast(type));
2595+
Type type,
2596+
bool isDesc) {
2597+
return withLoc(pos, irBuilder.makeRefCast(type, isDesc));
25972598
}
25982599

25992600
Result<> makeRefGetDesc(Index pos,

src/parser/parsers.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ Result<> makeI31Get(Ctx&, Index, const std::vector<Annotation>&, bool signed_);
231231
template<typename Ctx>
232232
Result<> makeRefTest(Ctx&, Index, const std::vector<Annotation>&);
233233
template<typename Ctx>
234-
Result<> makeRefCast(Ctx&, Index, const std::vector<Annotation>&);
234+
Result<> makeRefCast(Ctx&, Index, const std::vector<Annotation>&, bool isDesc);
235235
template<typename Ctx>
236236
Result<> makeRefGetDesc(Ctx&, Index, const std::vector<Annotation>&);
237237
template<typename Ctx>
@@ -2219,11 +2219,13 @@ makeRefTest(Ctx& ctx, Index pos, const std::vector<Annotation>& annotations) {
22192219
}
22202220

22212221
template<typename Ctx>
2222-
Result<>
2223-
makeRefCast(Ctx& ctx, Index pos, const std::vector<Annotation>& annotations) {
2222+
Result<> makeRefCast(Ctx& ctx,
2223+
Index pos,
2224+
const std::vector<Annotation>& annotations,
2225+
bool isDesc) {
22242226
auto type = reftype(ctx);
22252227
CHECK_ERR(type);
2226-
return ctx.makeRefCast(pos, annotations, *type);
2228+
return ctx.makeRefCast(pos, annotations, *type, isDesc);
22272229
}
22282230

22292231
template<typename Ctx>

src/passes/Print.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2208,7 +2208,11 @@ struct PrintExpressionContents
22082208
printType(curr->castType);
22092209
}
22102210
void visitRefCast(RefCast* curr) {
2211-
printMedium(o, "ref.cast ");
2211+
if (curr->desc) {
2212+
printMedium(o, "ref.cast_desc ");
2213+
} else {
2214+
printMedium(o, "ref.cast ");
2215+
}
22122216
printType(curr->type);
22132217
}
22142218
void visitRefGetDesc(RefGetDesc* curr) {

src/wasm-binary.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,6 +1172,8 @@ enum ASTNodes {
11721172
RefTestNull = 0x15,
11731173
RefCast = 0x16,
11741174
RefCastNull = 0x17,
1175+
RefCastDesc = 0x23,
1176+
RefCastDescNull = 0x24,
11751177
BrOnCast = 0x18,
11761178
BrOnCastFail = 0x19,
11771179
BrOnCastDesc = 0x25,

src/wasm-builder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,8 +888,12 @@ class Builder {
888888
return ret;
889889
}
890890
RefCast* makeRefCast(Expression* ref, Type type) {
891+
return makeRefCast(ref, nullptr, type);
892+
}
893+
RefCast* makeRefCast(Expression* ref, Expression* desc, Type type) {
891894
auto* ret = wasm.allocator.alloc<RefCast>();
892895
ret->ref = ref;
896+
ret->desc = desc;
893897
ret->type = type;
894898
ret->finalize();
895899
return ret;

src/wasm-delegations-fields.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ DELEGATE_FIELD_CHILD(RefTest, ref)
650650
DELEGATE_FIELD_CASE_END(RefTest)
651651

652652
DELEGATE_FIELD_CASE_START(RefCast)
653+
DELEGATE_FIELD_OPTIONAL_IMMEDIATE_TYPED_CHILD(RefCast, desc)
653654
DELEGATE_FIELD_CHILD(RefCast, ref)
654655
DELEGATE_FIELD_CASE_END(RefCast)
655656

0 commit comments

Comments
 (0)