1919#include " mlir/IR/Attributes.h"
2020#include " mlir/IR/Builders.h"
2121#include " mlir/IR/BuiltinAttributes.h"
22+ #include " mlir/IR/BuiltinTypeInterfaces.h"
2223#include " mlir/IR/BuiltinTypes.h"
2324#include " mlir/IR/Operation.h"
2425#include " mlir/IR/Value.h"
@@ -341,9 +342,6 @@ LogicalResult LayoutInfoPropagation::visitOperation(
341342 .Case <xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
342343 visitPrefetchNdOp (prefetchNdOp, operands, results);
343344 })
344- // No need to propagate the layout to operands in CreateNdDescOp because
345- // they are scalars (offsets, sizes, etc.).
346- .Case <xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
347345 .Case <vector::TransposeOp>([&](auto transposeOp) {
348346 visitTransposeOp (transposeOp, operands, results);
349347 })
@@ -355,12 +353,18 @@ LogicalResult LayoutInfoPropagation::visitOperation(
355353 })
356354 // All other ops.
357355 .Default ([&](Operation *op) {
358- for (const LayoutInfoLattice *r : results) {
359- if (r->getValue ().isAssigned ()) {
360- for (LayoutInfoLattice *operand : operands) {
361- // Propagate the layout of the result to the operand.
362- meet (operand, *r);
363- }
356+ for (const LayoutInfoLattice *resultInfo : results) {
357+ if (!resultInfo->getValue ().isAssigned ())
358+ continue ;
359+ for (auto [operandInfo, operand] :
360+ llvm::zip (operands, op->getOpOperands ())) {
361+ // If the operand type is not a vector or tensor descriptor, skip
362+ // it.
363+ if (!isa<xegpu::TensorDescType, VectorType>(
364+ operand.get ().getType ()))
365+ continue ;
366+ // Propagate the result layout to the operand.
367+ meet (operandInfo, *resultInfo);
364368 }
365369 }
366370 });
@@ -456,7 +460,8 @@ void LayoutInfoPropagation::visitLoadNdOp(
456460 return ;
457461 LayoutInfo tensorDescLayout = valueLayout;
458462 // LoadNdOp has the transpose effect. However, at the stage of this analysis
459- // this effect is not expected and should be abstracted away. Emit a warning.
463+ // this effect is not expected and should be abstracted away. Emit a
464+ // warning.
460465 if (auto transpose = load.getTranspose ()) {
461466 load.emitWarning (" Transpose effect is not expected for LoadNdOp at "
462467 " LayoutInfoPropagation stage." );
@@ -495,8 +500,8 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
495500 int outElemTyBitWidth =
496501 bitcast.getResultVectorType ().getElementType ().getIntOrFloatBitWidth ();
497502
498- // NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit a
499- // warning and return.
503+ // NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit
504+ // a warning and return.
500505 if (inElemTyBitWidth != outElemTyBitWidth) {
501506 bitcast.emitWarning (" Widening or narrowing bitcasts are not expected at "
502507 " layout propagation stage." );
@@ -583,7 +588,6 @@ void LayoutInfoPropagation::visitStoreScatterOp(
583588}
584589
585590namespace {
586-
587591// ===----------------------------------------------------------------------===//
588592// RunLayoutInfoPropagation
589593// ===----------------------------------------------------------------------===//
@@ -679,7 +683,6 @@ using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
679683// / attribute.
680684static void updateOp (mlir::OpBuilder &builder, mlir::Operation *op,
681685 GetLayoutFnTy getLayoutOfValue) {
682-
683686 // Iterate over all the results.
684687 for (OpResult result : op->getResults ()) {
685688 Type resultType = result.getType ();
@@ -872,7 +875,6 @@ static void updateFunctionOpInterface(mlir::OpBuilder &builder,
872875}
873876
874877namespace {
875-
876878struct XeGPULayoutPropagatePass final
877879 : public xegpu::impl::XeGPULayoutPropagateBase<XeGPULayoutPropagatePass> {
878880 void runOnOperation () override ;
0 commit comments