@@ -444,9 +444,8 @@ void LayoutInfoPropagation::visitStoreNdOp(
444444 ArrayRef<const LayoutInfoLattice *> results) {
445445 LayoutInfo storeLayout = getDefaultSIMTLayoutInfo (store.getValueType ());
446446 // Both operands should have the same layout
447- for (LayoutInfoLattice *operand : operands) {
447+ for (LayoutInfoLattice *operand : operands)
448448 propagateIfChanged (operand, operand->meet (storeLayout));
449- }
450449}
451450
452451// / Propagate the layout of the value to the tensor descriptor operand in
@@ -659,20 +658,18 @@ RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
659658
660659 SmallVector<FunctionOpInterface> funcOps;
661660 if (auto modOp = dyn_cast<ModuleOp>(target)) {
662- for (auto funcOp : modOp.getOps <FunctionOpInterface>()) {
661+ for (auto funcOp : modOp.getOps <FunctionOpInterface>())
663662 funcOps.push_back (funcOp);
664- }
663+
665664 // Collect all GpuFuncOps in the module.
666665 for (auto gpuModOp : modOp.getOps <gpu::GPUModuleOp>()) {
667- for (auto gpuFuncOp : gpuModOp.getOps <FunctionOpInterface>()) {
666+ for (auto gpuFuncOp : gpuModOp.getOps <FunctionOpInterface>())
668667 funcOps.push_back (gpuFuncOp);
669- }
670668 }
671669 }
672670 // Print the analysis result for each function.
673- for (FunctionOpInterface funcOp : funcOps) {
671+ for (FunctionOpInterface funcOp : funcOps)
674672 printFunctionResult (funcOp);
675- }
676673}
677674
678675using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
@@ -706,7 +703,6 @@ static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
706703 }
707704 // If the result is a vector type, add a temporary layout attribute to the
708705 // op.
709- std::string resultLayoutName = xegpu::getLayoutName (result);
710706 xegpu::setLayoutAttr (result, layout);
711707 }
712708}
@@ -717,6 +713,7 @@ static void updateBranchTerminatorOpInterface(
717713 mlir::OpBuilder &builder,
718714 mlir::RegionBranchTerminatorOpInterface terminator,
719715 GetLayoutFnTy getLayoutOfValue) {
716+ // Only process if the terminator is inside a region branch op.
720717 if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp ()))
721718 return ;
722719
@@ -729,9 +726,10 @@ static void updateBranchTerminatorOpInterface(
729726 if (!successor.isParent ())
730727 continue ;
731728
732- mlir::OperandRange operands = terminator.getSuccessorOperands (successor);
733- mlir::ValueRange inputs = successor.getSuccessorInputs ();
734- for (auto [operand, input] : llvm::zip (operands, inputs)) {
729+ mlir::OperandRange forwardedOperands =
730+ terminator.getSuccessorOperands (successor);
731+ mlir::ValueRange regionArgs = successor.getSuccessorInputs ();
732+ for (auto [operand, input] : llvm::zip (forwardedOperands, regionArgs)) {
735733 // print arg and inp
736734 // llvm::errs() << "arg: " << operand << ", inp: " << input << "\n";
737735 Type inputType = input.getType ();
@@ -773,38 +771,43 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
773771 llvm::SmallVector<mlir::RegionSuccessor> successors;
774772 llvm::SmallVector<mlir::Attribute> operands (op->getNumOperands (), nullptr );
775773 branch.getEntrySuccessorRegions (operands, successors);
776- DenseMap<Value, xegpu::LayoutAttr> resultToLayouts;
774+ DenseMap<Value, xegpu::LayoutAttr>
775+ resultToLayouts; // This map keeps track of layouts of any unused results
776+ // of the branch op.
777777 mlir::ValueRange results = op->getResults ();
778778
779779 for (mlir::RegionSuccessor &successor : successors) {
780+ // Only interested in successor regions that are contained within the op.
780781 if (successor.isParent ())
781782 continue ;
782783
783- mlir::OperandRange operands = branch.getEntrySuccessorOperands (successor);
784- mlir::ValueRange inputs = successor.getSuccessorInputs ();
784+ mlir::OperandRange forwardedOperands =
785+ branch.getEntrySuccessorOperands (successor);
786+ mlir::ValueRange regionArgs = successor.getSuccessorInputs ();
785787
786- for (auto [operand, input, result] : llvm::zip (operands, inputs, results)) {
787- Type inputType = input.getType ();
788+ for (auto [forwardedOperand, regionArg, result] :
789+ llvm::zip (forwardedOperands, regionArgs, results)) {
790+ Type inputType = regionArg.getType ();
788791 if (!isa<xegpu::TensorDescType>(inputType))
789792 continue ;
790- xegpu::LayoutAttr inputLayout = getLayoutOfValue (input );
791- xegpu::LayoutAttr operandLayout = getLayoutOfValue (operand );
793+ xegpu::LayoutAttr inputLayout = getLayoutOfValue (regionArg );
794+ xegpu::LayoutAttr operandLayout = getLayoutOfValue (forwardedOperand );
792795
793796 if (!inputLayout || !operandLayout) {
794- LLVM_DEBUG (DBGS () << " No layout assigned for block arg: " << input
795- << " or init arg: " << operand << " \n " );
797+ LLVM_DEBUG (DBGS () << " No layout assigned for block arg: " << regionArg
798+ << " or init arg: " << forwardedOperand << " \n " );
796799 continue ;
797800 }
798801
799802 // TODO: We expect these two to match.
800803 assert (inputLayout == operandLayout &&
801- " Expexing block arg and init arg to have the same layout." );
804+ " Expecting block arg and init arg to have the same layout." );
802805 // Get tensor descriptor type with the layout.
803806 auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType);
804807 auto newTdescTy = xegpu::TensorDescType::get (
805808 tdescTy.getContext (), tdescTy.getShape (), tdescTy.getElementType (),
806809 tdescTy.getEncoding (), inputLayout);
807- input .setType (newTdescTy);
810+ regionArg .setType (newTdescTy);
808811 // Store the layout for the result.
809812 if (resultToLayouts.count (result) != 0 &&
810813 resultToLayouts[result] != inputLayout) {
@@ -837,7 +840,6 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
837840 }
838841 // If the result is a vector type, add a temporary layout attribute to
839842 // the op.
840- std::string resultLayoutName = xegpu::getLayoutName (r);
841843 xegpu::setLayoutAttr (r, layout);
842844 }
843845}
@@ -865,7 +867,6 @@ static void updateFunctionOpInterface(mlir::OpBuilder &builder,
865867 tensorDescTy.getElementType (), tensorDescTy.getEncoding (), layout);
866868 arg.setType (newTdescTy);
867869 newArgTypes.back () = newTdescTy;
868- continue ;
869870 }
870871 }
871872 // Update the function type with the new argument types.
@@ -887,9 +888,9 @@ void XeGPULayoutPropagatePass::runOnOperation() {
887888 // Helper to convert LayoutInfo to xegpu::LayoutAttr.
888889 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
889890 LayoutInfo layout = analyis.getLayoutInfo (val);
890- if (!layout.isAssigned ()) {
891+ if (!layout.isAssigned ())
891892 return {};
892- }
893+
893894 SmallVector<int , 2 > laneLayout, laneData;
894895 for (auto [layout, data] : llvm::zip_equal (layout.getLayoutAsArrayRef (),
895896 layout.getDataAsArrayRef ())) {
0 commit comments