@@ -681,6 +681,10 @@ using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
681681// / attribute.
682682static LogicalResult updateOp (mlir::OpBuilder &builder, mlir::Operation *op,
683683 GetLayoutFnTy getLayoutOfValue) {
684+ // Region ops (like scf.for) are already handled by the updateControlFlowOps.
685+ if (mlir::isa<mlir::RegionBranchOpInterface>(op))
686+ return success ();
687+
684688 // Iterate over all the results.
685689 for (OpResult result : op->getResults ()) {
686690 Type resultType = result.getType ();
@@ -709,12 +713,27 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
709713 return success ();
710714}
711715
712- // / Update the types of successor regions of a branch terminator op (scf.yield)
713- // / with assigned layouts.
714- static LogicalResult updateBranchTerminatorOpInterface (
715- mlir::OpBuilder &builder,
716- mlir::RegionBranchTerminatorOpInterface terminator,
717- GetLayoutFnTy getLayoutOfValue) {
716+ // / Update the types of successor regions at control-flow transfer points. If
717+ // / the control flow transfers to a new block the block arguments are updated.
718+ // / If the control flow transfers out of the region op, the result types of the
719+ // / region op are updated.
720+ // / Example:
721+ // / clang-format off
722+ // / scf.for ... iter_args(...) -> (out types) {
723+ // / ^bb0(block types):
724+ // / ...
725+ // / scf.yield ... : (yield types)
726+ // / }
727+ // / clang-format on
728+ // / In this example, at scf.yield, control-flow can transfer to successor
729+ // / regions. One is the ^bb0 (for loop body) and the other is the scf.for op
730+ // / itself (yield the results). So we update both the block arguments of the
731+ // / successor region (i.e. block types) and the result types of the scf.for op
732+ // / (i.e. out types). Note that yield types are updated by respective producers.
733+ static LogicalResult
734+ updateControlFlowOps (mlir::OpBuilder &builder,
735+ mlir::RegionBranchTerminatorOpInterface terminator,
736+ GetLayoutFnTy getLayoutOfValue) {
718737 // Only process if the terminator is inside a region branch op.
719738 if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp ()))
720739 return success ();
@@ -725,101 +744,48 @@ static LogicalResult updateBranchTerminatorOpInterface(
725744 terminator.getSuccessorRegions (operands, successors);
726745
727746 for (mlir::RegionSuccessor &successor : successors) {
728- mlir::OperandRange forwardedOperands =
747+ mlir::OperandRange successorOperands =
729748 terminator.getSuccessorOperands (successor);
730- mlir::ValueRange regionArgs = successor.getSuccessorInputs ();
731- for (auto [forwardedOperand, regionArg ] :
732- llvm::zip (forwardedOperands, regionArgs )) {
733- Type inputType = regionArg .getType ();
749+ mlir::ValueRange successorInputs = successor.getSuccessorInputs ();
750+ for (auto [successorOperand, successorInput ] :
751+ llvm::zip (successorOperands, successorInputs )) {
752+ Type inputType = successorInput .getType ();
734753 // We only need to operate on tensor descriptor or vector types.
735754 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
736755 continue ;
737- xegpu::LayoutAttr argLayout = getLayoutOfValue (regionArg);
738- xegpu::LayoutAttr operandLayout = getLayoutOfValue (forwardedOperand);
756+ xegpu::LayoutAttr successorInputLayout = getLayoutOfValue (successorInput);
757+ xegpu::LayoutAttr successorOperandLayout =
758+ getLayoutOfValue (successorOperand);
739759
740760 // If either of the layouts is not assigned, we cannot proceed.
741- if (!operandLayout ) {
761+ if (!successorOperandLayout ) {
742762 LLVM_DEBUG (
743763 DBGS ()
744764 << " No layout assigned for forwarded operand in branch terminator: "
745- << forwardedOperand << " \n " );
765+ << successorOperand << " \n " );
746766 return failure ();
747767 }
748768 // We expect the layouts to match.
749- if (argLayout && argLayout != operandLayout) {
769+ if (successorInputLayout &&
770+ successorInputLayout != successorOperandLayout) {
750771 LLVM_DEBUG (DBGS () << " Conflicting layouts for region argument and "
751772 " operand forwarded as the argument: "
752- << argLayout << " vs " << operandLayout << " \n " );
773+ << successorInputLayout << " vs "
774+ << successorOperandLayout << " \n " );
753775 return failure ();
754776 }
755777 // Get tensor descriptor type with the layout.
756778 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
757779 auto newTdescTy = xegpu::TensorDescType::get (
758780 tdescTy.getContext (), tdescTy.getShape (), tdescTy.getElementType (),
759- tdescTy.getEncoding (), operandLayout );
760- regionArg .setType (newTdescTy);
781+ tdescTy.getEncoding (), successorOperandLayout );
782+ successorInput .setType (newTdescTy);
761783 continue ;
762784 }
763785 // If the type is a vector type and this region argument is an OpResult,
764786 // set the layout attribute on the OpResult.
765- if (auto result = dyn_cast<OpResult>(regionArg))
766- xegpu::setLayoutAttr (result, operandLayout);
767- }
768- }
769- return success ();
770- }
771-
772- // / Some operations contain multiple regions (like scf.for) each of which have
773- // / block arguments. This function updates the block arguments types of such
774- // / regions with the assigned layouts. Note that results of the region op is
775- // / updated by the branch terminator op interface.
776- static LogicalResult
777- updateBranchOpInterface (mlir::OpBuilder &builder,
778- mlir::RegionBranchOpInterface branch,
779- GetLayoutFnTy getLayoutOfValue) {
780- mlir::Operation *op = branch.getOperation ();
781- llvm::SmallVector<mlir::RegionSuccessor> entrySuccessors;
782- llvm::SmallVector<mlir::Attribute> operands (op->getNumOperands (), nullptr );
783- branch.getEntrySuccessorRegions (operands, entrySuccessors);
784-
785- for (mlir::RegionSuccessor &successor : entrySuccessors) {
786- // Only interested in successor regions that are contained within the op.
787- if (successor.isParent ())
788- continue ;
789-
790- mlir::OperandRange forwardedOperands =
791- branch.getEntrySuccessorOperands (successor);
792- mlir::ValueRange regionArgs = successor.getSuccessorInputs ();
793-
794- for (auto [forwardedOperand, regionArg] :
795- llvm::zip (forwardedOperands, regionArgs)) {
796- Type inputType = regionArg.getType ();
797- // Only update tensor descriptor types in region args.
798- if (!isa<xegpu::TensorDescType>(inputType))
799- continue ;
800- xegpu::LayoutAttr argLayout = getLayoutOfValue (regionArg);
801- xegpu::LayoutAttr operandLayout = getLayoutOfValue (forwardedOperand);
802-
803- if (!argLayout || !operandLayout) {
804- LLVM_DEBUG (DBGS () << " No layout assigned for region arg: " << regionArg
805- << " or forwarded operand to that arg: "
806- << forwardedOperand << " \n " );
807- return failure ();
808- }
809-
810- // We expect the layouts to match.
811- if (argLayout != operandLayout) {
812- LLVM_DEBUG (DBGS () << " Conflicting layouts for region argument and "
813- " operand forwarded as the argument: "
814- << argLayout << " vs " << operandLayout << " \n " );
815- return failure ();
816- }
817- // Get tensor descriptor type with the layout.
818- auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType);
819- auto newTdescTy = xegpu::TensorDescType::get (
820- tdescTy.getContext (), tdescTy.getShape (), tdescTy.getElementType (),
821- tdescTy.getEncoding (), argLayout);
822- regionArg.setType (newTdescTy);
787+ if (auto result = dyn_cast<OpResult>(successorInput))
788+ xegpu::setLayoutAttr (result, successorOperandLayout);
823789 }
824790 }
825791 return success ();
@@ -885,13 +851,8 @@ void XeGPULayoutPropagatePass::runOnOperation() {
885851 TypeSwitch<Operation *>(&op)
886852 .Case <mlir::RegionBranchTerminatorOpInterface>(
887853 [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
888- r = updateBranchTerminatorOpInterface (builder, branchTermOp,
889- getXeGPULayoutForValue);
890- })
891- .Case <mlir::RegionBranchOpInterface>(
892- [&](mlir::RegionBranchOpInterface regionBrOp) {
893- r = updateBranchOpInterface (builder, regionBrOp,
894- getXeGPULayoutForValue);
854+ r = updateControlFlowOps (builder, branchTermOp,
855+ getXeGPULayoutForValue);
895856 })
896857 .Case <mlir::FunctionOpInterface>(
897858 [&](mlir::FunctionOpInterface funcOp) {
0 commit comments