33
33
// must fail unless it allows null.
34
34
//
35
35
36
+ #include < memory>
37
+
38
+ #include " ir/localize.h"
36
39
#include " ir/module-utils.h"
37
40
#include " ir/subtypes.h"
38
41
#include " ir/type-updating.h"
@@ -112,14 +115,8 @@ struct AbstractTypeRefining : public Pass {
112
115
// module, given closed world, but we'd also need to make sure that
113
116
// we don't need to make any changes to public types that refer to
114
117
// them.
115
- auto heapTypes = ModuleUtils::collectHeapTypeInfo (
116
- *module ,
117
- ModuleUtils::TypeInclusion::AllTypes,
118
- ModuleUtils::VisibilityHandling::FindVisibility);
119
- for (auto & [type, info] : heapTypes) {
120
- if (info.visibility == ModuleUtils::Visibility::Public) {
121
- createdTypes.insert (type);
122
- }
118
+ for (auto type : ModuleUtils::getPublicHeapTypes (*module )) {
119
+ createdTypes.insert (type);
123
120
}
124
121
125
122
SubTypes subTypes (*module );
@@ -243,6 +240,12 @@ struct AbstractTypeRefining : public Pass {
243
240
244
241
TypeMapper::TypeUpdates mapping;
245
242
243
+ // Track whether we optimize any described types to bottom. If we do, then
244
+ // we could end up with descriptor casts to nullref, which need to be fixed
245
+ // up before ReFinalize reintroduces the cast result type that was supposed
246
+ // to be optimized out.
247
+ bool optimizedDescribedType = false ;
248
+
246
249
for (auto type : subTypes.types ) {
247
250
if (!type.isStruct ()) {
248
251
// TODO: support arrays and funcs
@@ -259,6 +262,9 @@ struct AbstractTypeRefining : public Pass {
259
262
// We check this first as it is the most powerful change.
260
263
if (createdTypesOrSubTypes.count (type) == 0 ) {
261
264
mapping[type] = type.getBottom ();
265
+ if (type.getDescriptorType ()) {
266
+ optimizedDescribedType = true ;
267
+ }
262
268
continue ;
263
269
}
264
270
@@ -297,11 +303,69 @@ struct AbstractTypeRefining : public Pass {
297
303
298
304
AbstractTypeRefiningTypeMapper (*module , mapping).map ();
299
305
306
+ if (optimizedDescribedType) {
307
+ // At this point we may have casts like this:
308
+ //
309
+ // (ref.cast_desc nullref
310
+ // (some struct...)
311
+ // (some desc...)
312
+ // )
313
+ //
314
+ // ReFinalize would fix up the cast target to be the struct type described
315
+ // by the descriptor, but that struct type was supposed to have been
316
+ // optimized out. Optimize out the cast (which we know must either be a
317
+ // null check or unconditional trap) before ReFinalize can get to it.
318
+ fixupDescriptorCasts (*module );
319
+ }
320
+
300
321
// Refinalize to propagate the type changes we made. For example, a refined
301
322
// cast may lead to a struct.get reading a more refined type using that
302
323
// type.
303
324
ReFinalize ().run (getPassRunner (), module );
304
325
}
326
+
327
+ void fixupDescriptorCasts (Module& module ) {
328
+ struct CastFixer : WalkerPass<PostWalker<CastFixer>> {
329
+ bool isFunctionParallel () override { return true ; }
330
+ std::unique_ptr<Pass> create () override {
331
+ return std::make_unique<CastFixer>();
332
+ }
333
+ void visitRefCast (RefCast* curr) {
334
+ if (!curr->desc || !curr->type .isNull ()) {
335
+ return ;
336
+ }
337
+ Block* replacement =
338
+ ChildLocalizer (curr, getFunction (), *getModule (), getPassOptions ())
339
+ .getChildrenReplacement ();
340
+ // Reuse `curr` as the cast to nullref. Leave further optimization of
341
+ // the cast to OptimizeInstructions.
342
+ curr->desc = nullptr ;
343
+ replacement->list .push_back (curr);
344
+ replacement->type = curr->type ;
345
+ replaceCurrent (replacement);
346
+ }
347
+ void visitBrOn (BrOn* curr) {
348
+ if (curr->op != BrOnCastDesc && curr->op != BrOnCastDescFail) {
349
+ return ;
350
+ }
351
+ if (!curr->castType .isNull ()) {
352
+ return ;
353
+ }
354
+ bool isFail = curr->op == BrOnCastDescFail;
355
+ Block* replacement =
356
+ ChildLocalizer (curr, getFunction (), *getModule (), getPassOptions ())
357
+ .getChildrenReplacement ();
358
+ // Reuse `curr` as a br_on_cast to nullref. Leave further optimization
359
+ // of the branch to RemoveUnusedBrs.
360
+ curr->desc = nullptr ;
361
+ curr->op = isFail ? BrOnCastFail : BrOnCast;
362
+ replacement->list .push_back (curr);
363
+ replacement->type = curr->type ;
364
+ replaceCurrent (replacement);
365
+ }
366
+ } fixer;
367
+ fixer.run (getPassRunner (), &module );
368
+ }
305
369
};
306
370
307
371
} // anonymous namespace
0 commit comments