@@ -240,12 +240,6 @@ struct AbstractTypeRefining : public Pass {
240
240
241
241
TypeMapper::TypeUpdates mapping;
242
242
243
- // Track whether we optimize any described types to bottom. If we do, then
244
- // we could end up with descriptor casts to nullref, which need to be fixed
245
- // up before ReFinalize reintroduces the cast result type that was supposed
246
- // to be optimized out.
247
- bool optimizedDescribedType = false ;
248
-
249
243
for (auto type : subTypes.types ) {
250
244
if (!type.isStruct ()) {
251
245
// TODO: support arrays and funcs
@@ -262,9 +256,6 @@ struct AbstractTypeRefining : public Pass {
262
256
// We check this first as it is the most powerful change.
263
257
if (createdTypesOrSubTypes.count (type) == 0 ) {
264
258
mapping[type] = type.getBottom ();
265
- if (type.getDescriptorType ()) {
266
- optimizedDescribedType = true ;
267
- }
268
259
continue ;
269
260
}
270
261
@@ -278,6 +269,8 @@ struct AbstractTypeRefining : public Pass {
278
269
return ;
279
270
}
280
271
272
+ fixupCasts (*module , mapping);
273
+
281
274
// Rewriting types can usually rewrite subtype relationships. For example,
282
275
// if we have this:
283
276
//
@@ -303,77 +296,137 @@ struct AbstractTypeRefining : public Pass {
303
296
304
297
AbstractTypeRefiningTypeMapper (*module , mapping).map ();
305
298
306
- if (optimizedDescribedType) {
307
- // At this point we may have casts like this:
308
- //
309
- // (ref.cast_desc nullref
310
- // (some struct...)
311
- // (some desc...)
312
- // )
313
- //
314
- // ReFinalize would fix up the cast target to be the struct type described
315
- // by the descriptor, but that struct type was supposed to have been
316
- // optimized out. Optimize out the cast (which we know must either be a
317
- // null check or unconditional trap) before ReFinalize can get to it.
318
- fixupDescriptorCasts (*module );
319
- }
320
-
321
299
// Refinalize to propagate the type changes we made. For example, a refined
322
300
// cast may lead to a struct.get reading a more refined type using that
323
301
// type.
324
302
ReFinalize ().run (getPassRunner (), module );
325
303
}
326
304
327
- void fixupDescriptorCasts (Module& module ) {
305
+ void fixupCasts (Module& module , const TypeMapper::TypeUpdates& mapping) {
306
+ if (!module .features .hasCustomDescriptors ()) {
307
+ // No descriptor or exact casts to fix up.
308
+ return ;
309
+ }
310
+
311
+ // We may have casts like this:
312
+ //
313
+ // (ref.cast_desc (ref null $optimized-to-bottom)
314
+ // (some struct...)
315
+ // (some desc...)
316
+ // )
317
+ //
318
+ // We will optimize the cast target to nullref, but then ReFinalize would
319
+ // fix up the cast target back to $optimized-to-bottom. Optimize out the
320
+ // cast (which we know must either be a null check or unconditional trap)
321
+ // to avoid this reintroduction of the optimized type.
322
+ //
323
+ // Separately, we may have exact casts like this:
324
+ //
325
+ // (br_on_cast anyref $l (ref (exact $uninstantiated)) ... )
326
+ //
327
+ // We know such casts will fail (or will pass only for null values), but
328
+ // with traps-never-happen, we might optimize them to this:
329
+ //
330
+ // (br_on_cast anyref $l (ref (exact $instantiated-subtype)) ... )
331
+ //
332
+ // This might cause the casts to incorrectly start succeeding. To avoid
333
+ // that, optimize them out first.
328
334
struct CastFixer : WalkerPass<PostWalker<CastFixer>> {
335
+ const TypeMapper::TypeUpdates& mapping;
336
+
337
+ CastFixer (const TypeMapper::TypeUpdates& mapping) : mapping(mapping) {}
338
+
329
339
bool isFunctionParallel () override { return true ; }
340
+
330
341
std::unique_ptr<Pass> create () override {
331
- return std::make_unique<CastFixer>();
342
+ return std::make_unique<CastFixer>(mapping);
343
+ }
344
+
345
+ Block* localizeChildren (Expression* curr) {
346
+ return ChildLocalizer (
347
+ curr, getFunction (), *getModule (), getPassOptions ())
348
+ .getChildrenReplacement ();
349
+ }
350
+
351
+ std::optional<HeapType> getOptimized (Type type) {
352
+ if (!type.isRef ()) {
353
+ return std::nullopt;
354
+ }
355
+ auto heapType = type.getHeapType ();
356
+ auto it = mapping.find (heapType);
357
+ if (it == mapping.end ()) {
358
+ return std::nullopt;
359
+ }
360
+ assert (it->second != heapType);
361
+ return it->second ;
332
362
}
363
+
333
364
void visitRefCast (RefCast* curr) {
334
- if (!curr->desc || !curr->type .isNull ()) {
365
+ auto optimized = getOptimized (curr->type );
366
+ if (!optimized) {
335
367
return ;
336
368
}
337
- // Preserve the trap on a null descriptor.
338
- if (curr->desc ->type .isNullable ()) {
339
- curr->desc =
340
- Builder (*getModule ()).makeRefAs (RefAsNonNull, curr->desc );
369
+ // Exact casts to any optimized type and descriptor casts whose types
370
+ // will be optimized to bottom either admit null or fail
371
+ // unconditionally. Optimize to a cast to bottom, reusing curr and
372
+ // preserving nullability. We may need to move the ref value past the
373
+ // descriptor value, if any.
374
+ Builder builder (*getModule ());
375
+ if (curr->type .isExact () || (curr->desc && optimized->isBottom ())) {
376
+ if (curr->desc ) {
377
+ if (curr->desc ->type .isNullable () &&
378
+ !getPassOptions ().trapsNeverHappen ) {
379
+ curr->desc = builder.makeRefAs (RefAsNonNull, curr->desc );
380
+ }
381
+ Block* replacement = localizeChildren (curr);
382
+ curr->desc = nullptr ;
383
+ curr->type = curr->type .with (optimized->getBottom ());
384
+ replacement->list .push_back (curr);
385
+ replacement->type = curr->type ;
386
+ replaceCurrent (replacement);
387
+ } else {
388
+ curr->type = curr->type .with (optimized->getBottom ());
389
+ }
341
390
}
342
- Block* replacement =
343
- ChildLocalizer (curr, getFunction (), *getModule (), getPassOptions ())
344
- .getChildrenReplacement ();
345
- // Reuse `curr` as the cast to nullref. Leave further optimization of
346
- // the cast to OptimizeInstructions.
347
- curr->desc = nullptr ;
348
- replacement->list .push_back (curr);
349
- replacement->type = curr->type ;
350
- replaceCurrent (replacement);
351
391
}
392
+
352
393
void visitBrOn (BrOn* curr) {
353
- if (curr->op != BrOnCastDesc && curr->op != BrOnCastDescFail ) {
394
+ if (curr->op == BrOnNull || curr->op == BrOnNonNull ) {
354
395
return ;
355
396
}
356
- if (!curr->castType .isNull ()) {
397
+ auto optimized = getOptimized (curr->castType );
398
+ if (!optimized) {
357
399
return ;
358
400
}
401
+ // Optimize the same way we optimize ref.cast*.
402
+ Builder builder (*getModule ());
359
403
bool isFail = curr->op == BrOnCastDescFail;
360
- // Preserve the trap on a null descriptor.
361
- if (curr->desc ->type .isNullable ()) {
362
- curr->desc =
363
- Builder (*getModule ()).makeRefAs (RefAsNonNull, curr->desc );
404
+ if (curr->castType .isExact () || (curr->desc && optimized->isBottom ())) {
405
+ if (curr->desc ) {
406
+ if (curr->desc ->type .isNullable () &&
407
+ !getPassOptions ().trapsNeverHappen ) {
408
+ curr->desc = builder.makeRefAs (RefAsNonNull, curr->desc );
409
+ }
410
+ Block* replacement = localizeChildren (curr);
411
+ // Reuse `curr` as a br_on_cast to nullref. Leave further
412
+ // optimization of the branch to RemoveUnusedBrs.
413
+ curr->desc = nullptr ;
414
+ curr->castType = curr->castType .with (optimized->getBottom ());
415
+ if (isFail) {
416
+ curr->op = BrOnCastFail;
417
+ curr->type = curr->castType ;
418
+ } else {
419
+ curr->op = BrOnCast;
420
+ }
421
+ replacement->list .push_back (curr);
422
+ replacement->type = curr->type ;
423
+ replaceCurrent (replacement);
424
+ } else {
425
+ curr->castType = curr->castType .with (optimized->getBottom ());
426
+ }
364
427
}
365
- Block* replacement =
366
- ChildLocalizer (curr, getFunction (), *getModule (), getPassOptions ())
367
- .getChildrenReplacement ();
368
- // Reuse `curr` as a br_on_cast to nullref. Leave further optimization
369
- // of the branch to RemoveUnusedBrs.
370
- curr->desc = nullptr ;
371
- curr->op = isFail ? BrOnCastFail : BrOnCast;
372
- replacement->list .push_back (curr);
373
- replacement->type = curr->type ;
374
- replaceCurrent (replacement);
375
428
}
376
- } fixer;
429
+ } fixer (mapping) ;
377
430
fixer.run (getPassRunner (), &module );
378
431
}
379
432
};
0 commit comments