@@ -119,6 +119,33 @@ static std::optional<T> findProducerOfType(Value val) {
119119 return findProducerOfType<T>(producerOp->getOperand (0 ));
120120}
121121
122+ // / Find layout attribute in producer chain.
123+ // / Traces producer ops until a layout attribute is found. Only traces through
124+ // / ops with a single operand, in other cases the op's result layout attribute
125+ // / must be set. Returns std::nullopt if no layout attribute is found.
126+ xegpu::LayoutAttr findProducerLayout (Value val) {
127+ // Get layout attr from value or producer's attribute or operand.
128+ if (auto layoutAttr = dyn_cast_if_present<xegpu::LayoutAttr>(
129+ xegpu::getDistributeLayoutAttr (val)))
130+ return layoutAttr;
131+
132+ // Recurse up the producer chain.
133+ Operation *producerOp = val.getDefiningOp ();
134+ if (!producerOp) {
135+ LDBG () << " Failed to find producer op." ;
136+ return nullptr ;
137+ }
138+ if (producerOp->getNumOperands () == 0 ) {
139+ LDBG () << " Producer has no operands." ;
140+ return nullptr ;
141+ }
142+ if (producerOp->getNumOperands () > 1 ) {
143+ LDBG () << " Producer has multiple operands." ;
144+ return nullptr ;
145+ }
146+ return findProducerLayout (producerOp->getOperand (0 ));
147+ }
148+
122149// / Create a layout attribute from the given parameters.
123150static xegpu::LayoutAttr
124151createLayoutAttr (MLIRContext *ctx, ArrayRef<int32_t > sgLayout,
@@ -436,34 +463,33 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
436463 << llvm::range_size (targetValues) << " )" ;
437464 auto value = *targetValues.begin ();
438465
439- xegpu::LayoutAttr layoutAttr = nullptr ;
466+ xegpu::LayoutAttr targetLayoutAttr = nullptr ;
440467 auto status = getLayoutAttrFromOperands (getContext (), state, (*this ),
441468 getMixedSgLayout (), getMixedSgData (),
442- getMixedInstData (), layoutAttr );
469+ getMixedInstData (), targetLayoutAttr );
443470 if (!status.succeeded ())
444471 return status;
445472
446- // Get load op.
447- auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
448- if (!maybeLoadOp)
449- return emitSilenceableFailure (getLoc ()) << " Could not find load op." ;
450- auto loadOp = *maybeLoadOp;
451- // Get load op operand value layout
452- auto producerLayoutAttr =
453- xegpu::getDistributeLayoutAttr (loadOp.getOperand (0 ));
473+ // Find source layout attribute from the producer chain.
474+ auto producerLayoutAttr = findProducerLayout (value);
454475 if (!producerLayoutAttr)
455476 return emitSilenceableFailure (getLoc ())
456- << " Operand producer op does not have a layout attr." ;
477+ << " Could not find a layout attribute in the producer chain." ;
478+
479+ // Find first user op to define insertion point for layout conversion.
480+ if (value.use_empty ())
481+ return emitSilenceableFailure (getLoc ())
482+ << " Value has no users to insert layout conversion." ;
483+ Operation *userOp = *value.getUsers ().begin ();
457484
458- if (producerLayoutAttr != layoutAttr) {
459- rewriter.setInsertionPointAfter (loadOp.getOperation ());
460- auto source = loadOp.getResult ();
485+ if (producerLayoutAttr != targetLayoutAttr) {
486+ rewriter.setInsertionPoint (userOp);
461487 auto convLayoutOp = xegpu::ConvertLayoutOp::create (
462- rewriter, loadOp .getLoc (), source .getType (), source , producerLayoutAttr,
463- layoutAttr );
488+ rewriter, value .getLoc (), value .getType (), value , producerLayoutAttr,
489+ targetLayoutAttr );
464490 // Replace load op result with the converted layout.
465491 rewriter.replaceUsesWithIf (
466- source , convLayoutOp.getResult (), [&](OpOperand &use) {
492+ value , convLayoutOp.getResult (), [&](OpOperand &use) {
467493 return use.getOwner () != convLayoutOp.getOperation ();
468494 });
469495 }
0 commit comments