@@ -211,7 +211,8 @@ struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
211211 mixedOffsets[x] = addi (oldX, subOffX);
212212 mixedOffsets[y] = addi (oldY, subOffY);
213213 auto newOp = rewriter.create <xegpu::CreateNdDescOp>(
214- loc, newTdescTy, op.getSource (), mixedOffsets, op.getMixedSizes (), op.getMixedStrides ());
214+ loc, newTdescTy, op.getSource (), mixedOffsets, op.getMixedSizes (),
215+ op.getMixedStrides ());
215216 newOps.push_back (newOp);
216217 }
217218 }
@@ -304,20 +305,21 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
304305
305306 auto elemTy = tdescTy.getElementType ();
306307 auto newValueTy = valueTy.cloneWith (targetShape, elemTy);
307- auto newTdescTy = xegpu::TensorDescType::get (ctx, targetShape, elemTy, tdescTy.getEncoding (),
308- getLaneLayoutAttr (layout));
308+ auto newTdescTy = xegpu::TensorDescType::get (ctx, targetShape, elemTy,
309+ tdescTy.getEncoding (),
310+ getLaneLayoutAttr (layout));
309311
310312 auto numNewOps = computeProduct (grids);
311313 llvm::SmallVector<Type> convertedValTypes (numNewOps, newValueTy);
312314 llvm::SmallVector<Type> convertedTdescTypes (numNewOps, newTdescTy);
313- auto convertedValues = addPackOp (op.getValue (), convertedValTypes, targetShape, loc, rewriter);
315+ auto convertedValues =
316+ addPackOp (op.getValue (), convertedValTypes, targetShape, loc, rewriter);
314317 auto convertedTdescs = addPackOp (op.getTensorDesc (), convertedTdescTypes,
315318 targetShape, loc, rewriter);
316319
317320 for (auto [v, t] : llvm::zip (convertedValues, convertedTdescs)) {
318321 rewriter.create <xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr (),
319- op.getL2HintAttr (),
320- op.getL3HintAttr ());
322+ op.getL2HintAttr (), op.getL3HintAttr ());
321323 }
322324 rewriter.eraseOp (op);
323325 return success ();
@@ -395,27 +397,27 @@ struct XeGPUUnrollPass final
395397
396398 void runOnOperation () override {
397399 vector::UnrollVectorOptions options;
398- options.setNativeShapeFn (
399- [&](Operation *op) -> std::optional<SmallVector<int64_t >> {
400- if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
401- xegpu::TensorDescType tdescTy;
402- if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
403- tdescTy = createNdOp.getType ();
404- } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
405- tdescTy = loadNdOp.getTensorDescType ();
406- } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
407- tdescTy = storeNdOp.getTensorDescType ();
408- }
409-
410- if (auto layout = tdescTy.getLayoutAttr ()) {
411- if (auto inst_data = layout.getInstData ())
412- return SmallVector<int64_t >(inst_data.asArrayRef ().begin (),
413- inst_data.asArrayRef ().end ());
414- }
415- }
416-
417- return std::nullopt ;
418- });
400+ options.setNativeShapeFn ([&](Operation *op)
401+ -> std::optional<SmallVector<int64_t >> {
402+ if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
403+ xegpu::TensorDescType tdescTy;
404+ if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
405+ tdescTy = createNdOp.getType ();
406+ } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
407+ tdescTy = loadNdOp.getTensorDescType ();
408+ } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
409+ tdescTy = storeNdOp.getTensorDescType ();
410+ }
411+
412+ if (auto layout = tdescTy.getLayoutAttr ()) {
413+ if (auto inst_data = layout.getInstData ())
414+ return SmallVector<int64_t >(inst_data.asArrayRef ().begin (),
415+ inst_data.asArrayRef ().end ());
416+ }
417+ }
418+
419+ return std::nullopt ;
420+ });
419421
420422 auto funcOp = getOperation ();
421423 RewritePatternSet patterns (&getContext ());
@@ -432,7 +434,8 @@ struct XeGPUUnrollPass final
432434} // namespace
433435
434436void mlir::xegpu::populateXeGPUUnrollPatterns (
435- RewritePatternSet &patterns, const mlir::vector::UnrollVectorOptions &options) {
437+ RewritePatternSet &patterns,
438+ const mlir::vector::UnrollVectorOptions &options) {
436439 patterns.add <UnrollCreateNdOp, UnrollLoadNdOp, UnrollStoreNdOp>(
437- patterns.getContext (), options);
440+ patterns.getContext (), options);
438441}
0 commit comments