Skip to content

Commit 29d3f45

Browse files
committed
Add test
1 parent 2c81dee commit 29d3f45

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
790790
for (int64_t r = 0; r < rows; ++r) {
791791
for (int64_t c = 0; c < cols; ++c) {
792792
int64_t idx = r * cols + c;
793-
// Check column stride (skip first column)
793+
// Check column stride
794794
if (c > 0 && cols > 1) {
795795
int64_t prevIdx = r * cols + (c - 1);
796796
int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
@@ -799,7 +799,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
799799
return rewriter.notifyMatchFailure(
800800
op, "Non-constant column stride in constant op.");
801801
}
802-
// Check row stride (skip first row)
802+
// Check row stride
803803
if (r > 0 && rows > 1) {
804804
int64_t prevIdx = (r - 1) * cols + c;
805805
int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,4 +505,18 @@ gpu.module @test_distribution {
505505
]> : vector<8x8xindex>
506506
gpu.return
507507
}
508+
509+
// CHECK-LABEL: non_splat_constant
510+
gpu.func @non_splat_constant() {
511+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
512+
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
513+
// CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C32:.*]]
514+
// CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C16:.*]] : index
515+
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL]] : index to vector<1xindex>
516+
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<1xindex>
517+
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496]> : vector<32xindex>
518+
// CHECK: arith.constant dense<{{\[}}[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]{{\]}}> : vector<1x16xindex>
519+
%cst_1 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 16]>} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>
520+
gpu.return
521+
}
508522
}

0 commit comments

Comments
 (0)