Skip to content

Commit 912b0fe

Browse files
committed
add tests and code changes
1 parent 93e18db commit 912b0fe

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ class XeGPUCreateNdDescOpPattern final
284284

285285
// If the source is a static memref, we need to extract the pointer to
286286
// base address.
287-
if (memrefType && memrefType.hasStaticShape()) {
287+
if (memrefType) {
288288
auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
289289
rewriter, createNdOp.getLoc(), source);
290290
source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),

mlir/test/Dialect/XeGPU/optimize-transpose.mlir renamed to mlir/test/Dialect/XeGPU/optimize-block-loads.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,32 @@ gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg
278278
gpu.return
279279
}
280280
}
281+
282+
// -----
283+
// CHECK-LABEL: gpu.func @dynamic_memref(
284+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf16>, %{{.*}}: vector<8x16xf16>) -> vector<8x16xf32> {
285+
// CHECK-DAG: %[[C32:.*]] = arith.constant 16 : index
286+
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
287+
// CHECK-NEXT: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?xf16> -> index
288+
// CHECK-NEXT: %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64
289+
// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C32]]], strides : [%[[C32]], 1] : i64
290+
// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
291+
// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[T1]][%{{.*}}, %[[C16]]] {layout_result_0 =
292+
// CHECK-SAME: #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x8xi32,
293+
// CHECK-SAME: #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
294+
// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T2]] {layout_result_0 =
295+
// CHECK-SAME: #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xi32> to vector<16x16xf16>
296+
#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
297+
#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
298+
#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
299+
gpu.module @xevm_module {
300+
gpu.func @dynamic_memref(%arg0: memref<?x?xf16>, %arg1: vector<8x16xf16>) -> vector<8x16xf32> {
301+
%c0 = arith.constant 0 : index
302+
%c32 = arith.constant 32 : index
303+
%0 = xegpu.create_nd_tdesc %arg0, shape : [64, 64], strides : [64, 1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
304+
%1 = xegpu.load_nd %0[%c0, %c32] { result_layout = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
305+
%2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
306+
%6 = xegpu.dpas %arg1, %2 { layout_result_0 = #a } : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
307+
gpu.return %6 : vector<8x16xf32>
308+
}
309+
}

0 commit comments

Comments
 (0)