@@ -120,6 +120,33 @@ static std::optional<T> findProducerOfType(Value val) {
120120 return findProducerOfType<T>(producerOp->getOperand (0 ));
121121}
122122
123+ // / Find layout attribute in producer chain.
124+ // / Traces producer ops until a layout attribute is found. Only traces through
125+ // / ops with a single operand, in other cases the op's result layout attribute
126+ // / must be set. Returns std::nullopt if no layout attribute is found.
127+ xegpu::LayoutAttr findProducerLayout (Value val) {
128+ // Get layout attr from value or producer's attribute or operand.
129+ if (auto layoutAttr = dyn_cast_if_present<xegpu::LayoutAttr>(
130+ xegpu::getDistributeLayoutAttr (val)))
131+ return layoutAttr;
132+
133+ // Recurse up the producer chain.
134+ Operation *producerOp = val.getDefiningOp ();
135+ if (!producerOp) {
136+ LDBG () << " Failed to find producer op." ;
137+ return nullptr ;
138+ }
139+ if (producerOp->getNumOperands () == 0 ) {
140+ LDBG () << " Producer has no operands." ;
141+ return nullptr ;
142+ }
143+ if (producerOp->getNumOperands () > 1 ) {
144+ LDBG () << " Producer has multiple operands." ;
145+ return nullptr ;
146+ }
147+ return findProducerLayout (producerOp->getOperand (0 ));
148+ }
149+
123150// / Create a layout attribute from the given parameters.
124151static xegpu::LayoutAttr
125152createLayoutAttr (MLIRContext *ctx, ArrayRef<int32_t > sgLayout,
@@ -568,34 +595,33 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
568595 << llvm::range_size (targetValues) << " )" ;
569596 auto value = *targetValues.begin ();
570597
571- xegpu::LayoutAttr layoutAttr = nullptr ;
598+ xegpu::LayoutAttr targetLayoutAttr = nullptr ;
572599 auto status = getLayoutAttrFromOperands (getContext (), state, (*this ),
573600 getMixedSgLayout (), getMixedSgData (),
574- getMixedInstData (), layoutAttr );
601+ getMixedInstData (), targetLayoutAttr );
575602 if (!status.succeeded ())
576603 return status;
577604
578- // Get load op.
579- auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
580- if (!maybeLoadOp)
581- return emitSilenceableFailure (getLoc ()) << " Could not find load op." ;
582- auto loadOp = *maybeLoadOp;
583- // Get load op operand value layout
584- auto producerLayoutAttr =
585- xegpu::getDistributeLayoutAttr (loadOp.getOperand (0 ));
605+ // Find source layout attribute from the producer chain.
606+ auto producerLayoutAttr = findProducerLayout (value);
586607 if (!producerLayoutAttr)
587608 return emitSilenceableFailure (getLoc ())
588- << " Operand producer op does not have a layout attr." ;
609+ << " Could not find a layout attribute in the producer chain." ;
610+
611+ // Find first user op to define insertion point for layout conversion.
612+ if (value.use_empty ())
613+ return emitSilenceableFailure (getLoc ())
614+ << " Value has no users to insert layout conversion." ;
615+ Operation *userOp = *value.getUsers ().begin ();
589616
590- if (producerLayoutAttr != layoutAttr) {
591- rewriter.setInsertionPointAfter (loadOp.getOperation ());
592- auto source = loadOp.getResult ();
617+ if (producerLayoutAttr != targetLayoutAttr) {
618+ rewriter.setInsertionPoint (userOp);
593619 auto convLayoutOp = xegpu::ConvertLayoutOp::create (
594- rewriter, loadOp .getLoc (), source .getType (), source , producerLayoutAttr,
595- layoutAttr );
620+ rewriter, value .getLoc (), value .getType (), value , producerLayoutAttr,
621+ targetLayoutAttr );
596622 // Replace load op result with the converted layout.
597623 rewriter.replaceUsesWithIf (
598- source , convLayoutOp.getResult (), [&](OpOperand &use) {
624+ value , convLayoutOp.getResult (), [&](OpOperand &use) {
599625 return use.getOwner () != convLayoutOp.getOperation ();
600626 });
601627 }
0 commit comments