Skip to content

Commit 5705d74

Browse files
committed
address comments
1 parent 2c66eac commit 5705d74

File tree

1 file changed

+45
-84
lines changed

1 file changed

+45
-84
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp

Lines changed: 45 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,10 @@ using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
681681
/// attribute.
682682
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
683683
GetLayoutFnTy getLayoutOfValue) {
684+
// Region ops (like scf.for) are already handled by the updateControlFlowOps.
685+
if (mlir::isa<mlir::RegionBranchOpInterface>(op))
686+
return success();
687+
684688
// Iterate over all the results.
685689
for (OpResult result : op->getResults()) {
686690
Type resultType = result.getType();
@@ -709,12 +713,27 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
709713
return success();
710714
}
711715

712-
/// Update the types of successor regions of a branch terminator op (scf.yield)
713-
/// with assigned layouts.
714-
static LogicalResult updateBranchTerminatorOpInterface(
715-
mlir::OpBuilder &builder,
716-
mlir::RegionBranchTerminatorOpInterface terminator,
717-
GetLayoutFnTy getLayoutOfValue) {
716+
/// Update the types of successor regions at control-flow transfer points. If
717+
/// the control flow transfers to a new block the block arguments are updated.
718+
/// If the control flow transfers out of the region op, the result types of the
719+
/// region op are updated.
720+
/// Example:
721+
/// clang-format off
722+
/// scf.for ... iter_args(...) -> (out types) {
723+
/// ^bb0(block types):
724+
/// ...
725+
/// scf.yield ... : (yield types)
726+
/// }
727+
/// clang-format on
728+
/// In this example, at scf.yield, control-flow can transfer to successor
729+
/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
730+
/// itself (yield the results). So we update both the block arguments of the
731+
/// successor region (i.e. block types) and the result types of the scf.for op
732+
/// (i.e. out types). Note that yield types are updated by respective producers.
733+
static LogicalResult
734+
updateControlFlowOps(mlir::OpBuilder &builder,
735+
mlir::RegionBranchTerminatorOpInterface terminator,
736+
GetLayoutFnTy getLayoutOfValue) {
718737
// Only process if the terminator is inside a region branch op.
719738
if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
720739
return success();
@@ -725,101 +744,48 @@ static LogicalResult updateBranchTerminatorOpInterface(
725744
terminator.getSuccessorRegions(operands, successors);
726745

727746
for (mlir::RegionSuccessor &successor : successors) {
728-
mlir::OperandRange forwardedOperands =
747+
mlir::OperandRange successorOperands =
729748
terminator.getSuccessorOperands(successor);
730-
mlir::ValueRange regionArgs = successor.getSuccessorInputs();
731-
for (auto [forwardedOperand, regionArg] :
732-
llvm::zip(forwardedOperands, regionArgs)) {
733-
Type inputType = regionArg.getType();
749+
mlir::ValueRange successorInputs = successor.getSuccessorInputs();
750+
for (auto [successorOperand, successorInput] :
751+
llvm::zip(successorOperands, successorInputs)) {
752+
Type inputType = successorInput.getType();
734753
// We only need to operate on tensor descriptor or vector types.
735754
if (!isa<xegpu::TensorDescType, VectorType>(inputType))
736755
continue;
737-
xegpu::LayoutAttr argLayout = getLayoutOfValue(regionArg);
738-
xegpu::LayoutAttr operandLayout = getLayoutOfValue(forwardedOperand);
756+
xegpu::LayoutAttr successorInputLayout = getLayoutOfValue(successorInput);
757+
xegpu::LayoutAttr successorOperandLayout =
758+
getLayoutOfValue(successorOperand);
739759

740760
// If either of the layouts is not assigned, we cannot proceed.
741-
if (!operandLayout) {
761+
if (!successorOperandLayout) {
742762
LLVM_DEBUG(
743763
DBGS()
744764
<< "No layout assigned for forwarded operand in branch terminator: "
745-
<< forwardedOperand << "\n");
765+
<< successorOperand << "\n");
746766
return failure();
747767
}
748768
// We expect the layouts to match.
749-
if (argLayout && argLayout != operandLayout) {
769+
if (successorInputLayout &&
770+
successorInputLayout != successorOperandLayout) {
750771
LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
751772
"operand forwarded as the argument: "
752-
<< argLayout << " vs " << operandLayout << "\n");
773+
<< successorInputLayout << " vs "
774+
<< successorOperandLayout << "\n");
753775
return failure();
754776
}
755777
// Get tensor descriptor type with the layout.
756778
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
757779
auto newTdescTy = xegpu::TensorDescType::get(
758780
tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
759-
tdescTy.getEncoding(), operandLayout);
760-
regionArg.setType(newTdescTy);
781+
tdescTy.getEncoding(), successorOperandLayout);
782+
successorInput.setType(newTdescTy);
761783
continue;
762784
}
763785
// If the type is a vector type and this region argument is an OpResult,
764786
// set the layout attribute on the OpResult.
765-
if (auto result = dyn_cast<OpResult>(regionArg))
766-
xegpu::setLayoutAttr(result, operandLayout);
767-
}
768-
}
769-
return success();
770-
}
771-
772-
/// Some operations contain multiple regions (like scf.for) each of which have
773-
/// block arguments. This function updates the block arguments types of such
774-
/// regions with the assigned layouts. Note that results of the region op is
775-
/// updated by the branch terminator op interface.
776-
static LogicalResult
777-
updateBranchOpInterface(mlir::OpBuilder &builder,
778-
mlir::RegionBranchOpInterface branch,
779-
GetLayoutFnTy getLayoutOfValue) {
780-
mlir::Operation *op = branch.getOperation();
781-
llvm::SmallVector<mlir::RegionSuccessor> entrySuccessors;
782-
llvm::SmallVector<mlir::Attribute> operands(op->getNumOperands(), nullptr);
783-
branch.getEntrySuccessorRegions(operands, entrySuccessors);
784-
785-
for (mlir::RegionSuccessor &successor : entrySuccessors) {
786-
// Only interested in successor regions that are contained within the op.
787-
if (successor.isParent())
788-
continue;
789-
790-
mlir::OperandRange forwardedOperands =
791-
branch.getEntrySuccessorOperands(successor);
792-
mlir::ValueRange regionArgs = successor.getSuccessorInputs();
793-
794-
for (auto [forwardedOperand, regionArg] :
795-
llvm::zip(forwardedOperands, regionArgs)) {
796-
Type inputType = regionArg.getType();
797-
// Only update tensor descriptor types in region args.
798-
if (!isa<xegpu::TensorDescType>(inputType))
799-
continue;
800-
xegpu::LayoutAttr argLayout = getLayoutOfValue(regionArg);
801-
xegpu::LayoutAttr operandLayout = getLayoutOfValue(forwardedOperand);
802-
803-
if (!argLayout || !operandLayout) {
804-
LLVM_DEBUG(DBGS() << "No layout assigned for region arg: " << regionArg
805-
<< " or forwarded operand to that arg: "
806-
<< forwardedOperand << "\n");
807-
return failure();
808-
}
809-
810-
// We expect the layouts to match.
811-
if (argLayout != operandLayout) {
812-
LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
813-
"operand forwarded as the argument: "
814-
<< argLayout << " vs " << operandLayout << "\n");
815-
return failure();
816-
}
817-
// Get tensor descriptor type with the layout.
818-
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType);
819-
auto newTdescTy = xegpu::TensorDescType::get(
820-
tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
821-
tdescTy.getEncoding(), argLayout);
822-
regionArg.setType(newTdescTy);
787+
if (auto result = dyn_cast<OpResult>(successorInput))
788+
xegpu::setLayoutAttr(result, successorOperandLayout);
823789
}
824790
}
825791
return success();
@@ -885,13 +851,8 @@ void XeGPULayoutPropagatePass::runOnOperation() {
885851
TypeSwitch<Operation *>(&op)
886852
.Case<mlir::RegionBranchTerminatorOpInterface>(
887853
[&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
888-
r = updateBranchTerminatorOpInterface(builder, branchTermOp,
889-
getXeGPULayoutForValue);
890-
})
891-
.Case<mlir::RegionBranchOpInterface>(
892-
[&](mlir::RegionBranchOpInterface regionBrOp) {
893-
r = updateBranchOpInterface(builder, regionBrOp,
894-
getXeGPULayoutForValue);
854+
r = updateControlFlowOps(builder, branchTermOp,
855+
getXeGPULayoutForValue);
895856
})
896857
.Case<mlir::FunctionOpInterface>(
897858
[&](mlir::FunctionOpInterface funcOp) {

0 commit comments

Comments
 (0)