@@ -120,33 +120,6 @@ static std::optional<T> findProducerOfType(Value val) {
120120 return findProducerOfType<T>(producerOp->getOperand (0 ));
121121}
122122
123- // / Find layout attribute in producer chain.
124- // / Traces producer ops until a layout attribute is found. Only traces through
125- // / ops with a single operand, in other cases the op's result layout attribute
126- // / must be set. Returns std::nullopt if no layout attribute is found.
127- xegpu::LayoutAttr findProducerLayout (Value val) {
128- // Get layout attr from value or producer's attribute or operand.
129- if (auto layoutAttr = dyn_cast_if_present<xegpu::LayoutAttr>(
130- xegpu::getDistributeLayoutAttr (val)))
131- return layoutAttr;
132-
133- // Recurse up the producer chain.
134- Operation *producerOp = val.getDefiningOp ();
135- if (!producerOp) {
136- LDBG () << " Failed to find producer op." ;
137- return nullptr ;
138- }
139- if (producerOp->getNumOperands () == 0 ) {
140- LDBG () << " Producer has no operands." ;
141- return nullptr ;
142- }
143- if (producerOp->getNumOperands () > 1 ) {
144- LDBG () << " Producer has multiple operands." ;
145- return nullptr ;
146- }
147- return findProducerLayout (producerOp->getOperand (0 ));
148- }
149-
150123// / Create a layout attribute from the given parameters.
151124static xegpu::LayoutAttr
152125createLayoutAttr (MLIRContext *ctx, ArrayRef<int32_t > sgLayout,
@@ -564,24 +537,48 @@ void transform::InsertPrefetchOp::getEffects(
564537 modifiesPayload (effects);
565538}
566539
567- void transform::ConvertLayoutOp::build (OpBuilder &builder,
568- OperationState &ostate, Value target,
569- ArrayRef<OpFoldResult> mixedSgLayout,
570- ArrayRef<OpFoldResult> mixedSgData,
571- ArrayRef<OpFoldResult> mixedInstData) {
572- SmallVector<int64_t > staticSgLayout, staticSgData, staticInstData;
573- SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
574- dispatchIndexOpFoldResults (mixedSgLayout, dynamicSgLayout, staticSgLayout);
575- dispatchIndexOpFoldResults (mixedSgData, dynamicSgData, staticSgData);
576- dispatchIndexOpFoldResults (mixedInstData, dynamicInstData, staticInstData);
540+ void transform::ConvertLayoutOp::build (
541+ OpBuilder &builder, OperationState &ostate, Value target,
542+ ArrayRef<OpFoldResult> mixedInputSgLayout,
543+ ArrayRef<OpFoldResult> mixedInputSgData,
544+ ArrayRef<OpFoldResult> mixedInputInstData,
545+ ArrayRef<OpFoldResult> mixedTargetSgLayout,
546+ ArrayRef<OpFoldResult> mixedTargetSgData,
547+ ArrayRef<OpFoldResult> mixedTargetInstData) {
548+ SmallVector<int64_t > staticInputSgLayout, staticInputSgData,
549+ staticInputInstData;
550+ SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData,
551+ dynamicInputInstData;
552+ dispatchIndexOpFoldResults (mixedInputSgLayout, dynamicInputSgLayout,
553+ staticInputSgLayout);
554+ dispatchIndexOpFoldResults (mixedInputSgData, dynamicInputSgData,
555+ staticInputSgData);
556+ dispatchIndexOpFoldResults (mixedInputInstData, dynamicInputInstData,
557+ staticInputInstData);
558+ SmallVector<int64_t > staticTargetSgLayout, staticTargetSgData,
559+ staticTargetInstData;
560+ SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData,
561+ dynamicTargetInstData;
562+ dispatchIndexOpFoldResults (mixedTargetSgLayout, dynamicTargetSgLayout,
563+ staticTargetSgLayout);
564+ dispatchIndexOpFoldResults (mixedTargetSgData, dynamicTargetSgData,
565+ staticTargetSgData);
566+ dispatchIndexOpFoldResults (mixedTargetInstData, dynamicTargetInstData,
567+ staticTargetInstData);
577568 build (builder, ostate, target.getType (),
578569 /* target=*/ target,
579- /* sg_layout=*/ dynamicSgLayout,
580- /* sg_data=*/ dynamicSgData,
581- /* inst_data=*/ dynamicInstData,
582- /* static_sg_layout=*/ staticSgLayout,
583- /* static_sg_data=*/ staticSgData,
584- /* static_inst_data=*/ staticInstData);
570+ /* input_sg_layout=*/ dynamicInputSgLayout,
571+ /* input_sg_data=*/ dynamicInputSgData,
572+ /* input_inst_data=*/ dynamicInputInstData,
573+ /* target_sg_layout=*/ dynamicTargetSgLayout,
574+ /* target_sg_data=*/ dynamicTargetSgData,
575+ /* target_inst_data=*/ dynamicTargetInstData,
576+ /* static_input_sg_layout=*/ staticInputSgLayout,
577+ /* static_input_sg_data=*/ staticInputSgData,
578+ /* static_input_inst_data=*/ staticInputInstData,
579+ /* static_target_sg_layout=*/ staticTargetSgLayout,
580+ /* static_target_sg_data=*/ staticTargetSgData,
581+ /* static_target_inst_data=*/ staticTargetInstData);
585582}
586583
587584DiagnosedSilenceableFailure
@@ -595,18 +592,20 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
595592 << llvm::range_size (targetValues) << " )" ;
596593 auto value = *targetValues.begin ();
597594
598- xegpu::LayoutAttr targetLayoutAttr = nullptr ;
599- auto status = getLayoutAttrFromOperands (getContext (), state, (*this ),
600- getMixedSgLayout (), getMixedSgData (),
601- getMixedInstData (), targetLayoutAttr);
595+ // Construct layout attributes.
596+ xegpu::LayoutAttr inputLayoutAttr = nullptr ;
597+ auto status = getLayoutAttrFromOperands (
598+ getContext (), state, (*this ), getMixedInputSgLayout (),
599+ getMixedInputSgData (), getMixedInputInstData (), inputLayoutAttr);
602600 if (!status.succeeded ())
603601 return status;
604602
605- // Find source layout attribute from the producer chain.
606- auto producerLayoutAttr = findProducerLayout (value);
607- if (!producerLayoutAttr)
608- return emitSilenceableFailure (getLoc ())
609- << " Could not find a layout attribute in the producer chain." ;
603+ xegpu::LayoutAttr targetLayoutAttr = nullptr ;
604+ status = getLayoutAttrFromOperands (
605+ getContext (), state, (*this ), getMixedTargetSgLayout (),
606+ getMixedTargetSgData (), getMixedTargetInstData (), targetLayoutAttr);
607+ if (!status.succeeded ())
608+ return status;
610609
611610 // Find first user op to define insertion point for layout conversion.
612611 if (value.use_empty ())
@@ -616,9 +615,9 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
616615
617616 // Emit convert_layout op.
618617 rewriter.setInsertionPoint (userOp);
619- auto convLayoutOp = xegpu::ConvertLayoutOp::create (
620- rewriter, value.getLoc (), value.getType (), value, producerLayoutAttr ,
621- targetLayoutAttr);
618+ auto convLayoutOp =
619+ xegpu::ConvertLayoutOp::create ( rewriter, value.getLoc (), value.getType (),
620+ value, inputLayoutAttr, targetLayoutAttr);
622621 // Replace load op result with the converted layout.
623622 rewriter.replaceUsesWithIf (
624623 value, convLayoutOp.getResult (), [&](OpOperand &use) {
@@ -632,9 +631,12 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
632631void transform::ConvertLayoutOp::getEffects (
633632 ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
634633 onlyReadsHandle (getTargetMutable (), effects);
635- onlyReadsHandle (getSgLayoutMutable (), effects);
636- onlyReadsHandle (getSgDataMutable (), effects);
637- onlyReadsHandle (getInstDataMutable (), effects);
634+ onlyReadsHandle (getInputSgLayoutMutable (), effects);
635+ onlyReadsHandle (getInputSgDataMutable (), effects);
636+ onlyReadsHandle (getInputInstDataMutable (), effects);
637+ onlyReadsHandle (getTargetSgLayoutMutable (), effects);
638+ onlyReadsHandle (getTargetSgDataMutable (), effects);
639+ onlyReadsHandle (getTargetInstDataMutable (), effects);
638640 producesHandle (getOperation ()->getOpResults (), effects);
639641 modifiesPayload (effects);
640642}
0 commit comments