@@ -216,6 +216,18 @@ void XeGPUBlockingPass::runOnOperation() {
216216 // operation is replaced.
217217 xegpu::setLayoutAttrs (mod, [&](Value v) { return xegpu::getLayoutAttr (v); });
218218
219+ auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t > shape,
220+ xegpu::LayoutAttr layout) {
221+ int count = 1 ;
222+ SmallVector<int64_t > tileShape (shape);
223+ if (layout && layout.getInstData ()) {
224+ DenseI32ArrayAttr instData = layout.getInstData ();
225+ tileShape = llvm::to_vector_of<int64_t >(instData.asArrayRef ());
226+ count = computeProduct (shape) / computeProduct (tileShape);
227+ }
228+ return std::make_pair (tileShape, count);
229+ };
230+
219231 // Perform type conversion for SCF control folow ops
220232 TypeConverter converter;
221233 converter.addConversion ([&](Type type) -> Type { return type; });
@@ -225,56 +237,41 @@ void XeGPUBlockingPass::runOnOperation() {
225237 Type elemTy = type.getElementType ();
226238 ArrayRef<int64_t > shape = type.getShape ();
227239
228- // init count and subShape to the default value. If the LayoutAttr
229- // is not present, it will return a VectorType with original shape.
230- int count = 1 ;
231- SmallVector<int64_t > subShape (shape);
232- if (auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
233- type.getEncoding ())) {
234- if (layout.isWgLayout ())
235- return failure ();
236- if (DenseI32ArrayAttr instData = layout.getInstData ()) {
237- // for unrolling, the subShape is determined by inst_data
238- subShape = llvm::to_vector_of<int64_t >(instData.asArrayRef ());
239- count = computeProduct (shape) / computeProduct (subShape);
240- }
241- }
240+ auto layout =
241+ llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding ());
242+ if (layout && layout.isWgLayout ())
243+ return failure ();
244+
245+ int count;
246+ SmallVector<int64_t > subShape;
247+ std::tie (subShape, count) = getTileShapeAndCount (shape, layout);
242248 auto newTy = VectorType::get (subShape, elemTy);
243249 result.append (count, newTy);
244250 return success ();
245251 });
246-
247252 converter.addConversion (
248253 [&](xegpu::TensorDescType type,
249254 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
250- MLIRContext *ctx = type.getContext ();
251255 Type elemTy = type.getElementType ();
252- Attribute encoding = type.getEncoding ();
253256 ArrayRef<int64_t > shape = type.getShape ();
254257
255- // init count and newTy to the default value. If the layout
256- // attribute is not present, it will return the original type.
257- int count = 1 ;
258- SmallVector<int64_t > subShape (shape);
259-
260258 xegpu::LayoutAttr layout = type.getLayoutAttr ();
259+ if (layout && layout.isWgLayout ())
260+ return failure ();
261+
262+ int count;
263+ SmallVector<int64_t > subShape;
264+ std::tie (subShape, count) = getTileShapeAndCount (shape, layout);
261265
262- if (layout) {
263- if (layout.isWgLayout ())
264- return failure ();
265-
266- if (DenseI32ArrayAttr instData = layout.getInstData ()) {
267- // for unrolling, the subShape is determined by inst_data
268- subShape = llvm::to_vector_of<int64_t >(instData.asArrayRef ());
269- count = computeProduct (shape) / computeProduct (subShape);
270- layout = layout.dropInstData ();
271- }
272- }
273- auto newTy =
274- xegpu::TensorDescType::get (ctx, subShape, elemTy, encoding, layout);
266+ if (layout)
267+ layout = layout.dropInstData ();
268+
269+ auto newTy = xegpu::TensorDescType::get (
270+ type.getContext (), subShape, elemTy, type.getEncoding (), layout);
275271 result.append (count, newTy);
276272 return success ();
277273 });
274+
278275 xegpu::doSCFStructuralTypeConversionWithTensorType (mod, converter);
279276
280277 xegpu::UnrollOptions options;
0 commit comments