Skip to content

Commit 32f8c79

Browse files
committed
address comments
1 parent 76671e2 commit 32f8c79

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/Attributes.h"
2020
#include "mlir/IR/Builders.h"
2121
#include "mlir/IR/BuiltinAttributes.h"
22+
#include "mlir/IR/BuiltinTypeInterfaces.h"
2223
#include "mlir/IR/BuiltinTypes.h"
2324
#include "mlir/IR/Operation.h"
2425
#include "mlir/IR/Value.h"
@@ -341,9 +342,6 @@ LogicalResult LayoutInfoPropagation::visitOperation(
341342
.Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
342343
visitPrefetchNdOp(prefetchNdOp, operands, results);
343344
})
344-
// No need to propagate the layout to operands in CreateNdDescOp because
345-
// they are scalars (offsets, sizes, etc.).
346-
.Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
347345
.Case<vector::TransposeOp>([&](auto transposeOp) {
348346
visitTransposeOp(transposeOp, operands, results);
349347
})
@@ -355,12 +353,18 @@ LogicalResult LayoutInfoPropagation::visitOperation(
355353
})
356354
// All other ops.
357355
.Default([&](Operation *op) {
358-
for (const LayoutInfoLattice *r : results) {
359-
if (r->getValue().isAssigned()) {
360-
for (LayoutInfoLattice *operand : operands) {
361-
// Propagate the layout of the result to the operand.
362-
meet(operand, *r);
363-
}
356+
for (const LayoutInfoLattice *resultInfo : results) {
357+
if (!resultInfo->getValue().isAssigned())
358+
continue;
359+
for (auto [operandInfo, operand] :
360+
llvm::zip(operands, op->getOpOperands())) {
361+
// If the operand type is not a vector or tensor descriptor, skip
362+
// it.
363+
if (!isa<xegpu::TensorDescType, VectorType>(
364+
operand.get().getType()))
365+
continue;
366+
// Propagate the result layout to the operand.
367+
meet(operandInfo, *resultInfo);
364368
}
365369
}
366370
});
@@ -456,7 +460,8 @@ void LayoutInfoPropagation::visitLoadNdOp(
456460
return;
457461
LayoutInfo tensorDescLayout = valueLayout;
458462
// LoadNdOp has the transpose effect. However, at the stage of this analysis
459-
// this effect is not expected and should be abstracted away. Emit a warning.
463+
// this effect is not expected and should be abstracted away. Emit a
464+
// warning.
460465
if (auto transpose = load.getTranspose()) {
461466
load.emitWarning("Transpose effect is not expected for LoadNdOp at "
462467
"LayoutInfoPropagation stage.");
@@ -495,8 +500,8 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
495500
int outElemTyBitWidth =
496501
bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
497502

498-
// NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit a
499-
// warning and return.
503+
// NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit
504+
// a warning and return.
500505
if (inElemTyBitWidth != outElemTyBitWidth) {
501506
bitcast.emitWarning("Widening or narrowing bitcasts are not expected at "
502507
"layout propagation stage.");
@@ -583,7 +588,6 @@ void LayoutInfoPropagation::visitStoreScatterOp(
583588
}
584589

585590
namespace {
586-
587591
//===----------------------------------------------------------------------===//
588592
// RunLayoutInfoPropagation
589593
//===----------------------------------------------------------------------===//
@@ -679,7 +683,6 @@ using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
679683
/// attribute.
680684
static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
681685
GetLayoutFnTy getLayoutOfValue) {
682-
683686
// Iterate over all the results.
684687
for (OpResult result : op->getResults()) {
685688
Type resultType = result.getType();
@@ -872,7 +875,6 @@ static void updateFunctionOpInterface(mlir::OpBuilder &builder,
872875
}
873876

874877
namespace {
875-
876878
struct XeGPULayoutPropagatePass final
877879
: public xegpu::impl::XeGPULayoutPropagateBase<XeGPULayoutPropagatePass> {
878880
void runOnOperation() override;

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,7 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
812812
}
813813

814814
void XeGPUSubgroupDistributePass::runOnOperation() {
815-
// Step 1: Attach layout to op operands.
815+
// Step 1: Attach layouts to op operands.
816816
// TODO: Following assumptions are made:
817817
// 1) It is assumed that there are no layout conflicts.
818818
// 2) Any existing layout attributes attached to the operands are ignored.
@@ -853,7 +853,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
853853
}
854854
});
855855
}
856-
// Step 3: Finally, Apply subgroup to workitem distribution patterns.
856+
// Step 3: Apply subgroup to workitem distribution patterns.
857857
RewritePatternSet patterns(&getContext());
858858
xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
859859
// TODO: distributionFn and shuffleFn are not used at this point.
@@ -874,9 +874,9 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
874874
return;
875875
}
876876

877-
// Step 4: Clean up UnrealizedConversionCastOps that were inserted due to
878-
// tensor desc type mismatches created by using upstream distribution patterns
879-
// (scf.for)
877+
// Step 4: Finllay, clean up UnrealizedConversionCastOps that were inserted
878+
// due to tensor desc type mismatches created by using upstream distribution
879+
// patterns (scf.for)
880880
getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
881881
// We are only interested in UnrealizedConversionCastOps there were added
882882
// for resolving SIMT type mismatches.

0 commit comments

Comments
 (0)