@@ -44,37 +44,38 @@ struct TestXeGPUUnrollingPatterns
4444
4545 void runOnOperation () override {
4646 vector::UnrollVectorOptions options;
47- options.setNativeShapeFn ([&](Operation *op)
48- -> std::optional<SmallVector<int64_t >> {
49- if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
50- xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
51- xegpu::TensorDescType tdescTy;
52- if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
53- tdescTy = createNdOp.getType ();
54- } else if (auto updateNdOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
55- tdescTy = updateNdOp.getTensorDescType ();
56- } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
57- tdescTy = prefetchNdOp.getTensorDescType ();
58- } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
59- tdescTy = loadNdOp.getTensorDescType ();
60- } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
61- tdescTy = storeNdOp.getTensorDescType ();
62- }
63-
64- if (auto layout = tdescTy.getLayoutAttr ()) {
65- auto inst_data = layout.getInstData ();
66- if (inst_data && layout.isSgLayout ())
67- return SmallVector<int64_t >(inst_data.asArrayRef ().begin (),
68- inst_data.asArrayRef ().end ());
69- }
70- }
71-
72- if (isa<xegpu::DpasOp>(op)) {
73- return SmallVector<int64_t >{8 , 16 , 16 };
74- }
75-
76- return std::nullopt ;
77- });
47+ options.setNativeShapeFn (
48+ [&](Operation *op) -> std::optional<SmallVector<int64_t >> {
49+ if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
50+ xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
51+ xegpu::TensorDescType tdescTy;
52+ if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
53+ tdescTy = createNdOp.getType ();
54+ } else if (auto updateNdOp =
55+ dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
56+ tdescTy = updateNdOp.getTensorDescType ();
57+ } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
58+ tdescTy = prefetchNdOp.getTensorDescType ();
59+ } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
60+ tdescTy = loadNdOp.getTensorDescType ();
61+ } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
62+ tdescTy = storeNdOp.getTensorDescType ();
63+ }
64+
65+ if (auto layout = tdescTy.getLayoutAttr ()) {
66+ auto inst_data = layout.getInstData ();
67+ if (inst_data && layout.isSgLayout ())
68+ return SmallVector<int64_t >(inst_data.asArrayRef ().begin (),
69+ inst_data.asArrayRef ().end ());
70+ }
71+ }
72+
73+ if (isa<xegpu::DpasOp>(op)) {
74+ return SmallVector<int64_t >{8 , 16 , 16 };
75+ }
76+
77+ return std::nullopt ;
78+ });
7879
7980 MLIRContext *ctx = &getContext ();
8081 RewritePatternSet patterns (ctx);
0 commit comments