Skip to content

Commit 981cf69

Browse files
authored
[Custom Descriptors] Fix AbstractTypeRefining for exact casts (#7768)
AbstractTypeRefining would previously optimize a cast to `(ref (exact $uninstantiated))` to `(ref (exact $instantiated))` when traps never happen and `$instantiated <: $uninstantiated`. This is not correct, however, because it transforms a cast that will never succeed to a cast that might succeed. Generalize our existing logic for fixing up descriptor casts to also fix up exact casts. Move the fixups to occur before the types are rewritten so we don't lose track of the original types of the casts.
1 parent b7b18a8 commit 981cf69

File tree

4 files changed

+904
-75
lines changed

4 files changed

+904
-75
lines changed

scripts/test/fuzzing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
'gto-desc.wast',
130130
'type-ssa-desc.wast',
131131
'abstract-type-refining-desc.wast',
132+
'abstract-type-refining-tnh-exact-casts.wast',
132133
'remove-unused-brs-desc.wast',
133134
'vacuum-desc.wast',
134135
'j2cl-merge-itables-desc.wast',

src/passes/AbstractTypeRefining.cpp

Lines changed: 110 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,6 @@ struct AbstractTypeRefining : public Pass {
240240

241241
TypeMapper::TypeUpdates mapping;
242242

243-
// Track whether we optimize any described types to bottom. If we do, then
244-
// we could end up with descriptor casts to nullref, which need to be fixed
245-
// up before ReFinalize reintroduces the cast result type that was supposed
246-
// to be optimized out.
247-
bool optimizedDescribedType = false;
248-
249243
for (auto type : subTypes.types) {
250244
if (!type.isStruct()) {
251245
// TODO: support arrays and funcs
@@ -262,9 +256,6 @@ struct AbstractTypeRefining : public Pass {
262256
// We check this first as it is the most powerful change.
263257
if (createdTypesOrSubTypes.count(type) == 0) {
264258
mapping[type] = type.getBottom();
265-
if (type.getDescriptorType()) {
266-
optimizedDescribedType = true;
267-
}
268259
continue;
269260
}
270261

@@ -278,6 +269,8 @@ struct AbstractTypeRefining : public Pass {
278269
return;
279270
}
280271

272+
fixupCasts(*module, mapping);
273+
281274
// Rewriting types can usually rewrite subtype relationships. For example,
282275
// if we have this:
283276
//
@@ -303,77 +296,137 @@ struct AbstractTypeRefining : public Pass {
303296

304297
AbstractTypeRefiningTypeMapper(*module, mapping).map();
305298

306-
if (optimizedDescribedType) {
307-
// At this point we may have casts like this:
308-
//
309-
// (ref.cast_desc nullref
310-
// (some struct...)
311-
// (some desc...)
312-
// )
313-
//
314-
// ReFinalize would fix up the cast target to be the struct type described
315-
// by the descriptor, but that struct type was supposed to have been
316-
// optimized out. Optimize out the cast (which we know must either be a
317-
// null check or unconditional trap) before ReFinalize can get to it.
318-
fixupDescriptorCasts(*module);
319-
}
320-
321299
// Refinalize to propagate the type changes we made. For example, a refined
322300
// cast may lead to a struct.get reading a more refined type using that
323301
// type.
324302
ReFinalize().run(getPassRunner(), module);
325303
}
326304

327-
void fixupDescriptorCasts(Module& module) {
305+
void fixupCasts(Module& module, const TypeMapper::TypeUpdates& mapping) {
306+
if (!module.features.hasCustomDescriptors()) {
307+
// No descriptor or exact casts to fix up.
308+
return;
309+
}
310+
311+
// We may have casts like this:
312+
//
313+
// (ref.cast_desc (ref null $optimized-to-bottom)
314+
// (some struct...)
315+
// (some desc...)
316+
// )
317+
//
318+
// We will optimize the cast target to nullref, but then ReFinalize would
319+
// fix up the cast target back to $optimized-to-bottom. Optimize out the
320+
// cast (which we know must either be a null check or unconditional trap)
321+
// to avoid this reintroduction of the optimized type.
322+
//
323+
// Separately, we may have exact casts like this:
324+
//
325+
// (br_on_cast anyref $l (ref (exact $uninstantiated)) ... )
326+
//
327+
// We know such casts will fail (or will pass only for null values), but
328+
// with traps-never-happen, we might optimize them to this:
329+
//
330+
// (br_on_cast anyref $l (ref (exact $instantiated-subtype)) ... )
331+
//
332+
// This might cause the casts to incorrectly start succeeding. To avoid
333+
// that, optimize them out first.
328334
struct CastFixer : WalkerPass<PostWalker<CastFixer>> {
335+
const TypeMapper::TypeUpdates& mapping;
336+
337+
CastFixer(const TypeMapper::TypeUpdates& mapping) : mapping(mapping) {}
338+
329339
bool isFunctionParallel() override { return true; }
340+
330341
std::unique_ptr<Pass> create() override {
331-
return std::make_unique<CastFixer>();
342+
return std::make_unique<CastFixer>(mapping);
343+
}
344+
345+
Block* localizeChildren(Expression* curr) {
346+
return ChildLocalizer(
347+
curr, getFunction(), *getModule(), getPassOptions())
348+
.getChildrenReplacement();
349+
}
350+
351+
std::optional<HeapType> getOptimized(Type type) {
352+
if (!type.isRef()) {
353+
return std::nullopt;
354+
}
355+
auto heapType = type.getHeapType();
356+
auto it = mapping.find(heapType);
357+
if (it == mapping.end()) {
358+
return std::nullopt;
359+
}
360+
assert(it->second != heapType);
361+
return it->second;
332362
}
363+
333364
void visitRefCast(RefCast* curr) {
334-
if (!curr->desc || !curr->type.isNull()) {
365+
auto optimized = getOptimized(curr->type);
366+
if (!optimized) {
335367
return;
336368
}
337-
// Preserve the trap on a null descriptor.
338-
if (curr->desc->type.isNullable()) {
339-
curr->desc =
340-
Builder(*getModule()).makeRefAs(RefAsNonNull, curr->desc);
369+
// Exact casts to any optimized type and descriptor casts whose types
370+
// will be optimized to bottom either admit null or fail
371+
// unconditionally. Optimize to a cast to bottom, reusing curr and
372+
// preserving nullability. We may need to move the ref value past the
373+
// descriptor value, if any.
374+
Builder builder(*getModule());
375+
if (curr->type.isExact() || (curr->desc && optimized->isBottom())) {
376+
if (curr->desc) {
377+
if (curr->desc->type.isNullable() &&
378+
!getPassOptions().trapsNeverHappen) {
379+
curr->desc = builder.makeRefAs(RefAsNonNull, curr->desc);
380+
}
381+
Block* replacement = localizeChildren(curr);
382+
curr->desc = nullptr;
383+
curr->type = curr->type.with(optimized->getBottom());
384+
replacement->list.push_back(curr);
385+
replacement->type = curr->type;
386+
replaceCurrent(replacement);
387+
} else {
388+
curr->type = curr->type.with(optimized->getBottom());
389+
}
341390
}
342-
Block* replacement =
343-
ChildLocalizer(curr, getFunction(), *getModule(), getPassOptions())
344-
.getChildrenReplacement();
345-
// Reuse `curr` as the cast to nullref. Leave further optimization of
346-
// the cast to OptimizeInstructions.
347-
curr->desc = nullptr;
348-
replacement->list.push_back(curr);
349-
replacement->type = curr->type;
350-
replaceCurrent(replacement);
351391
}
392+
352393
void visitBrOn(BrOn* curr) {
353-
if (curr->op != BrOnCastDesc && curr->op != BrOnCastDescFail) {
394+
if (curr->op == BrOnNull || curr->op == BrOnNonNull) {
354395
return;
355396
}
356-
if (!curr->castType.isNull()) {
397+
auto optimized = getOptimized(curr->castType);
398+
if (!optimized) {
357399
return;
358400
}
401+
// Optimize the same way we optimize ref.cast*.
402+
Builder builder(*getModule());
359403
bool isFail = curr->op == BrOnCastDescFail;
360-
// Preserve the trap on a null descriptor.
361-
if (curr->desc->type.isNullable()) {
362-
curr->desc =
363-
Builder(*getModule()).makeRefAs(RefAsNonNull, curr->desc);
404+
if (curr->castType.isExact() || (curr->desc && optimized->isBottom())) {
405+
if (curr->desc) {
406+
if (curr->desc->type.isNullable() &&
407+
!getPassOptions().trapsNeverHappen) {
408+
curr->desc = builder.makeRefAs(RefAsNonNull, curr->desc);
409+
}
410+
Block* replacement = localizeChildren(curr);
411+
// Reuse `curr` as a br_on_cast to nullref. Leave further
412+
// optimization of the branch to RemoveUnusedBrs.
413+
curr->desc = nullptr;
414+
curr->castType = curr->castType.with(optimized->getBottom());
415+
if (isFail) {
416+
curr->op = BrOnCastFail;
417+
curr->type = curr->castType;
418+
} else {
419+
curr->op = BrOnCast;
420+
}
421+
replacement->list.push_back(curr);
422+
replacement->type = curr->type;
423+
replaceCurrent(replacement);
424+
} else {
425+
curr->castType = curr->castType.with(optimized->getBottom());
426+
}
364427
}
365-
Block* replacement =
366-
ChildLocalizer(curr, getFunction(), *getModule(), getPassOptions())
367-
.getChildrenReplacement();
368-
// Reuse `curr` as a br_on_cast to nullref. Leave further optimization
369-
// of the branch to RemoveUnusedBrs.
370-
curr->desc = nullptr;
371-
curr->op = isFail ? BrOnCastFail : BrOnCast;
372-
replacement->list.push_back(curr);
373-
replacement->type = curr->type;
374-
replaceCurrent(replacement);
375428
}
376-
} fixer;
429+
} fixer(mapping);
377430
fixer.run(getPassRunner(), &module);
378431
}
379432
};

test/lit/passes/abstract-type-refining-desc.wast

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -134,19 +134,17 @@
134134

135135
;; YESTNH: (func $cast-nullable-effect (type $2) (param $ref anyref) (param $desc (ref null $desc)) (result nullref)
136136
;; YESTNH-NEXT: (local $2 anyref)
137-
;; YESTNH-NEXT: (local $3 (ref $desc))
137+
;; YESTNH-NEXT: (local $3 (ref null $desc))
138138
;; YESTNH-NEXT: (local.set $2
139139
;; YESTNH-NEXT: (block (result anyref)
140140
;; YESTNH-NEXT: (call $effect)
141141
;; YESTNH-NEXT: (local.get $ref)
142142
;; YESTNH-NEXT: )
143143
;; YESTNH-NEXT: )
144144
;; YESTNH-NEXT: (local.set $3
145-
;; YESTNH-NEXT: (ref.as_non_null
146-
;; YESTNH-NEXT: (block (result (ref null $desc))
147-
;; YESTNH-NEXT: (call $effect)
148-
;; YESTNH-NEXT: (local.get $desc)
149-
;; YESTNH-NEXT: )
145+
;; YESTNH-NEXT: (block (result (ref null $desc))
146+
;; YESTNH-NEXT: (call $effect)
147+
;; YESTNH-NEXT: (local.get $desc)
150148
;; YESTNH-NEXT: )
151149
;; YESTNH-NEXT: )
152150
;; YESTNH-NEXT: (ref.cast nullref
@@ -327,7 +325,7 @@
327325

328326
;; YESTNH: (func $branch-nullable-effect (type $2) (param $ref anyref) (param $desc (ref null $desc)) (result nullref)
329327
;; YESTNH-NEXT: (local $2 anyref)
330-
;; YESTNH-NEXT: (local $3 (ref $desc))
328+
;; YESTNH-NEXT: (local $3 (ref null $desc))
331329
;; YESTNH-NEXT: (block $block (result nullref)
332330
;; YESTNH-NEXT: (drop
333331
;; YESTNH-NEXT: (block (result (ref any))
@@ -338,11 +336,9 @@
338336
;; YESTNH-NEXT: )
339337
;; YESTNH-NEXT: )
340338
;; YESTNH-NEXT: (local.set $3
341-
;; YESTNH-NEXT: (ref.as_non_null
342-
;; YESTNH-NEXT: (block (result (ref null $desc))
343-
;; YESTNH-NEXT: (call $effect)
344-
;; YESTNH-NEXT: (local.get $desc)
345-
;; YESTNH-NEXT: )
339+
;; YESTNH-NEXT: (block (result (ref null $desc))
340+
;; YESTNH-NEXT: (call $effect)
341+
;; YESTNH-NEXT: (local.get $desc)
346342
;; YESTNH-NEXT: )
347343
;; YESTNH-NEXT: )
348344
;; YESTNH-NEXT: (br_on_cast $block anyref nullref
@@ -602,7 +598,7 @@
602598

603599
;; YESTNH: (func $branch-fail-nullable-effect (type $2) (param $ref anyref) (param $desc (ref null $desc)) (result nullref)
604600
;; YESTNH-NEXT: (local $2 anyref)
605-
;; YESTNH-NEXT: (local $3 (ref $desc))
601+
;; YESTNH-NEXT: (local $3 (ref null $desc))
606602
;; YESTNH-NEXT: (drop
607603
;; YESTNH-NEXT: (block $block (result (ref any))
608604
;; YESTNH-NEXT: (return
@@ -614,11 +610,9 @@
614610
;; YESTNH-NEXT: )
615611
;; YESTNH-NEXT: )
616612
;; YESTNH-NEXT: (local.set $3
617-
;; YESTNH-NEXT: (ref.as_non_null
618-
;; YESTNH-NEXT: (block (result (ref null $desc))
619-
;; YESTNH-NEXT: (call $effect)
620-
;; YESTNH-NEXT: (local.get $desc)
621-
;; YESTNH-NEXT: )
613+
;; YESTNH-NEXT: (block (result (ref null $desc))
614+
;; YESTNH-NEXT: (call $effect)
615+
;; YESTNH-NEXT: (local.get $desc)
622616
;; YESTNH-NEXT: )
623617
;; YESTNH-NEXT: )
624618
;; YESTNH-NEXT: (br_on_cast_fail $block anyref nullref

0 commit comments

Comments
 (0)