@@ -1402,30 +1402,35 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
14021402 // Save the operand to replace / delete later (avoid iterator invalidation).
14031403 // TODO: can we use an early_inc iterator?
14041404 for (OpOperand &use : oldUse->getUses ()) {
1405+ // Propagate through `ttg.warp_specialize`.
1406+ if (auto wsOp = dyn_cast<ttg::WarpSpecializeOp>(use.getOwner ())) {
1407+ for (Region *region : wsOp.getPartitionRegions ())
1408+ region->getArgument (use.getOperandNumber ()).setType (val.getType ());
1409+ }
1410+
14051411 // Non-subview/trans ops will be replaced by `val`.
1406- if (!isa<triton::gpu::MemDescTransOp, triton::gpu::MemDescSubviewOp>(
1407- use.getOwner ())) {
1412+ if (!isa<ttg::MemDescTransOp, ttg::MemDescSubviewOp>(use.getOwner ())) {
14081413 operandsToReplace.push_back (&use);
14091414 continue ;
14101415 }
1416+
14111417 Operation *user = use.getOwner ();
14121418 // `subview(old_op)` is replaced by a new `subview(val)`.
14131419 OpBuilder::InsertionGuard g (builder);
14141420 builder.setInsertionPoint (user);
14151421 Value newVal;
1416- if (auto subview = dyn_cast<triton::gpu::MemDescSubviewOp>(user)) {
1417- triton::gpu::MemDescType oldType = subview.getType ();
1418- bool isMutable =
1419- cast<triton::gpu::MemDescType>(val.getType ()).getMutableMemory ();
1420- Type newDstType = triton::gpu::MemDescType::get (
1422+ if (auto subview = dyn_cast<ttg::MemDescSubviewOp>(user)) {
1423+ ttg::MemDescType oldType = subview.getType ();
1424+ bool isMutable = cast<ttg::MemDescType>(val.getType ()).getMutableMemory ();
1425+ Type newDstType = ttg::MemDescType::get (
14211426 oldType.getShape (), oldType.getElementType (), oldType.getEncoding (),
14221427 oldType.getMemorySpace (), isMutable);
1423- newVal = builder.create <triton::gpu ::MemDescSubviewOp>(
1428+ newVal = builder.create <ttg ::MemDescSubviewOp>(
14241429 subview.getLoc (), newDstType, val, subview.getOffsets ());
14251430 newVal.getDefiningOp ()->setAttrs (user->getAttrs ());
1426- } else if (auto trans = dyn_cast<triton::gpu ::MemDescTransOp>(user)) {
1427- newVal = builder.create <triton::gpu ::MemDescTransOp>(trans.getLoc (), val,
1428- trans.getOrder ());
1431+ } else if (auto trans = dyn_cast<ttg ::MemDescTransOp>(user)) {
1432+ newVal = builder.create <ttg ::MemDescTransOp>(trans.getLoc (), val,
1433+ trans.getOrder ());
14291434 newVal.getDefiningOp ()->setAttrs (user->getAttrs ());
14301435 }
14311436 assert (newVal);
0 commit comments