Skip to content

Commit 9457b54

Browse files
committed
Add check
1 parent 1161e28 commit 9457b54

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,30 @@ struct WgToSgVectorShapeCastOp
970970
VectorType newResultType =
971971
VectorType::get(sgShape, resultType.getElementType());
972972

973+
// TODO: Add check for compatible layouts in layout attr.
974+
// Only support ShapeCast which expands or reduces unit dims only.
975+
// That is, only allow shape casts where the non-unit dimensions are
976+
// preserved, and any added or removed dimensions must be of size 1.
977+
auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
978+
if (!srcType)
979+
return failure();
980+
981+
auto isUnitOrPreserved = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
982+
// Remove all 1s from both shapes and compare the rest.
983+
SmallVector<int64_t> srcNonUnit, dstNonUnit;
984+
for (int64_t d : src)
985+
if (d != 1)
986+
srcNonUnit.push_back(d);
987+
for (int64_t d : dst)
988+
if (d != 1)
989+
dstNonUnit.push_back(d);
990+
return srcNonUnit == dstNonUnit;
991+
};
992+
993+
if (!isUnitOrPreserved(srcType.getShape(), sgShape) ||
994+
!isUnitOrPreserved(sgShape, srcType.getShape()))
995+
return failure();
996+
973997
SmallVector<Value> newShapeCastOps;
974998
for (auto src : adaptor.getSource()) {
975999
auto newShapeCast =

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,8 @@ gpu.module @test_distribution {
414414
%load = xegpu.load_nd %tdesc[0, 0]
415415
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
416416
-> vector<256x128xf32>
417-
//CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<2x16x4x8xf32>
418-
%cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 4, 1], sg_data = [2, 16, 4, 8]>} : vector<256x128xf32> to vector<16x16x16x8xf32>
417+
//CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32>
418+
%cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 4, 1], sg_data = [32, 1, 32, 1]>} : vector<256x128xf32> to vector<256x1x128x1xf32>
419419
gpu.return
420420
}
421421
}

0 commit comments

Comments
 (0)