Skip to content

Commit 9cefe6f

Browse files
committed
address comments
1 parent 32f8c79 commit 9cefe6f

File tree

3 files changed

+34
-34
lines changed

3 files changed

+34
-34
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -444,9 +444,8 @@ void LayoutInfoPropagation::visitStoreNdOp(
444444
ArrayRef<const LayoutInfoLattice *> results) {
445445
LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType());
446446
// Both operands should have the same layout
447-
for (LayoutInfoLattice *operand : operands) {
447+
for (LayoutInfoLattice *operand : operands)
448448
propagateIfChanged(operand, operand->meet(storeLayout));
449-
}
450449
}
451450

452451
/// Propagate the layout of the value to the tensor descriptor operand in
@@ -659,20 +658,18 @@ RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
659658

660659
SmallVector<FunctionOpInterface> funcOps;
661660
if (auto modOp = dyn_cast<ModuleOp>(target)) {
662-
for (auto funcOp : modOp.getOps<FunctionOpInterface>()) {
661+
for (auto funcOp : modOp.getOps<FunctionOpInterface>())
663662
funcOps.push_back(funcOp);
664-
}
663+
665664
// Collect all GpuFuncOps in the module.
666665
for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
667-
for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) {
666+
for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
668667
funcOps.push_back(gpuFuncOp);
669-
}
670668
}
671669
}
672670
// Print the analysis result for each function.
673-
for (FunctionOpInterface funcOp : funcOps) {
671+
for (FunctionOpInterface funcOp : funcOps)
674672
printFunctionResult(funcOp);
675-
}
676673
}
677674

678675
using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
@@ -706,7 +703,6 @@ static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
706703
}
707704
// If the result is a vector type, add a temporary layout attribute to the
708705
// op.
709-
std::string resultLayoutName = xegpu::getLayoutName(result);
710706
xegpu::setLayoutAttr(result, layout);
711707
}
712708
}
@@ -717,6 +713,7 @@ static void updateBranchTerminatorOpInterface(
717713
mlir::OpBuilder &builder,
718714
mlir::RegionBranchTerminatorOpInterface terminator,
719715
GetLayoutFnTy getLayoutOfValue) {
716+
// Only process if the terminator is inside a region branch op.
720717
if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
721718
return;
722719

@@ -729,9 +726,10 @@ static void updateBranchTerminatorOpInterface(
729726
if (!successor.isParent())
730727
continue;
731728

732-
mlir::OperandRange operands = terminator.getSuccessorOperands(successor);
733-
mlir::ValueRange inputs = successor.getSuccessorInputs();
734-
for (auto [operand, input] : llvm::zip(operands, inputs)) {
729+
mlir::OperandRange forwardedOperands =
730+
terminator.getSuccessorOperands(successor);
731+
mlir::ValueRange regionArgs = successor.getSuccessorInputs();
732+
for (auto [operand, input] : llvm::zip(forwardedOperands, regionArgs)) {
735733
// print arg and inp
736734
// llvm::errs() << "arg: " << operand << ", inp: " << input << "\n";
737735
Type inputType = input.getType();
@@ -773,38 +771,43 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
773771
llvm::SmallVector<mlir::RegionSuccessor> successors;
774772
llvm::SmallVector<mlir::Attribute> operands(op->getNumOperands(), nullptr);
775773
branch.getEntrySuccessorRegions(operands, successors);
776-
DenseMap<Value, xegpu::LayoutAttr> resultToLayouts;
774+
DenseMap<Value, xegpu::LayoutAttr>
775+
resultToLayouts; // This map keeps track of layouts of any unused results
776+
// of the branch op.
777777
mlir::ValueRange results = op->getResults();
778778

779779
for (mlir::RegionSuccessor &successor : successors) {
780+
// Only interested in successor regions that are contained within the op.
780781
if (successor.isParent())
781782
continue;
782783

783-
mlir::OperandRange operands = branch.getEntrySuccessorOperands(successor);
784-
mlir::ValueRange inputs = successor.getSuccessorInputs();
784+
mlir::OperandRange forwardedOperands =
785+
branch.getEntrySuccessorOperands(successor);
786+
mlir::ValueRange regionArgs = successor.getSuccessorInputs();
785787

786-
for (auto [operand, input, result] : llvm::zip(operands, inputs, results)) {
787-
Type inputType = input.getType();
788+
for (auto [forwardedOperand, regionArg, result] :
789+
llvm::zip(forwardedOperands, regionArgs, results)) {
790+
Type inputType = regionArg.getType();
788791
if (!isa<xegpu::TensorDescType>(inputType))
789792
continue;
790-
xegpu::LayoutAttr inputLayout = getLayoutOfValue(input);
791-
xegpu::LayoutAttr operandLayout = getLayoutOfValue(operand);
793+
xegpu::LayoutAttr inputLayout = getLayoutOfValue(regionArg);
794+
xegpu::LayoutAttr operandLayout = getLayoutOfValue(forwardedOperand);
792795

793796
if (!inputLayout || !operandLayout) {
794-
LLVM_DEBUG(DBGS() << "No layout assigned for block arg: " << input
795-
<< " or init arg: " << operand << "\n");
797+
LLVM_DEBUG(DBGS() << "No layout assigned for block arg: " << regionArg
798+
<< " or init arg: " << forwardedOperand << "\n");
796799
continue;
797800
}
798801

799802
// TODO: We expect these two to match.
800803
assert(inputLayout == operandLayout &&
801-
"Expexing block arg and init arg to have the same layout.");
804+
"Expecting block arg and init arg to have the same layout.");
802805
// Get tensor descriptor type with the layout.
803806
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType);
804807
auto newTdescTy = xegpu::TensorDescType::get(
805808
tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
806809
tdescTy.getEncoding(), inputLayout);
807-
input.setType(newTdescTy);
810+
regionArg.setType(newTdescTy);
808811
// Store the layout for the result.
809812
if (resultToLayouts.count(result) != 0 &&
810813
resultToLayouts[result] != inputLayout) {
@@ -837,7 +840,6 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
837840
}
838841
// If the result is a vector type, add a temporary layout attribute to
839842
// the op.
840-
std::string resultLayoutName = xegpu::getLayoutName(r);
841843
xegpu::setLayoutAttr(r, layout);
842844
}
843845
}
@@ -865,7 +867,6 @@ static void updateFunctionOpInterface(mlir::OpBuilder &builder,
865867
tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
866868
arg.setType(newTdescTy);
867869
newArgTypes.back() = newTdescTy;
868-
continue;
869870
}
870871
}
871872
// Update the function type with the new argument types.
@@ -887,9 +888,9 @@ void XeGPULayoutPropagatePass::runOnOperation() {
887888
// Helper to convert LayoutInfo to xegpu::LayoutAttr.
888889
auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
889890
LayoutInfo layout = analyis.getLayoutInfo(val);
890-
if (!layout.isAssigned()) {
891+
if (!layout.isAssigned())
891892
return {};
892-
}
893+
893894
SmallVector<int, 2> laneLayout, laneData;
894895
for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
895896
layout.getDataAsArrayRef())) {

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
9797
// dimensions are not distributed.
9898
unsigned distributionStart = originalType.getRank() - laneLayout.size();
9999
for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
100-
if (i < distributionStart) {
100+
if (i < distributionStart)
101101
continue;
102-
}
102+
103103
// Check if the dimension can be distributed evenly.
104104
if (dim % laneLayout[i - distributionStart] != 0)
105105
return failure();
@@ -848,9 +848,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
848848
// GPU index ops, scalar constants, etc.). This will simplify the
849849
// later lowering and avoid custom patterns for these ops.
850850
getOperation()->walk([&](Operation *op) {
851-
if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
851+
if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
852852
vector::moveScalarUniformCode(warpOp);
853-
}
854853
});
855854
}
856855
// Step 3: Apply subgroup to workitem distribution patterns.

mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ gpu.module @test {
166166
}
167167

168168
// -----
169-
// TODO: gemm does not use update_nd_offset because of an issue in vector distribution. PR141853 tracks this issue.
170-
// CHECK-LABEL: gpu.func @gemm_loop
169+
// TODO: gemm does not use update_nd_offset because of an issue in scf-for distribution.
170+
// CHECK-LABEL: gpu.func @gemm
171171
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
172172
// CHECK-DAG: %[[BLOCK_ID_X:.*]] = gpu.block_id x
173173
// CHECK-DAG: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
@@ -189,7 +189,7 @@ gpu.module @test {
189189
// CHECK-NEXT: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
190190
// CHECK-NEXT: xegpu.store_nd %[[T9]], %[[T2]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
191191
gpu.module @test {
192-
gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
192+
gpu.func @gemm(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
193193
%c0 = arith.constant 0 : index
194194
%c16 = arith.constant 16 : index
195195
%c8 = arith.constant 8 : index

0 commit comments

Comments
 (0)