Skip to content

Commit 6cffa44

Browse files
committed
refactor
1 parent 9776850 commit 6cffa44

File tree

1 file changed

+32
-35
lines changed

1 file changed

+32
-35
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,18 @@ void XeGPUBlockingPass::runOnOperation() {
216216
// operation is replaced.
217217
xegpu::setLayoutAttrs(mod, [&](Value v) { return xegpu::getLayoutAttr(v); });
218218

219+
auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
220+
xegpu::LayoutAttr layout) {
221+
int count = 1;
222+
SmallVector<int64_t> tileShape(shape);
223+
if (layout && layout.getInstData()) {
224+
DenseI32ArrayAttr instData = layout.getInstData();
225+
tileShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
226+
count = computeProduct(shape) / computeProduct(tileShape);
227+
}
228+
return std::make_pair(tileShape, count);
229+
};
230+
219231
// Perform type conversion for SCF control folow ops
220232
TypeConverter converter;
221233
converter.addConversion([&](Type type) -> Type { return type; });
@@ -225,56 +237,41 @@ void XeGPUBlockingPass::runOnOperation() {
225237
Type elemTy = type.getElementType();
226238
ArrayRef<int64_t> shape = type.getShape();
227239

228-
// init count and subShape to the default value. If the LayoutAttr
229-
// is not present, it will return a VectorType with original shape.
230-
int count = 1;
231-
SmallVector<int64_t> subShape(shape);
232-
if (auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
233-
type.getEncoding())) {
234-
if (layout.isWgLayout())
235-
return failure();
236-
if (DenseI32ArrayAttr instData = layout.getInstData()) {
237-
// for unrolling, the subShape is determined by inst_data
238-
subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
239-
count = computeProduct(shape) / computeProduct(subShape);
240-
}
241-
}
240+
auto layout =
241+
llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
242+
if (layout && layout.isWgLayout())
243+
return failure();
244+
245+
int count;
246+
SmallVector<int64_t> subShape;
247+
std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
242248
auto newTy = VectorType::get(subShape, elemTy);
243249
result.append(count, newTy);
244250
return success();
245251
});
246-
247252
converter.addConversion(
248253
[&](xegpu::TensorDescType type,
249254
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
250-
MLIRContext *ctx = type.getContext();
251255
Type elemTy = type.getElementType();
252-
Attribute encoding = type.getEncoding();
253256
ArrayRef<int64_t> shape = type.getShape();
254257

255-
// init count and newTy to the default value. If the layout
256-
// attribute is not present, it will return the original type.
257-
int count = 1;
258-
SmallVector<int64_t> subShape(shape);
259-
260258
xegpu::LayoutAttr layout = type.getLayoutAttr();
259+
if (layout && layout.isWgLayout())
260+
return failure();
261+
262+
int count;
263+
SmallVector<int64_t> subShape;
264+
std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
261265

262-
if (layout) {
263-
if (layout.isWgLayout())
264-
return failure();
265-
266-
if (DenseI32ArrayAttr instData = layout.getInstData()) {
267-
// for unrolling, the subShape is determined by inst_data
268-
subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
269-
count = computeProduct(shape) / computeProduct(subShape);
270-
layout = layout.dropInstData();
271-
}
272-
}
273-
auto newTy =
274-
xegpu::TensorDescType::get(ctx, subShape, elemTy, encoding, layout);
266+
if (layout)
267+
layout = layout.dropInstData();
268+
269+
auto newTy = xegpu::TensorDescType::get(
270+
type.getContext(), subShape, elemTy, type.getEncoding(), layout);
275271
result.append(count, newTy);
276272
return success();
277273
});
274+
278275
xegpu::doSCFStructuralTypeConversionWithTensorType(mod, converter);
279276

280277
xegpu::UnrollOptions options;

0 commit comments

Comments
 (0)