Skip to content

Commit 1a223e9

Browse files
authored
[Test] Add SIMT GEMM + transpose B case and restructure the integration tests. (#1095)
1 parent bc95b47 commit 1a223e9

File tree

101 files changed

+605
-23
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

101 files changed

+605
-23
lines changed

lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -179,48 +179,49 @@ class CreateNdDescToXeVMPattern
179179
Value baseShapeH;
180180
Value offsetW;
181181
Value offsetH;
182+
auto convertToValue = [&](OpFoldResult ofr) -> Value {
183+
Value val;
184+
if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
185+
val = rewriter.create<arith::IndexCastOp>(loc, i64Ty, v);
186+
val = rewriter.create<arith::TruncIOp>(loc, payloadElemTy, val);
187+
} else {
188+
int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
189+
val = rewriter.create<arith::ConstantIntOp>(loc, payloadElemTy, off);
190+
}
191+
return val;
192+
};
193+
194+
int rank = op.getMixedOffsets().size();
195+
if (rank != 2) {
196+
op.emitError() << "Expected 2D offsets, got " << rank << "D offsets.";
197+
return mlir::failure();
198+
}
199+
offsetW = convertToValue(op.getMixedOffsets()[rank - 1]);
200+
offsetH = convertToValue(op.getMixedOffsets()[rank - 2]);
182201

183202
if (auto sourceTy = source.getType(); isa<MemRefType>(sourceTy)) {
184203
baseAddr =
185204
rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(loc, source);
205+
baseAddr = rewriter.create<arith::IndexCastUIOp>(loc, i64Ty, baseAddr);
186206
auto sourceMemrefTy = cast<MemRefType>(sourceTy);
187207
if (!sourceMemrefTy.hasStaticShape()) {
188208
op.emitError() << "Expected static memref shape.";
189209
return mlir::failure();
190210
}
191211
auto rank = sourceMemrefTy.getRank();
192-
if (rank != 2) {
193-
op.emitError() << "Expected a 2D memref.";
194-
return mlir::failure();
195-
}
196-
auto createOffset = [&](unsigned idx) -> Value {
197-
Value val;
198-
OpFoldResult ofr = op.getMixedOffsets()[idx];
199-
if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
200-
val = rewriter.create<arith::IndexCastOp>(loc, i64Ty, v);
201-
val = rewriter.create<arith::TruncIOp>(loc, payloadElemTy, val);
202-
} else {
203-
int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
204-
val = rewriter.create<arith::ConstantIntOp>(loc, payloadElemTy, off);
205-
}
206-
return val;
207-
};
208-
offsetW = createOffset(rank - 1);
209-
offsetH = createOffset(rank - 2);
210212
baseShapeW = rewriter.create<arith::ConstantIntOp>(
211213
loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 1));
212214
baseShapeH = rewriter.create<arith::ConstantIntOp>(
213215
loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 2));
214216
} else if (isa<IntegerType>(sourceTy)) {
215-
op.emitError()
216-
<< "Integer as source are currently not supported by the pass.";
217-
return mlir::failure();
217+
baseAddr = source;
218+
baseShapeW = convertToValue(op.getMixedSizes()[rank - 1]);
219+
baseShapeH = convertToValue(op.getMixedSizes()[rank - 2]);
218220
} else {
219221
op.emitError() << "Unknown source type.";
220222
return mlir::failure();
221223
}
222224

223-
baseAddr = rewriter.create<arith::IndexCastUIOp>(loc, i64Ty, baseAddr);
224225
Value payLoadAsI64 =
225226
rewriter.create<vector::BitCastOp>(loc, payloadI64Ty, payload);
226227
payLoadAsI64 = rewriter.create<vector::InsertOp>(

test/Conversion/XeGPUToXeVM/loadstore_nd.mlir

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: imex-opt -convert-xegpu-to-xevm %s | FileCheck %s
1+
// RUN: imex-opt -convert-xegpu-to-xevm -allow-unregistered-dialect %s | FileCheck %s
22

33
gpu.module @load_store_check {
44
gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
@@ -66,4 +66,31 @@ gpu.module @load_store_check {
6666
xegpu.store_nd %loaded_modified, %dst_tdesc <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
6767
gpu.return
6868
}
69+
70+
gpu.func @create_nd_tdesc_integer_source(%src: i64, %src_h : index, %src_w : index) kernel {
71+
%c1 = arith.constant 1 : index
72+
%c4 = arith.constant 4 : index
73+
%c8 = arith.constant 8 : index
74+
%c0 = arith.constant 0 : index
75+
// CHECK: %[[PAYLOAD:.*]] = arith.constant dense<0> : vector<8xi32>
76+
// CHECK: %[[T0:.*]] = arith.index_cast %{{.*}} : index to i64
77+
// CHECK: %[[T1:.*]] = arith.trunci %[[T0]] : i64 to i32
78+
// CHECK: %[[T2:.*]] = arith.index_cast %{{.*}} : index to i64
79+
// CHECK: %[[T3:.*]] = arith.trunci %[[T2]] : i64 to i32
80+
// CHECK: %[[T4:.*]] = arith.index_cast %{{.*}} : index to i64
81+
// CHECK: %[[T5:.*]] = arith.trunci %[[T4]] : i64 to i32
82+
// CHECK: %[[T6:.*]] = arith.index_cast %{{.*}} : index to i64
83+
// CHECK: %[[T7:.*]] = arith.trunci %[[T6]] : i64 to i32
84+
// CHECK: %[[T8:.*]] = vector.bitcast %[[PAYLOAD]] : vector<8xi32> to vector<4xi64>
85+
// CHECK: %[[T9:.*]] = vector.insert %{{.*}}, %[[T8]] [0] : i64 into vector<4xi64>
86+
// CHECK: %[[T10:.*]] = vector.bitcast %[[T9]] : vector<4xi64> to vector<8xi32>
87+
// CHECK: %[[T11:.*]] = vector.insert %[[T5]], %[[T10]] [2] : i32 into vector<8xi32>
88+
// CHECK: %[[T12:.*]] = vector.insert %[[T7]], %[[T11]] [3] : i32 into vector<8xi32>
89+
// CHECK: %[[T13:.*]] = vector.insert %[[T1]], %[[T12]] [4] : i32 into vector<8xi32>
90+
// CHECK: %[[T14:.*]] = vector.insert %[[T3]], %[[T13]] [5] : i32 into vector<8xi32>
91+
%src_tdesc = xegpu.create_nd_tdesc %src [%c4, %c8], [%src_h, %src_w], [%src_w, %c1] : i64
92+
-> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
93+
"some_op"(%src_tdesc) : (!xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>) -> ()
94+
gpu.return
95+
}
6996
}

0 commit comments

Comments
 (0)