@@ -52,19 +52,28 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
5252 // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
5353 // To:
5454 // %val -> %v
55- auto toPtr = dyn_cast_or_null<ToPtrOp>(getPtr ().getDefiningOp ());
56- // Cannot fold if it's not a `to_ptr` op or the initial and final types are
57- // different.
58- if (!toPtr || toPtr.getPtr ().getType () != getType ())
59- return nullptr ;
60- Value md = getMetadata ();
61- if (!md)
62- return toPtr.getPtr ();
63- // Fold if the metadata can be verified to be equal.
64- if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp ());
65- mdOp && mdOp.getPtr () == toPtr.getPtr ())
66- return toPtr.getPtr ();
67- return nullptr ;
55+ Value ptrLike;
56+ FromPtrOp fromPtr = *this ;
57+ while (fromPtr != nullptr ) {
58+ auto toPtr = dyn_cast_or_null<ToPtrOp>(fromPtr.getPtr ().getDefiningOp ());
59+ // Cannot fold if it's not a `to_ptr` op or the initial and final types are
60+ // different.
61+ if (!toPtr || toPtr.getPtr ().getType () != fromPtr.getType ())
62+ return ptrLike;
63+ Value md = fromPtr.getMetadata ();
64+ // If there's no metadata in the op, either the cast never requires metadata
65+ // or the op has the trivial metadata flag set, therefore fold.
66+ if (!md)
67+ ptrLike = toPtr.getPtr ();
68+ // Fold if the metadata can be verified to be equal.
69+ else if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp ());
70+ mdOp && mdOp.getPtr () == toPtr.getPtr ())
71+ ptrLike = toPtr.getPtr ();
72+ // Check for a sequence of casts.
73+ fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp ()
74+ : nullptr );
75+ }
76+ return ptrLike;
6877}
6978
7079LogicalResult FromPtrOp::verify () {
@@ -113,11 +122,18 @@ OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
113122 // %ptr = ptr.to_ptr %val : type -> ptr
114123 // To:
115124 // %ptr -> %p
116- auto fromPtr = dyn_cast_or_null<FromPtrOp>(getPtr ().getDefiningOp ());
117- // Cannot fold if it's not a `from_ptr` op.
118- if (!fromPtr)
119- return nullptr ;
120- return fromPtr.getPtr ();
125+ Value ptr;
126+ ToPtrOp toPtr = *this ;
127+ while (toPtr != nullptr ) {
128+ auto fromPtr = dyn_cast_or_null<FromPtrOp>(toPtr.getPtr ().getDefiningOp ());
129+ // Cannot fold if it's not a `from_ptr` op.
130+ if (!fromPtr)
131+ return ptr;
132+ ptr = fromPtr.getPtr ();
133+ // Check for chains of casts.
134+ toPtr = dyn_cast_or_null<ToPtrOp>(ptr.getDefiningOp ());
135+ }
136+ return ptr;
121137}
122138
123139LogicalResult ToPtrOp::verify () {
0 commit comments