Skip to content

Commit 2c97ee7

Browse files
committed
Add CHECKS
1 parent 803a565 commit 2c97ee7

File tree

3 files changed

+11
-12
lines changed

3 files changed

+11
-12
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,6 @@ struct WgToSgVectorBroadcastOp
338338
ConversionPatternRewriter &rewriter) const override {
339339
VectorType resultType = op.getResult().getType();
340340
ArrayRef<int64_t> wgShape = resultType.getShape();
341-
if (!resultType)
342-
return failure();
343341

344342
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
345343
if (!layout || !layout.getSgLayout())
@@ -348,17 +346,17 @@ struct WgToSgVectorBroadcastOp
348346
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
349347
VectorType newResultType =
350348
VectorType::get(sgShape, resultType.getElementType());
351-
SmallVector<Value> newBroadcasts;
352349

350+
SmallVector<Value> newBroadcastOps;
353351
for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
354352
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
355353
op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
356354
xegpu::setLayoutAttr(newBroadcast->getResult(0),
357355
layout.dropSgLayoutAndData());
358-
newBroadcasts.push_back(newBroadcast.getResult());
356+
newBroadcastOps.push_back(newBroadcast.getResult());
359357
}
360358

361-
rewriter.replaceOpWithMultiple(op, {newBroadcasts});
359+
rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
362360
return success();
363361
}
364362
};
@@ -556,11 +554,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
556554

557555
target.addDynamicallyLegalOp<vector::BroadcastOp>(
558556
[=](vector::BroadcastOp op) -> bool {
559-
auto resultType = dyn_cast<VectorType>(op.getResult().getType());
560-
if (!resultType)
561-
return true;
562-
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
563-
return isLegal(layout);
557+
return isLegal(xegpu::getLayoutAttr(op.getResult()));
564558
});
565559

566560
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ gpu.module @test_round_robin_assignment {
111111
%load = xegpu.load_nd %tdesc
112112
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
113113
-> vector<24x1xf32>
114+
// CHECK-COUNT-3: vector.broadcast {{.*}}
115+
// CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
116+
// CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
117+
// CHECK-NOT: vector.broadcast
114118
%broadcast = vector.broadcast %load
115119
{layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
116120
: vector<24x1xf32> to vector<24x8xf32>

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,16 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
170170
gpu.return
171171
}
172172

173-
174-
// CHECK-LABEL: test_broadcast
173+
// CHECK-LABEL: test_broadcast
175174
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
176175
gpu.func @test_broadcast(%src: memref<24x1xf32>) {
177176
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
178177
-> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
179178
%load = xegpu.load_nd %tdesc
180179
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
181180
-> vector<24x1xf32>
181+
// CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
182+
// CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
182183
%broadcast = vector.broadcast %load
183184
{layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
184185
: vector<24x1xf32> to vector<24x8xf32>

0 commit comments

Comments
 (0)