Skip to content

Commit 3c147c7

Browse files
committed
Refactor
1 parent fabb419 commit 3c147c7

File tree

2 files changed

+38
-53
lines changed

2 files changed

+38
-53
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 36 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -773,54 +773,40 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
773773
op, "Only 1D & 2D vector constant supported");
774774

775775
SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
776-
int64_t stride = 0;
777776
int64_t rowStride = 0, colStride = 0;
778-
if (wgShape.size() == 1) {
779-
// 1D case: single stride
780-
if (values.size() > 1) {
781-
stride = cast<IntegerAttr>(values[1]).getInt() -
782-
cast<IntegerAttr>(values[0]).getInt();
783-
for (size_t i = 2; i < values.size(); ++i) {
784-
int64_t diff = cast<IntegerAttr>(values[i]).getInt() -
785-
cast<IntegerAttr>(values[i - 1]).getInt();
786-
if (diff != stride)
777+
int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
778+
int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
779+
780+
// Compute colStride and rowStride, and check for constant strides.
781+
if (cols > 1) {
782+
colStride = cast<IntegerAttr>(values[1]).getInt() -
783+
cast<IntegerAttr>(values[0]).getInt();
784+
}
785+
if (rows > 1) {
786+
rowStride = cast<IntegerAttr>(values[cols]).getInt() -
787+
cast<IntegerAttr>(values[0]).getInt();
788+
}
789+
790+
for (int64_t r = 0; r < rows; ++r) {
791+
for (int64_t c = 0; c < cols; ++c) {
792+
int64_t idx = r * cols + c;
793+
// Check column stride (skip first column)
794+
if (c > 0 && cols > 1) {
795+
int64_t prevIdx = r * cols + (c - 1);
796+
int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
797+
cast<IntegerAttr>(values[prevIdx]).getInt();
798+
if (diff != colStride)
787799
return rewriter.notifyMatchFailure(
788-
op, "Non-constant stride in non-splat constant op.");
800+
op, "Non-constant column stride in constant op.");
789801
}
790-
}
791-
} else if (wgShape.size() == 2) {
792-
// 2D case: row stride and column stride
793-
int64_t rows = wgShape[0], cols = wgShape[1];
794-
// Compute col stride (stride between elements in a column)
795-
if (cols > 1) {
796-
colStride = cast<IntegerAttr>(values[1]).getInt() -
797-
cast<IntegerAttr>(values[0]).getInt();
798-
for (int64_t r = 0; r < rows; ++r) {
799-
for (int64_t c = 1; c < cols; ++c) {
800-
int64_t idx = r * cols + c;
801-
int64_t prevIdx = r * cols + (c - 1);
802-
int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
803-
cast<IntegerAttr>(values[prevIdx]).getInt();
804-
if (diff != colStride)
805-
return rewriter.notifyMatchFailure(
806-
op, "Non-constant column stride in 2D constant op.");
807-
}
808-
}
809-
}
810-
// Compute row stride (stride between elements in a row)
811-
if (rows > 1) {
812-
rowStride = cast<IntegerAttr>(values[cols]).getInt() -
813-
cast<IntegerAttr>(values[0]).getInt();
814-
for (int64_t c = 0; c < cols; ++c) {
815-
for (int64_t r = 1; r < rows; ++r) {
816-
int64_t idx = r * cols + c;
817-
int64_t prevIdx = (r - 1) * cols + c;
818-
int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
819-
cast<IntegerAttr>(values[prevIdx]).getInt();
820-
if (diff != rowStride)
821-
return rewriter.notifyMatchFailure(
822-
op, "Non-constant row stride in 2D constant op.");
823-
}
802+
// Check row stride (skip first row)
803+
if (r > 0 && rows > 1) {
804+
int64_t prevIdx = (r - 1) * cols + c;
805+
int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
806+
cast<IntegerAttr>(values[prevIdx]).getInt();
807+
if (diff != rowStride)
808+
return rewriter.notifyMatchFailure(
809+
op, "Non-constant row stride in constant op.");
824810
}
825811
}
826812
}
@@ -829,12 +815,11 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
829815
// For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
830816
// For 1D case, extract the first sgShape[0] elements.
831817
SmallVector<Attribute> baseTileValues;
832-
int cols = sgShape[sgShape.size() - 1];
833-
int64_t wgCols = wgShape[sgShape.size() - 1];
834-
int64_t rows = sgShape.size() == 1 ? 1 : sgShape[0];
835-
for (int64_t r = 0; r < rows; ++r) {
836-
for (int64_t c = 0; c < cols; ++c) {
837-
baseTileValues.push_back(values[r * wgCols + c]);
818+
int baseTileCols = sgShape[sgShape.size() - 1];
819+
int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
820+
for (int64_t r = 0; r < baseTileRows; ++r) {
821+
for (int64_t c = 0; c < baseTileCols; ++c) {
822+
baseTileValues.push_back(values[r * cols + c]);
838823
}
839824
}
840825

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,8 @@ gpu.module @test_distribution {
459459
gpu.return
460460
}
461461

462-
// CHECK-LABEL: non_splat_constant
463-
gpu.func @non_splat_constant() {
462+
// CHECK-LABEL: non_splat_constant_2D
463+
gpu.func @non_splat_constant_2D() {
464464
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x1xindex>
465465
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
466466
// CHECK-DAG: affine.apply #map4()[%[[SGID]]]

0 commit comments

Comments
 (0)