Skip to content

Commit d90bc3b

Browse files
authored
[mlir][XeGPU][VectorToXeGPU] Use 'xegpu.load' to lower 1D 'vector.transfer_read' for PVC & BMG (#168910)
The PR changes the `TransferReadLowering` to always use `xegpu.load` (and not `xegpu.load_nd`) for 1D cases as it has more developed interface (e.g. layouts capabilites). Signed-off-by: dchigarev <[email protected]>
1 parent d416289 commit d90bc3b

File tree

2 files changed

+69
-20
lines changed

2 files changed

+69
-20
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,13 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
519519
return lowerToScatteredLoadOp(readOp, rewriter);
520520
}
521521

522-
// Perform common data transfer checks.
523522
VectorType vecTy = readOp.getVectorType();
523+
524+
// Lower using load.gather in 1D case
525+
if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
526+
return lowerToScatteredLoadOp(readOp, rewriter);
527+
528+
// Perform common data transfer checks.
524529
if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
525530
return failure();
526531

mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
1111

1212
// LOAD-ND-LABEL: @load_1D_vector(
1313
// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
14-
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
15-
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
16-
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
17-
// LOAD-ND-SAME: %[[COLLAPSED]]
18-
// LOAD-ND-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
19-
// LOAD-ND-SAME: boundary_check = false
20-
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
21-
// LOAD-ND: return %[[VEC]]
14+
// LOAD-ND: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
15+
// LOAD-ND: %[[STEP:.+]] = vector.step : vector<8xindex>
16+
// LOAD-ND-COUNT2: arith.muli {{.*}} : index
17+
// LOAD-ND-COUNT2: arith.addi {{.*}} : index
18+
// LOAD-ND: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
19+
// LOAD-ND: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
20+
// LOAD-ND: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
21+
// LOAD-ND: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
22+
// LOAD-ND: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
2223

2324
// LOAD-GATHER-LABEL: @load_1D_vector(
2425
// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
@@ -404,27 +405,31 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
404405

405406
// -----
406407
gpu.module @xevm_module {
407-
gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
408+
gpu.func @load_from_subview_1D(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
408409
%c0 = arith.constant 0.0 : f16
409410
%subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
410411
%0 = vector.transfer_read %subview[%off2, %off2], %c0
411412
{in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16>
412413
gpu.return %0 : vector<8xf16>
413414
}
414415

415-
// LOAD-ND-LABEL: @load_from_subview(
416+
// LOAD-ND-LABEL: @load_from_subview_1D(
416417
// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
417418
// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
419+
// LOAD-ND: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
418420
// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
419-
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
420-
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
421-
// LOAD-ND-SAME: %[[COLLAPSED]]
422-
// LOAD-ND-SAME: memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
423-
// LOAD-ND-SAME: boundary_check = false
424-
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]]]{{.*}}-> vector<8xf16>
425-
// LOAD-ND: return %[[VEC]]
426-
427-
// LOAD-GATHER-LABEL: @load_from_subview(
421+
// LOAD-ND: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
422+
// LOAD-ND: %[[STEP:.+]] = vector.step : vector<8xindex>
423+
// LOAD-ND: arith.muli {{.*}} : index
424+
// LOAD-ND: arith.addi %[[OFFSET]]{{.*}} : index
425+
// LOAD-ND: arith.addi {{.*}} : index
426+
// LOAD-ND: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
427+
// LOAD-ND: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
428+
// LOAD-ND: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
429+
// LOAD-ND: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
430+
// LOAD-ND: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
431+
432+
// LOAD-GATHER-LABEL: @load_from_subview_1D(
428433
// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
429434
// LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
430435
// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
@@ -440,3 +445,42 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2:
440445
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
441446
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
442447
}
448+
449+
// -----
450+
gpu.module @xevm_module {
451+
gpu.func @load_from_subview_2D(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8x16xf16> {
452+
%c0 = arith.constant 0.0 : f16
453+
%subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
454+
%0 = vector.transfer_read %subview[%off2, %off2], %c0
455+
{in_bounds = [true, true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8x16xf16>
456+
gpu.return %0 : vector<8x16xf16>
457+
}
458+
459+
// LOAD-ND-LABEL: @load_from_subview_2D(
460+
// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
461+
// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
462+
// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
463+
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
464+
// LOAD-ND-SAME: %[[SUBVIEW]]
465+
// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf16,
466+
// LOAD-ND-SAME: boundary_check = false
467+
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8x16xf16>
468+
// LOAD-ND: return %[[VEC]]
469+
470+
// LOAD-GATHER-LABEL: @load_from_subview_2D(
471+
// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
472+
// LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
473+
// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
474+
// LOAD-GATHER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
475+
// LOAD-GATHER: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
476+
// LOAD-GATHER-COUNT2: vector.step
477+
// LOAD-GATHER-COUNT2: vector.shape_cast
478+
// LOAD-GATHER-COUNT2: vector.broadcast
479+
// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
480+
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
481+
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
482+
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<8x16xindex>
483+
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
484+
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
485+
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
486+
}

0 commit comments

Comments
 (0)