Skip to content
38 changes: 33 additions & 5 deletions mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,46 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
xegpu::TensorDescType descType,
TypedValue<MemRefType> src) {
MemRefType srcTy = src.getType();
assert(srcTy.isStrided() && "Expected strided memref type");
auto [strides, offset] = srcTy.getStridesAndOffset();
bool isStatic = true;

// Memref is dynamic if any of its shape, offset or strides is dynamic.
if (!srcTy.hasStaticShape())
isStatic = false;

if (!ShapedType::isStatic(offset))
isStatic = false;

for (auto stride : strides) {
if (!ShapedType::isStatic(stride)) {
isStatic = false;
break;
}
}

xegpu::CreateNdDescOp ndDesc;
if (srcTy.hasStaticShape()) {
if (isStatic) {
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
} else {
// In case of any dynamic shapes, source's shape and strides have to be
// In case of ranked dynamic memref, instead of passing on the memref,
// i64 base address, source's offset, shape and strides have to be
// explicitly provided.
auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
meta.getConstifiedMixedSizes(),
meta.getConstifiedMixedStrides());
auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
rewriter, loc, meta.getBaseBuffer());
auto offset = meta.getOffset();
auto elemByteSize = srcTy.getElementTypeBitWidth() / 8;
auto offsetInBytes = arith::MulIOp::create(
rewriter, loc, offset,
arith::ConstantIndexOp::create(rewriter, loc, elemByteSize));
auto adjustedBaseAddr = arith::AddIOp::create(
rewriter, loc, baseAddrIndex.getResult(), offsetInBytes);
auto adjustedAddrI64 = arith::IndexCastOp::create(
rewriter, loc, rewriter.getI64Type(), adjustedBaseAddr);
ndDesc = xegpu::CreateNdDescOp::create(
rewriter, loc, descType, adjustedAddrI64,
meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
}

return ndDesc;
Expand Down
35 changes: 27 additions & 8 deletions mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,17 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
// CHECK-LABEL: @load_1D_vector(
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[COLLAPSED]]
// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// CHECK: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// CHECK-SAME: : memref<f32> -> index
// CHECK: %[[MUL:.+]] = arith.muli %[[OFFSET1]], %[[ELEM_BYTES]] : index
// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [32],
// CHECK-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
// CHECK: return %[[VEC]]
Expand All @@ -29,10 +36,16 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
// CHECK-LABEL: @load_2D_vector(
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[COLLAPSED]]
// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// CHECK-SAME: : memref<f32> -> index
// CHECK: %[[MUL:.+]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index
// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32],
// CHECK-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]

Expand All @@ -48,9 +61,15 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// CHECK-LABEL: @load_dynamic_source(
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// CHECK: %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
// CHECK: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
// CHECK: %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
// CHECK-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]

Expand Down
35 changes: 27 additions & 8 deletions mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<8xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[ELEM_BYTES:.*]] = arith.constant 4 : index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[COLLAPSED]]
// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// CHECK: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// CHECK-SAME: : memref<f32> -> index
// CHECK: %[[MUL:.+]] = arith.muli %[[OFFSET1]], %[[ELEM_BYTES]] : index
// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [32],
// CHECK-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>

Expand All @@ -31,10 +38,16 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[ELEM_BYTES:.*]] = arith.constant 4 : index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[COLLAPSED]]
// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// CHECK-SAME: : memref<f32> -> index
// CHECK: %[[MUL:.+]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index
// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32],
// CHECK-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>

// -----
Expand All @@ -50,9 +63,15 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// CHECK: %[[ELEM_BYTES:.*]] = arith.constant 4 : index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
// CHECK: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
// CHECK: %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
// CHECK-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>

// -----
Expand Down
44 changes: 34 additions & 10 deletions mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,16 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
// LOAD-ND-LABEL: @load_2D_vector(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
// LOAD-ND: %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
// LOAD-ND-SAME: %[[COLLAPSED]]
// LOAD-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// LOAD-ND-SAME: : memref<f32> -> index
// LOAD-ND: %[[MUL:.*]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index
// LOAD-ND: %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32],
// LOAD-ND-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
// LOAD-ND-SAME: boundary_check = false
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
Expand Down Expand Up @@ -147,9 +153,16 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// LOAD-ND-LABEL: @load_dynamic_source(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
// LOAD-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
// LOAD-ND: %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
// LOAD-ND: %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
// LOAD-ND-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
// LOAD-ND-SAME: #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]

Expand Down Expand Up @@ -184,8 +197,15 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
// LOAD-ND-LABEL: @load_dynamic_source2(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x8x16xf32>,
// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32, strided<[16, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// LOAD-ND: %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
// LOAD-ND: %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
// LOAD-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [8, 16], strides : [16, 1] :
// LOAD-ND-SAME: i64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
// LOAD-ND: return %[[VEC]] : vector<8x16xf32>

Expand Down Expand Up @@ -459,11 +479,15 @@ gpu.func @load_from_subview_2D(%source: memref<4096x4096xf16>, %off1: index, %of
// LOAD-ND-LABEL: @load_from_subview_2D(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[ELEM_BYTES:.+]] = arith.constant 2 : index
// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
// LOAD-ND-SAME: %[[SUBVIEW]]
// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf16,
// LOAD-ND-SAME: boundary_check = false
// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[SUBVIEW]]
// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// LOAD-ND: %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
// LOAD-ND: %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
// LOAD-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [256, 256], strides : [4096, 1] :
// LOAD-ND-SAME: i64 -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8x16xf16>
// LOAD-ND: return %[[VEC]]

Expand Down
Loading