Skip to content

Commit bf37af1

Browse files
committed
add one more unit tests
1 parent b2032a4 commit bf37af1

File tree

2 files changed

+38
-23
lines changed

2 files changed

+38
-23
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,13 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
329329

330330
// Handles UnrealizedConversionCastOp generated during
331331
// SCFStructuralTypeConversions (step 1). This op may appear as either a
332-
// target or source materialization for Vector or TensorDesc, e.g.:
333-
// 1. unrealized_conversion_cast %1 : tensor_desc<16xf16> to
334-
// tensor_desc<128xf16, ...>
335-
// 2. unrealized_conversion_cast %1 : vector<256xf32> to vector<16xf32>, ...
336-
// 3. unrealized_conversion_cast %1 : vector<16xf32>, ... to vector<256xf32>
337-
// In all cases, the pattern simply forwards the inputs to the outputs with
338-
// one-to-one or one-to-n patterns.
332+
// target or source materialization for Vector or TensorDesc values, e.g.:
333+
// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
334+
// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
335+
// it could be either 1:1, 1:N or N:1 cast. In all cases, the pattern
336+
// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
337+
// TODO: remove it when context-aware type converter is ready.
338+
// It is safe only when input codes don't contain UnrealizedConversionCastOp.
339339
struct UnrealizedConversionCastOpPattern
340340
: public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
341341
using OpConversionPattern<
@@ -346,34 +346,28 @@ struct UnrealizedConversionCastOpPattern
346346
ConversionPatternRewriter &rewriter) const override {
347347
SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
348348

349-
// Handles the case where cast %1 : tensor_desc<16xf16> to
350-
// tensor_desc<128xf16, ...> The input values provided by the adaptor should
351-
// already be distributed.
352-
if (op.getNumOperands() == 1 && op.getNumResults() == 1 &&
353-
isa<xegpu::TensorDescType>(op->getOperand(0).getType()) &&
354-
isa<xegpu::TensorDescType>(op->getResult(0).getType())) {
355-
rewriter.replaceOp(op, inputs);
356-
return success();
357-
}
349+
auto inputTy = inputs[0].getType();
350+
auto outputTy = op->getOpResult(0).getType();
351+
352+
if (!llvm::all_equal(op->getResultTypes()) ||
353+
!llvm::all_equal(ValueRange(inputs).getTypes()) ||
354+
!isa<VectorType, xegpu::TensorDescType>(inputTy) ||
355+
!isa<VectorType, xegpu::TensorDescType>(outputTy))
356+
return failure();
358357

359358
// Handles the case where cast %1 : vector<256xf32> to vector<16xf32>, ...
360359
// the input values provided by the adaptor should already be distributed,
361360
// and their types should correspond exactly to the result types of the
362361
// operation.
363-
if (op.getNumOperands() == 1 &&
364-
llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
362+
if (op.getNumOperands() == 1) {
365363
rewriter.replaceOp(op, inputs);
366364
return success();
367365
}
368366

369367
// Handles the case where cast %1 : vector<16xf32>, ... to vector<256xf32>.
370368
// All input values must have the same vector type, and their shape must be
371369
// evenly divisible by the output vector's shape.
372-
auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
373-
auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
374-
if (op.getNumResults() == 1 && inputTy && outputTy &&
375-
llvm::all_equal(ValueRange(inputs).getTypes()) &&
376-
computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
370+
if (op.getNumResults() == 1) {
377371
rewriter.replaceOpWithMultiple(op, {inputs});
378372
return success();
379373
}

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,27 @@ gpu.module @test_round_robin_assignment {
103103
gpu.return
104104
}
105105

106+
gpu.func @test_scf_for(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
107+
%c1 = arith.constant 1 : index
108+
%c10 = arith.constant 10 : index
109+
%c0 = arith.constant 0 : index
110+
%c256 = arith.constant 256 : index
111+
%c1024 = arith.constant 1024 : index
112+
%0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
113+
%1 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
114+
// CHECK-LABEL: scf.for
115+
// CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
116+
%2:2 = scf.for %arg2 = %c0 to %c1024 step %c256 iter_args(%arg3 = %0, %arg4 = %1)
117+
-> (!xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>) {
118+
%3 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
119+
xegpu.store_nd %3, %arg3 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
120+
%4 = xegpu.update_nd_offset %arg3, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
121+
%5 = xegpu.update_nd_offset %arg4, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
122+
scf.yield %4, %5 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
123+
}
124+
gpu.return
125+
}
126+
106127
gpu.func @test_scf_while_and_condition(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
107128
%c1_i32 = arith.constant 1 : i32
108129
%c10_i32 = arith.constant 10 : i32

0 commit comments

Comments
 (0)