@@ -90,6 +90,38 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
9090 /* order=*/ nullptr );
9191}
9292
93+ // / Generate `xegpu::LayoutAttr` from op mixed layout values.
94+ DiagnosedSilenceableFailure
95+ getLayoutAttrFromOperands (transform::TransformRewriter &rewriter,
96+ transform::TransformState &state,
97+ TransformOpInterface transformOp,
98+ ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
99+ ArrayRef<::mlir::OpFoldResult> mixedSgData,
100+ ArrayRef<::mlir::OpFoldResult> mixedInstData,
101+ xegpu::LayoutAttr &layoutAttr) {
102+ SmallVector<int32_t > sgLayout, sgData, instData;
103+ auto status =
104+ convertMixedValuesToInt (state, transformOp, sgLayout, mixedSgLayout);
105+ if (!status.succeeded ())
106+ return status;
107+
108+ status = convertMixedValuesToInt (state, transformOp, sgData, mixedSgData);
109+ if (!status.succeeded ())
110+ return status;
111+
112+ status = convertMixedValuesToInt (state, transformOp, instData, mixedInstData);
113+ if (!status.succeeded ())
114+ return status;
115+ auto maybeInstData = instData.empty ()
116+ ? std::nullopt
117+ : std::optional<ArrayRef<int32_t >>(instData);
118+
119+ layoutAttr =
120+ createLayoutAttr (rewriter.getContext (), sgLayout, sgData, maybeInstData);
121+
122+ return DiagnosedSilenceableFailure::success ();
123+ }
124+
93125// / Replace xegpu.create_nd_desc op with a new one with the given layout.
94126static xegpu::CreateNdDescOp
95127setDescLayout (transform::TransformRewriter &rewriter,
@@ -142,26 +174,13 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
142174 }
143175 Operation *target = *targetOps.begin ();
144176
145- SmallVector<int32_t > sgLayout;
146- DiagnosedSilenceableFailure status =
147- convertMixedValuesToInt (state, (*this ), sgLayout, getMixedSgLayout ());
177+ xegpu::LayoutAttr layoutAttr = nullptr ;
178+ auto status = getLayoutAttrFromOperands (rewriter, state, (*this ),
179+ getMixedSgLayout (), getMixedSgData (),
180+ getMixedInstData (), layoutAttr);
148181 if (!status.succeeded ())
149182 return status;
150183
151- SmallVector<int32_t > sgData;
152- status = convertMixedValuesToInt (state, (*this ), sgData, getMixedSgData ());
153- if (!status.succeeded ())
154- return status;
155-
156- SmallVector<int32_t > instData;
157- status =
158- convertMixedValuesToInt (state, (*this ), instData, getMixedInstData ());
159- if (!status.succeeded ())
160- return status;
161- auto maybeInstData = instData.empty ()
162- ? std::nullopt
163- : std::optional<ArrayRef<int32_t >>(instData);
164-
165184 // For now only create_nd_desc op is supported.
166185 auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
167186 if (!descOp) {
@@ -173,8 +192,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
173192 }
174193
175194 // Set layout attr in desc op's return type. Replaces old desc op.
176- auto layoutAttr =
177- createLayoutAttr (rewriter.getContext (), sgLayout, sgData, maybeInstData);
178195 auto newdescOp = setDescLayout (rewriter, descOp, layoutAttr);
179196
180197 // Map result handles.
@@ -193,6 +210,76 @@ void transform::SetDescLayoutOp::getEffects(
193210 modifiesPayload (effects);
194211}
195212
213+ void transform::SetOpLayoutAttrOp::build (
214+ OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
215+ ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
216+ ArrayRef<OpFoldResult> mixedInstData, bool result) {
217+ SmallVector<int64_t > staticSgLayout, staticSgData, staticInstData;
218+ SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
219+ dispatchIndexOpFoldResults (mixedSgLayout, dynamicSgLayout, staticSgLayout);
220+ dispatchIndexOpFoldResults (mixedSgData, dynamicSgData, staticSgData);
221+ dispatchIndexOpFoldResults (mixedInstData, dynamicInstData, staticInstData);
222+ build (builder, ostate, target.getType (),
223+ /* target=*/ target,
224+ /* index=*/ index,
225+ /* sg_layout=*/ dynamicSgLayout,
226+ /* sg_data=*/ dynamicSgData,
227+ /* inst_data=*/ dynamicInstData,
228+ /* static_sg_layout=*/ staticSgLayout,
229+ /* static_sg_data=*/ staticSgData,
230+ /* static_inst_data=*/ staticInstData,
231+ /* result=*/ result);
232+ }
233+
234+ DiagnosedSilenceableFailure
235+ transform::SetOpLayoutAttrOp::apply (transform::TransformRewriter &rewriter,
236+ transform::TransformResults &results,
237+ transform::TransformState &state) {
238+
239+ auto targetOps = state.getPayloadOps (getTarget ());
240+ if (!llvm::hasSingleElement (targetOps)) {
241+ return emitDefiniteFailure () << " Requires exactly one targetOp handle (got "
242+ << llvm::range_size (targetOps) << " )" ;
243+ }
244+ Operation *target = *targetOps.begin ();
245+
246+ bool resultTarget = getResult ();
247+
248+ int64_t index = getIndex ();
249+ if (resultTarget && index >= target->getNumResults ()) {
250+ return emitSilenceableFailure (getLoc ())
251+ << " Index exceeds the number of op results" ;
252+ }
253+ if (!resultTarget && index >= target->getNumOperands ()) {
254+ return emitSilenceableFailure (getLoc ())
255+ << " Index exceeds the number of op operands" ;
256+ }
257+
258+ xegpu::LayoutAttr layoutAttr = nullptr ;
259+ auto status = getLayoutAttrFromOperands (rewriter, state, (*this ),
260+ getMixedSgLayout (), getMixedSgData (),
261+ getMixedInstData (), layoutAttr);
262+ if (!status.succeeded ())
263+ return status;
264+
265+ // Set layout attribute for the op result or operand
266+ if (resultTarget) {
267+ xegpu::setDistributeLayoutAttr (target->getResult (index), layoutAttr);
268+ } else {
269+ xegpu::setDistributeLayoutAttr (target->getOpOperand (index), layoutAttr);
270+ }
271+ return DiagnosedSilenceableFailure::success ();
272+ }
273+
274+ void transform::SetOpLayoutAttrOp::getEffects (
275+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
276+ onlyReadsHandle (getTargetMutable (), effects);
277+ onlyReadsHandle (getSgLayoutMutable (), effects);
278+ onlyReadsHandle (getSgDataMutable (), effects);
279+ onlyReadsHandle (getInstDataMutable (), effects);
280+ modifiesPayload (effects);
281+ }
282+
196283namespace {
197284class XeGPUTransformDialectExtension
198285 : public transform::TransformDialectExtension<
0 commit comments