From 5e07a46f98d641cc43f49780b9993940ef0c735f Mon Sep 17 00:00:00 2001 From: Harrison Hao Date: Tue, 5 Nov 2024 20:11:01 +0800 Subject: [PATCH 1/6] [mlir][Vector] Support 0-d vectors natively in VectorStoreToMemrefStoreLowering. --- .../Vector/Transforms/LowerVectorTransfer.cpp | 12 ++----- .../VectorToLLVM/vector-to-llvm.mlir | 5 ++- mlir/test/Dialect/SPIRV/IR/availability.mlir | 36 +++++++++---------- .../vector-transfer-to-vector-load-store.mlir | 8 ++--- 4 files changed, 26 insertions(+), 35 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index f9428a4ce2864..1cb3baaef82ba 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -530,15 +530,9 @@ struct VectorStoreToMemrefStoreLowering return rewriter.notifyMatchFailure(storeOp, "not single element vector"); Value extracted; - if (vecType.getRank() == 0) { - // TODO: Unifiy once ExtractOp supports 0-d vectors. - extracted = rewriter.create( - storeOp.getLoc(), storeOp.getValueToStore()); - } else { - SmallVector indices(vecType.getRank(), 0); - extracted = rewriter.create( - storeOp.getLoc(), storeOp.getValueToStore(), indices); - } + SmallVector indices(vecType.getRank(), 0); + extracted = rewriter.create( + storeOp.getLoc(), storeOp.getValueToStore(), indices); rewriter.replaceOpWithNewOp( storeOp, extracted, storeOp.getBase(), storeOp.getIndices()); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index c1de24fd0403c..abbdbe02ce6c1 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -2971,9 +2971,8 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in // CHECK-LABEL: func @vector_store_op_0d // CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector to vector<1xf32> -// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32> -// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}] +// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<1xf32> to f32 +// CHECK: memref.store %[[cast2]], %{{.*}}[%{{.*}}, %{{.*}}] // ----- diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir index c583a48eba270..ceebeeffcf267 100644 --- a/mlir/test/Dialect/SPIRV/IR/availability.mlir +++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir @@ -58,7 +58,7 @@ func.func @module_physical_storage_buffer64_vulkan() { func.func @sdot_scalar_i32_i32(%a: i32) -> i32 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] %r = spirv.SDot %a, %a, : i32 -> i32 return %r: i32 @@ -68,7 +68,7 @@ func.func @sdot_scalar_i32_i32(%a: i32) -> i32 { func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] %r = spirv.SDot %a, %a: vector<4xi8> -> i64 return %r: i64 @@ -78,7 +78,7 @@ func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 { func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] %r = spirv.SDot %a, %a: vector<4xi16> -> i64 return %r: i64 @@ -88,7 +88,7 @@ func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 { func.func @sudot_scalar_i32_i32(%a: i32) -> i32 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] %r = spirv.SUDot %a, %a, : i32 -> i32 return %r: i32 @@ -98,7 +98,7 @@ func.func @sudot_scalar_i32_i32(%a: i32) -> i32 { func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] %r = spirv.SUDot %a, %a: vector<4xi8> -> i64 return %r: i64 @@ -108,7 +108,7 @@ func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 { func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] %r = spirv.SUDot %a, %a: vector<4xi16> -> i64 return %r: i64 @@ -118,7 +118,7 @@ func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 { func.func @udot_scalar_i32_i32(%a: i32) -> i32 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] %r = spirv.UDot %a, %a, : i32 -> i32 return %r: i32 @@ -128,7 +128,7 @@ func.func @udot_scalar_i32_i32(%a: i32) -> i32 { func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] %r = spirv.UDot %a, %a: vector<4xi8> -> i64 return %r: i64 @@ -138,7 +138,7 @@ func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 { func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] %r = spirv.UDot %a, %a: vector<4xi16> -> i64 return %r: i64 @@ -148,7 +148,7 @@ func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 { func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] %r = spirv.SDotAccSat %a, %a, %a, : i32 -> i32 return %r: i32 @@ -158,7 +158,7 @@ func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 { func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi8> -> i64 return %r: i64 @@ -168,7 +168,7 @@ func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 { func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi16> -> i64 return %r: i64 @@ -178,7 +178,7 @@ func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 { func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] %r = spirv.SUDotAccSat %a, %a, %a, : i32 -> i32 return %r: i32 @@ -188,7 +188,7 @@ func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 { func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi8> -> i64 return %r: i64 @@ -198,7 +198,7 @@ func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 { func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi16> -> i64 return %r: i64 @@ -208,7 +208,7 @@ func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 { func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] %r = spirv.UDotAccSat %a, %a, %a, : i32 -> i32 return %r: i32 @@ -218,7 +218,7 @@ func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 { func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi8> -> i64 return %r: i64 @@ -228,7 +228,7 @@ func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 { func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi16> -> i64 return %r: i64 diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index f90111b4c8861..f75f8f8489efc 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -7,15 +7,13 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref, %vec: vector<1x1x1xf %f0 = arith.constant 0.0 : f32 // CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref -// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector %0 = vector.transfer_read %mem[], %f0 : memref, vector -// CHECK-NEXT: %[[SS:.*]] = vector.extractelement %[[V]][] : vector -// CHECK-NEXT: memref.store %[[SS]], %[[MEM]][] : memref +// CHECK-NEXT: memref.store %[[S]], %[[MEM]][] : memref vector.transfer_write %0, %mem[] : vector, memref -// CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32> -// CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref +// CHECK-NEXT: %[[V:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32> +// CHECK-NEXT: memref.store %[[V]], %[[MEM]][] : memref vector.store %vec, %mem[] : memref, vector<1x1x1xf32> return From 3f9f7b0753fabc59d4dda6072c72fb6939b64858 Mon Sep 17 00:00:00 2001 From: Harrison Hao Date: Tue, 5 Nov 2024 14:41:31 +0000 Subject: [PATCH 2/6] [MLIR] Fix the lit test failure issue. --- mlir/test/Dialect/SPIRV/IR/availability.mlir | 36 ++++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir index ceebeeffcf267..c583a48eba270 100644 --- a/mlir/test/Dialect/SPIRV/IR/availability.mlir +++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir @@ -58,7 +58,7 @@ func.func @module_physical_storage_buffer64_vulkan() { func.func @sdot_scalar_i32_i32(%a: i32) -> i32 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] %r = spirv.SDot %a, %a, : i32 -> i32 return %r: i32 @@ -68,7 +68,7 @@ func.func @sdot_scalar_i32_i32(%a: i32) -> i32 { func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] %r = spirv.SDot %a, %a: vector<4xi8> -> i64 return %r: i64 @@ -78,7 +78,7 @@ func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 { func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] %r = spirv.SDot %a, %a: vector<4xi16> -> i64 return %r: i64 @@ -88,7 +88,7 @@ func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 { func.func @sudot_scalar_i32_i32(%a: i32) -> i32 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] %r = spirv.SUDot %a, %a, : i32 -> i32 return %r: i32 @@ -98,7 +98,7 @@ func.func @sudot_scalar_i32_i32(%a: i32) -> i32 { func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] %r = spirv.SUDot %a, %a: vector<4xi8> -> i64 return %r: i64 @@ -108,7 +108,7 @@ func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 { func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] %r = spirv.SUDot %a, %a: vector<4xi16> -> i64 return %r: i64 @@ -118,7 +118,7 @@ func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 { func.func @udot_scalar_i32_i32(%a: i32) -> i32 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] %r = spirv.UDot %a, %a, : i32 -> i32 return %r: i32 @@ -128,7 +128,7 @@ func.func @udot_scalar_i32_i32(%a: i32) -> i32 { func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] %r = spirv.UDot %a, %a: vector<4xi8> -> i64 return %r: i64 @@ -138,7 +138,7 @@ func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 { func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] %r = spirv.UDot %a, %a: vector<4xi16> -> i64 return %r: i64 @@ -148,7 +148,7 @@ func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 { func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] %r = spirv.SDotAccSat %a, %a, %a, : i32 -> i32 return %r: i32 @@ -158,7 +158,7 @@ func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 { func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi8> -> i64 return %r: i64 @@ -168,7 +168,7 @@ func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 { func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi16> -> i64 return %r: i64 @@ -178,7 +178,7 @@ func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 { func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] %r = spirv.SUDotAccSat %a, %a, %a, : i32 -> i32 return %r: i32 @@ -188,7 +188,7 @@ func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 { func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi8> -> i64 return %r: i64 @@ -198,7 +198,7 @@ func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 { func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi16> -> i64 return %r: i64 @@ -208,7 +208,7 @@ func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 { func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] %r = spirv.UDotAccSat %a, %a, %a, : i32 -> i32 return %r: i32 @@ -218,7 +218,7 @@ func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 { func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi8> -> i64 return %r: i64 @@ -228,7 +228,7 @@ func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 { func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 { // CHECK: min version: v1.0 // CHECK: max version: v1.6 - // CHECK: extensions: [ [SPV_KHR_16bit_storage] ] + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi16> -> i64 return %r: i64 From 5f61d162694bbb3e7d620d7dc1f4a100bc17716d Mon Sep 17 00:00:00 2001 From: Harrison Hao Date: Wed, 6 Nov 2024 15:48:41 +0000 Subject: [PATCH 3/6] [MLIR] Update comments. --- mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 1cb3baaef82ba..6c50473232e1b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -518,7 +518,7 @@ struct VectorLoadToMemrefLoadLowering } }; -/// Replace a 0-d vector.store with a vector.extractelement + memref.store. +/// Replace a vector.store with a vector.extract + memref.store. struct VectorStoreToMemrefStoreLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; From 6db82e0463ce7a0852925ea615c73b1dae63b23d Mon Sep 17 00:00:00 2001 From: Harrison Hao Date: Wed, 6 Nov 2024 15:51:12 +0000 Subject: [PATCH 4/6] [MLIR] Update comments again. --- mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 6c50473232e1b..6f033cbe02509 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -518,7 +518,7 @@ struct VectorLoadToMemrefLoadLowering } }; -/// Replace a vector.store with a vector.extract + memref.store. +/// Replace a vector.store with a vector.extractelement + memref.store. struct VectorStoreToMemrefStoreLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; From 6a5ac3ca2853b2b5822e8c642cd671038c1c6988 Mon Sep 17 00:00:00 2001 From: Harrison Hao Date: Sun, 10 Nov 2024 13:16:53 +0000 Subject: [PATCH 5/6] [MLIR][Vector] Remove 0-d corner case condition. --- .../Vector/Transforms/LowerVectorTransfer.cpp | 20 +++++++++---------- .../VectorToLLVM/vector-to-llvm.mlir | 5 +++-- .../vector-transfer-to-vector-load-store.mlir | 8 +++++--- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 6f033cbe02509..a953b24220701 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -97,9 +97,6 @@ struct TransferReadPermutationLowering matchAndRewriteMaskableOp(vector::TransferReadOp op, MaskingOpInterface maskOp, PatternRewriter &rewriter) const override { - // TODO: support 0-d corner case. - if (op.getTransferRank() == 0) - return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); // TODO: Support transfer_read inside MaskOp case. if (maskOp) return rewriter.notifyMatchFailure(op, "Masked case not supported"); @@ -326,9 +323,6 @@ struct TransferOpReduceRank matchAndRewriteMaskableOp(vector::TransferReadOp op, MaskingOpInterface maskOp, PatternRewriter &rewriter) const override { - // TODO: support 0-d corner case. - if (op.getTransferRank() == 0) - return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); // TODO: support masked case. if (maskOp) return rewriter.notifyMatchFailure(op, "Masked case not supported"); @@ -518,7 +512,7 @@ struct VectorLoadToMemrefLoadLowering } }; -/// Replace a vector.store with a vector.extractelement + memref.store. +/// Replace a 0-d vector.store with a vector.extractelement + memref.store. struct VectorStoreToMemrefStoreLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -530,9 +524,15 @@ struct VectorStoreToMemrefStoreLowering return rewriter.notifyMatchFailure(storeOp, "not single element vector"); Value extracted; - SmallVector indices(vecType.getRank(), 0); - extracted = rewriter.create( - storeOp.getLoc(), storeOp.getValueToStore(), indices); + if (vecType.getRank() == 0) { + // TODO: Unifiy once ExtractOp supports 0-d vectors. + extracted = rewriter.create( + storeOp.getLoc(), storeOp.getValueToStore()); + } else { + SmallVector indices(vecType.getRank(), 0); + extracted = rewriter.create( + storeOp.getLoc(), storeOp.getValueToStore(), indices); + } rewriter.replaceOpWithNewOp( storeOp, extracted, storeOp.getBase(), storeOp.getIndices()); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index abbdbe02ce6c1..c1de24fd0403c 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -2971,8 +2971,9 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in // CHECK-LABEL: func @vector_store_op_0d // CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector to vector<1xf32> -// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<1xf32> to f32 -// CHECK: memref.store %[[cast2]], %{{.*}}[%{{.*}}, %{{.*}}] +// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32> +// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}] // ----- diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index f75f8f8489efc..f90111b4c8861 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -7,13 +7,15 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref, %vec: vector<1x1x1xf %f0 = arith.constant 0.0 : f32 // CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref +// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector %0 = vector.transfer_read %mem[], %f0 : memref, vector -// CHECK-NEXT: memref.store %[[S]], %[[MEM]][] : memref +// CHECK-NEXT: %[[SS:.*]] = vector.extractelement %[[V]][] : vector +// CHECK-NEXT: memref.store %[[SS]], %[[MEM]][] : memref vector.transfer_write %0, %mem[] : vector, memref -// CHECK-NEXT: %[[V:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32> -// CHECK-NEXT: memref.store %[[V]], %[[MEM]][] : memref +// CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32> +// CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref vector.store %vec, %mem[] : memref, vector<1x1x1xf32> return From 700b0ea025a57fe888f25b492b40de7502159b6f Mon Sep 17 00:00:00 2001 From: Harrison Hao Date: Fri, 15 Nov 2024 11:00:28 +0800 Subject: [PATCH 6/6] [MLIR] Remove Transfer vectore lower pattern. --- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 2 - .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 3 - .../Vector/Transforms/LowerVectorTransfer.cpp | 12 +- .../Conversion/GPUCommon/transfer_write.mlir | 5 +- .../VectorToLLVM/vector-to-llvm.mlir | 32 +- .../VectorToLLVM/vector-xfer-to-llvm.mlir | 319 ++---------------- .../test/Dialect/Vector/transform-vector.mlir | 3 +- .../vector-transfer-to-vector-load-store.mlir | 93 +++-- 8 files changed, 97 insertions(+), 372 deletions(-) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 58ca84c8d7bca..155b2241b7a93 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1910,8 +1910,6 @@ void mlir::populateVectorToLLVMConversionPatterns( MaskedReductionOpConversion, VectorInterleaveOpLowering, VectorDeinterleaveOpLowering, VectorFromElementsLowering, VectorScalableStepOpLowering>(converter); - // Transfer ops with rank > 1 are handled by VectorToSCF. - populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); } void mlir::populateVectorToLLVMMatrixConversionPatterns( diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 4623b9667998c..7635e10822a34 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -74,8 +74,6 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorInterleaveLoweringPatterns(patterns); populateVectorTransposeLoweringPatterns(patterns, VectorTransformsOptions()); - // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. - populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } @@ -84,7 +82,6 @@ void ConvertVectorToLLVMPass::runOnOperation() { LLVMTypeConverter converter(&getContext(), options); RewritePatternSet patterns(&getContext()); populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices); - populateVectorTransferLoweringPatterns(patterns); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns( converter, patterns, reassociateFPReductions, force32BitVectorIndices); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index a953b24220701..484363c6b1d8d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -636,10 +636,10 @@ struct TransferWriteToVectorStoreLowering void mlir::vector::populateVectorTransferLoweringPatterns( RewritePatternSet &patterns, std::optional maxTransferRank, PatternBenefit benefit) { - patterns.add(patterns.getContext(), - maxTransferRank, benefit); - patterns - .add( - patterns.getContext(), benefit); + // patterns.add(patterns.getContext(), + // maxTransferRank, benefit); + // patterns + // .add( + // patterns.getContext(), benefit); } diff --git a/mlir/test/Conversion/GPUCommon/transfer_write.mlir b/mlir/test/Conversion/GPUCommon/transfer_write.mlir index cd62b7b13fa9a..d1127e4203c7b 100644 --- a/mlir/test/Conversion/GPUCommon/transfer_write.mlir +++ b/mlir/test/Conversion/GPUCommon/transfer_write.mlir @@ -3,10 +3,7 @@ func.func @warp_extract(%arg0: index, %arg1: memref<1024x1024xf32>, %arg2: index, %arg3: vector<1xf32>) { %c0 = arith.constant 0 : index vector.warp_execute_on_lane_0(%arg0)[32] { - // CHECK:%[[val:[0-9]+]] = llvm.extractelement - // CHECK:%[[base:[0-9]+]] = llvm.extractvalue - // CHECK:%[[ptr:[0-9]+]] = llvm.getelementptr %[[base]] - // CHECK:llvm.store %[[val]], %[[ptr]] + // CHECK: vector.transfer_write %arg9, %[[MEM:.*]][%[[IDX:.*]], %[[IDX]]] {in_bounds = [true]} : vector<1xf32>, memref<1024x1024xf32> vector.transfer_write %arg3, %arg1[%c0, %c0] {in_bounds = [true]} : vector<1xf32>, memref<1024x1024xf32> } return diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index c1de24fd0403c..be230cfcbd6e5 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -2953,12 +2953,16 @@ func.func @vector_load_op_0d(%memref : memref<200x100xf32>, %i : index, %j : ind } // CHECK-LABEL: func @vector_load_op_0d -// CHECK: %[[load:.*]] = memref.load %{{.*}}[%{{.*}}, %{{.*}}] -// CHECK: %[[vec:.*]] = llvm.mlir.undef : vector<1xf32> -// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK: %[[inserted:.*]] = llvm.insertelement %[[load]], %[[vec]][%[[c0]] : i32] : vector<1xf32> -// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[inserted]] : vector<1xf32> to vector -// CHECK: return %[[cast]] : vector +// CHECK: %[[S0:.*]] = builtin.unrealized_conversion_cast %arg2 : index to i64 +// CHECK: %[[S1:.*]] = builtin.unrealized_conversion_cast %arg1 : index to i64 +// CHECK: %[[S2:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[S3:.*]] = llvm.extractvalue %[[S2]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[S4:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK: %[[S5:.*]] = llvm.mul %[[S1]], %[[S4]] : i64 +// CHECK: %[[S6:.*]] = llvm.add %[[S5]], %[[S0]] : i64 +// CHECK: %[[S7:.*]] = llvm.getelementptr %[[S3]][%[[S6]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: %[[S8:.*]] = llvm.load %[[S7]] {alignment = 4 : i64} : !llvm.ptr -> vector<1xf32> +// CHECK: %[[S9:.*]] = builtin.unrealized_conversion_cast %[[S8]] : vector<1xf32> to vector // ----- @@ -2969,11 +2973,17 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in } // CHECK-LABEL: func @vector_store_op_0d -// CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector -// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector to vector<1xf32> -// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32> -// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}] +// CHECK: %[[S0:.*]] = builtin.unrealized_conversion_cast %arg2 : index to i64 +// CHECK: %[[S1:.*]] = builtin.unrealized_conversion_cast %arg1 : index to i64 +// CHECK: %[[S2:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[S3:.*]] = arith.constant dense<1.100000e+01> : vector +// CHECK: %[[S4:.*]] = builtin.unrealized_conversion_cast %[[S3]] : vector to vector<1xf32> +// CHECK: %[[S5:.*]] = llvm.extractvalue %[[S2]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[S6:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK: %[[S7:.*]] = llvm.mul %[[S1]], %[[S6]] : i64 +// CHECK: %[[S8:.*]] = llvm.add %[[S7]], %[[S0]] : i64 +// CHECK: %[[S9:.*]] = llvm.getelementptr %[[S5]][%[[S8]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: llvm.store %[[S4]], %[[S9]] {alignment = 4 : i64} : vector<1xf32>, !llvm.ptr // ----- diff --git a/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir index 8f01cc2b8d44c..112e868e12107 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir @@ -12,67 +12,11 @@ func.func @transfer_read_write_1d(%A : memref, %base: index) -> vector<17 return %f: vector<17xf32> } // CHECK-LABEL: func @transfer_read_write_1d -// CHECK-SAME: %[[MEM:.*]]: memref, -// CHECK-SAME: %[[BASE:.*]]: index) -> vector<17xf32> -// CHECK: %[[C7:.*]] = arith.constant 7.0 -// -// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset) -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref -// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index -// -// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. -// CHECK: %[[linearIndex:.*]] = arith.constant dense -// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]> -// -// 3. Create bound vector to compute in-bound mask: -// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ] -// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : -// CMP32-SAME: index to i32 -// CMP64-SAME: index to i64 -// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]] -// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]] -// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]] : vector<17x[[$IDX_TYPE]]> -// CMP64-SAME: : vector<17xi64> -// -// 4. Create pass-through vector. -// CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<17xf32> -// -// 5. Bitcast to vector form. -// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} : -// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32 -// -// 6. Rewrite as a masked read. -// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[gep]], %[[mask]], -// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} : -// CHECK-SAME: -> vector<17xf32> -// -// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset) -// CHECK: %[[C0_b:.*]] = arith.constant 0 : index -// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref -// CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index -// -// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. -// CHECK: %[[linearIndex_b:.*]] = arith.constant dense -// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]> -// -// 3. Create bound vector to compute in-bound mask: -// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ] -// CHECK: %[[btrunc_b:.*]] = arith.index_cast %[[BOUND_b]] -// CMP32-SAME: index to i32 -// CHECK: %[[boundVecInsert_b:.*]] = llvm.insertelement %[[btrunc_b]] -// CHECK: %[[boundVect_b:.*]] = llvm.shufflevector %[[boundVecInsert_b]] -// CHECK: %[[mask_b:.*]] = arith.cmpi slt, %[[linearIndex_b]], -// CHECK-SAME: %[[boundVect_b]] : vector<17x[[$IDX_TYPE]]> -// -// 4. Bitcast to vector form. -// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} : -// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32 -// -// 5. Rewrite as a masked write. -// CHECK: llvm.intr.masked.store %[[loaded]], %[[gep_b]], %[[mask_b]] -// CHECK-SAME: {alignment = 4 : i32} : -// CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr +// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32 +// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1], %[[CST]] : memref, vector<17xf32> +// CHECK: vector.transfer_write %[[TRANSFER_READ]], %arg0[%arg1] : vector<17xf32>, memref + +// ----- func.func @transfer_read_write_1d_scalable(%A : memref, %base: index) -> vector<[17]xf32> { %f7 = arith.constant 7.0: f32 @@ -85,62 +29,9 @@ func.func @transfer_read_write_1d_scalable(%A : memref, %base: index) -> return %f: vector<[17]xf32> } // CHECK-LABEL: func @transfer_read_write_1d_scalable -// CHECK-SAME: %[[MEM:.*]]: memref, -// CHECK-SAME: %[[BASE:.*]]: index) -> vector<[17]xf32> -// CHECK: %[[C7:.*]] = arith.constant 7.0 -// -// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset) -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref -// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index -// -// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. -// CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]x[[$IDX_TYPE]]> -// -// 3. Create bound vector to compute in-bound mask: -// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ] -// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]] -// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]] -// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]] -// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]] -// CHECK-SAME: : vector<[17]x[[$IDX_TYPE]]> -// -// 4. Create pass-through vector. -// CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<[17]xf32> -// -// 5. Bitcast to vector form. -// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} : -// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32 -// -// 6. Rewrite as a masked read. -// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[gep]], %[[mask]], -// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} : -// CHECK-SAME: -> vector<[17]xf32> -// -// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset) -// CHECK: %[[C0_b:.*]] = arith.constant 0 : index -// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref -// CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index -// -// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. -// CHECK: %[[linearIndex_b:.*]] = llvm.intr.stepvector : vector<[17]x[[$IDX_TYPE]]> -// -// 3. Create bound vector to compute in-bound mask: -// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ] -// CHECK: %[[btrunc_b:.*]] = arith.index_cast %[[BOUND_b]] : index to [[$IDX_TYPE]] -// CHECK: %[[boundVecInsert_b:.*]] = llvm.insertelement %[[btrunc_b]] -// CHECK: %[[boundVect_b:.*]] = llvm.shufflevector %[[boundVecInsert_b]] -// CHECK: %[[mask_b:.*]] = arith.cmpi slt, %[[linearIndex_b]], -// CHECK-SAME: %[[boundVect_b]] : vector<[17]x[[$IDX_TYPE]]> -// -// 4. Bitcast to vector form. -// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} : -// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32 -// -// 5. Rewrite as a masked write. -// CHECK: llvm.intr.masked.store %[[loaded]], %[[gep_b]], %[[mask_b]] -// CHECK-SAME: {alignment = 4 : i32} : -// CHECK-SAME: vector<[17]xf32>, vector<[17]xi1> into !llvm.ptr +// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32 +// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1], %[[CST]] : memref, vector<[17]xf32> +// CHECK: vector.transfer_write %[[TRANSFER_READ]], %arg0[%arg1] : vector<[17]xf32>, memref // ----- @@ -155,15 +46,9 @@ func.func @transfer_read_write_index_1d(%A : memref, %base: index) -> v return %f: vector<17xindex> } // CHECK-LABEL: func @transfer_read_write_index_1d -// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xindex> -// CHECK: %[[SPLAT:.*]] = arith.constant dense<7> : vector<17xindex> -// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[SPLAT]] : vector<17xindex> to vector<17xi64> - -// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : -// CHECK-SAME: (!llvm.ptr, vector<17xi1>, vector<17xi64>) -> vector<17xi64> - -// CHECK: llvm.intr.masked.store %[[loaded]], %{{.*}}, %{{.*}} {alignment = 8 : i32} : -// CHECK-SAME: vector<17xi64>, vector<17xi1> into !llvm.ptr +// CHECK: %[[CST:.*]] = arith.constant 7 : index +// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1], %[[CST]] : memref, vector<17xindex> +// CHECK: vector.transfer_write %[[TRANSFER_READ]], %arg0[%arg1] : vector<17xindex>, memref func.func @transfer_read_write_index_1d_scalable(%A : memref, %base: index) -> vector<[17]xindex> { %f7 = arith.constant 7: index @@ -175,16 +60,10 @@ func.func @transfer_read_write_index_1d_scalable(%A : memref, %base: in vector<[17]xindex>, memref return %f: vector<[17]xindex> } -// CHECK-LABEL: func @transfer_read_write_index_1d -// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xindex> -// CHECK: %[[SPLAT:.*]] = arith.constant dense<7> : vector<[17]xindex> -// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[SPLAT]] : vector<[17]xindex> to vector<[17]xi64> - -// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : -// CHECK-SAME: (!llvm.ptr, vector<[17]xi1>, vector<[17]xi64>) -> vector<[17]xi64> - -// CHECK: llvm.intr.masked.store %[[loaded]], %{{.*}}, %{{.*}} {alignment = 8 : i32} : -// CHECK-SAME: vector<[17]xi64>, vector<[17]xi1> into !llvm.ptr +// CHECK-LABEL: func @transfer_read_write_index_1d_scalable +// CHECK: %[[CST:.*]] = arith.constant 7 : index +// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1], %[[CST]] : memref, vector<[17]xindex> +// CHECK: vector.transfer_write %[[TRANSFER_READ]], %arg0[%arg1] : vector<[17]xindex>, memref // ----- @@ -196,24 +75,8 @@ func.func @transfer_read_2d_to_1d(%A : memref, %base0: index, %base1: i return %f: vector<17xf32> } // CHECK-LABEL: func @transfer_read_2d_to_1d -// CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: index, %[[BASE_1:[a-zA-Z0-9]*]]: index) -> vector<17xf32> -// CHECK: %[[c1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c1]] : memref -// -// Compute the in-bound index (dim - offset) -// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE_1]] : index -// -// Create a vector with linear indices [ 0 .. vector_length - 1 ]. -// CHECK: %[[linearIndex:.*]] = arith.constant dense -// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : -// CHECK-SAME: vector<17x[[$IDX_TYPE]]> -// -// Create bound vector to compute in-bound mask: -// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ] -// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]] -// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]] -// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]] -// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]] +// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32 +// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg2], %[[CST]] : memref, vector<17xf32> func.func @transfer_read_2d_to_1d_scalable(%A : memref, %base0: index, %base1: index) -> vector<[17]xf32> { %f7 = arith.constant 7.0: f32 @@ -222,23 +85,10 @@ func.func @transfer_read_2d_to_1d_scalable(%A : memref, %base0: index, memref, vector<[17]xf32> return %f: vector<[17]xf32> } -// CHECK-LABEL: func @transfer_read_2d_to_1d -// CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: index, %[[BASE_1:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32> -// CHECK: %[[c1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c1]] : memref -// -// Compute the in-bound index (dim - offset) -// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE_1]] : index -// -// Create a vector with linear indices [ 0 .. vector_length - 1 ]. -// CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]x[[$IDX_TYPE]]> -// -// Create bound vector to compute in-bound mask: -// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ] -// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]] -// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]] -// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]] -// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]] +// CHECK-LABEL: func @transfer_read_2d_to_1d_scalable +// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32 +// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg2], %[[CST]] : memref, vector<[17]xf32> +// CHECK: return %[[TRANSFER_READ]] : vector<[17]xf32> // ----- @@ -253,126 +103,7 @@ func.func @transfer_read_write_1d_non_zero_addrspace(%A : memref, %bas return %f: vector<17xf32> } // CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace -// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32> -// -// 1. Check address space for GEP is correct. -// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : -// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32 -// -// 2. Check address space of the memref is correct. -// CHECK: %[[c0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref -// -// 3. Check address space for GEP is correct. -// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} : -// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32 - -func.func @transfer_read_write_1d_non_zero_addrspace_scalable(%A : memref, %base: index) -> vector<[17]xf32> { - %f7 = arith.constant 7.0: f32 - %f = vector.transfer_read %A[%base], %f7 - {permutation_map = affine_map<(d0) -> (d0)>} : - memref, vector<[17]xf32> - vector.transfer_write %f, %A[%base] - {permutation_map = affine_map<(d0) -> (d0)>} : - vector<[17]xf32>, memref - return %f: vector<[17]xf32> -} -// CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace_scalable -// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32> -// -// 1. Check address space for GEP is correct. -// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : -// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32 -// -// 2. Check address space of the memref is correct. -// CHECK: %[[c0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref -// -// 3. Check address space for GEP is correct. -// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} : -// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32 - -// ----- - -func.func @transfer_read_1d_inbounds(%A : memref, %base: index) -> vector<17xf32> { - %f7 = arith.constant 7.0: f32 - %f = vector.transfer_read %A[%base], %f7 {in_bounds = [true]} : - memref, vector<17xf32> - return %f: vector<17xf32> -} -// CHECK-LABEL: func @transfer_read_1d_inbounds -// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32> -// -// 1. Bitcast to vector form. -// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : -// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32 -// -// 2. Rewrite as a load. -// CHECK: %[[loaded:.*]] = llvm.load %[[gep]] {alignment = 4 : i64} : !llvm.ptr -> vector<17xf32> - -func.func @transfer_read_1d_inbounds_scalable(%A : memref, %base: index) -> vector<[17]xf32> { - %f7 = arith.constant 7.0: f32 - %f = vector.transfer_read %A[%base], %f7 {in_bounds = [true]} : - memref, vector<[17]xf32> - return %f: vector<[17]xf32> -} -// CHECK-LABEL: func @transfer_read_1d_inbounds_scalable -// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32> -// -// 1. Bitcast to vector form. -// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : -// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32 -// -// 2. Rewrite as a load. -// CHECK: %[[loaded:.*]] = llvm.load %[[gep]] {alignment = 4 : i64} : !llvm.ptr -> vector<[17]xf32> - -// ----- - -// CHECK-LABEL: func @transfer_read_write_1d_mask -// CHECK: %[[mask1:.*]] = arith.constant dense<[false, false, true, false, true]> -// CHECK: %[[cmpi:.*]] = arith.cmpi slt -// CHECK: %[[mask2:.*]] = arith.andi %[[cmpi]], %[[mask1]] -// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]] -// CHECK: %[[cmpi_1:.*]] = arith.cmpi slt -// CHECK: %[[mask3:.*]] = arith.andi %[[cmpi_1]], %[[mask1]] -// CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask3]] -// CHECK: return %[[r]] -func.func @transfer_read_write_1d_mask(%A : memref, %base : index) -> vector<5xf32> { - %m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1> - %f7 = arith.constant 7.0: f32 - %f = vector.transfer_read %A[%base], %f7, %m : memref, vector<5xf32> - vector.transfer_write %f, %A[%base], %m : vector<5xf32>, memref - return %f: vector<5xf32> -} - -// CHECK-LABEL: func @transfer_read_write_1d_mask_scalable -// CHECK-SAME: %[[mask:[a-zA-Z0-9]*]]: vector<[5]xi1> -// CHECK: %[[cmpi:.*]] = arith.cmpi slt -// CHECK: %[[mask1:.*]] = arith.andi %[[cmpi]], %[[mask]] -// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask1]] -// CHECK: %[[cmpi_1:.*]] = arith.cmpi slt -// CHECK: %[[mask2:.*]] = arith.andi %[[cmpi_1]], %[[mask]] -// CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask2]] -// CHECK: return %[[r]] -func.func @transfer_read_write_1d_mask_scalable(%A : memref, %base : index, %m : vector<[5]xi1>) -> vector<[5]xf32> { - %f7 = arith.constant 7.0: f32 - %f = vector.transfer_read %A[%base], %f7, %m : memref, vector<[5]xf32> - vector.transfer_write %f, %A[%base], %m : vector<[5]xf32>, memref - return %f: vector<[5]xf32> -} - -// ----- - -// Can't lower xfer_read/xfer_write on tensors, but this shouldn't crash - -// CHECK-LABEL: func @transfer_read_write_tensor -// CHECK: vector.transfer_read -// CHECK: vector.transfer_write -func.func @transfer_read_write_tensor(%A: tensor, %base : index) -> vector<4xf32> { - %f7 = arith.constant 7.0: f32 - %c0 = arith.constant 0: index - %f = vector.transfer_read %A[%base], %f7 : tensor, vector<4xf32> - %w = vector.transfer_write %f, %A[%c0] : vector<4xf32>, tensor - "test.some_use"(%w) : (tensor) -> () - return %f : vector<4xf32> -} +// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32 +// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1], %[[CST]] : memref, vector<17xf32> +// CHECK: vector.transfer_write %[[TRANSFER_READ]], %arg0[%arg1] : vector<17xf32>, memref +// CHECK: return %[[TRANSFER_READ]] : diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir index 4b38db79bff3e..e590e462c728a 100644 --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -5,8 +5,9 @@ func.func @matmul_tensors( %arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>, %arg2: tensor<8x32xf32>) -> tensor<8x32xf32> { // CHECK-NOT: linalg +// CHECK: vector.transfer_read {{.*}} : memref<8x16xf32>, vector<2xf32> // CHECK: vector.extract {{.*}} : vector<4xf32> from vector<8x4xf32> -// CHECK: vector.store {{.*}} : memref<8x32xf32>, vector<4xf32> +// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<8x32xf32> %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x16xf32>, tensor<16x32xf32>) outs(%arg2: tensor<8x32xf32>) -> tensor<8x32xf32> diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index f90111b4c8861..7acfdad930b8e 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -1,23 +1,17 @@ // RUN: mlir-opt %s --transform-interpreter -canonicalize --split-input-file | FileCheck %s -// CHECK-LABEL: func @vector_transfer_ops_0d_memref( +// CHECK-LABEL: func @vector_transfer_ops_0d_memref // CHECK-SAME: %[[MEM:.*]]: memref // CHECK-SAME: %[[VEC:.*]]: vector<1x1x1xf32> -func.func @vector_transfer_ops_0d_memref(%mem: memref, %vec: vector<1x1x1xf32>) { +// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[V:.*]] = vector.transfer_read %arg0[], %[[CST]] : memref, vector +// CHECK-NEXT: vector.transfer_write %0, %arg0[] : vector, memref +// CHECK-NEXT: vector.store %arg1, %arg0[] : memref, vector<1x1x1xf32> +func.func @vector_transfer_ops_0d_memref(%M: memref, %v: vector<1x1x1xf32>) { %f0 = arith.constant 0.0 : f32 - -// CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref -// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector - %0 = vector.transfer_read %mem[], %f0 : memref, vector - -// CHECK-NEXT: %[[SS:.*]] = vector.extractelement %[[V]][] : vector -// CHECK-NEXT: memref.store %[[SS]], %[[MEM]][] : memref - vector.transfer_write %0, %mem[] : vector, memref - -// CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32> -// CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref - vector.store %vec, %mem[] : memref, vector<1x1x1xf32> - + %0 = vector.transfer_read %M[], %f0 : memref, vector + vector.transfer_write %0, %M[] : vector, memref + vector.store %v, %M[] : memref, vector<1x1x1xf32> return } @@ -36,13 +30,11 @@ func.func @vector_transfer_ops_0d_tensor(%src: tensor) -> vector<1xf32> { } // transfer_read/write are lowered to vector.load/store -// CHECK-LABEL: func @transfer_to_load( -// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, -// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { -// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32> -// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32> -// CHECK-NEXT: return %[[RES]] : vector<4xf32> -// CHECK-NEXT: } +// CHECK-LABEL: func @transfer_to_load( +// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]] {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32> +// CHECK-NEXT: vector.transfer_write %0, %arg0[%arg1, %arg1] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32> +// CHECK-NEXT: return %[[TRANSFER_READ]] : vector<4xf32> func.func @transfer_to_load(%mem : memref<8x8xf32>, %idx : index) -> vector<4xf32> { %cf0 = arith.constant 0.0 : f32 @@ -70,12 +62,10 @@ func.func @masked_transfer_to_load(%mem : memref<8x8xf32>, %idx : index, %mask : // n-D results are also supported. // CHECK-LABEL: func @transfer_2D( -// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, -// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> { -// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32> -// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32> -// CHECK-NEXT: return %[[RES]] : vector<2x4xf32> -// CHECK-NEXT: } +// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]] {in_bounds = [true, true]} : memref<8x8xf32>, vector<2x4xf32> +// CHECK-NEXT: vector.transfer_write %[[TRANSFER_READ]], %arg0[%arg1, %arg1] {in_bounds = [true, true]} : vector<2x4xf32>, memref<8x8xf32> +// CHECK-NEXT: return %[[TRANSFER_READ]] : vector<2x4xf32> func.func @transfer_2D(%mem : memref<8x8xf32>, %idx : index) -> vector<2x4xf32> { %cf0 = arith.constant 0.0 : f32 @@ -88,10 +78,10 @@ func.func @transfer_2D(%mem : memref<8x8xf32>, %idx : index) -> vector<2x4xf32> // CHECK-LABEL: func @transfer_vector_element( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> { -// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32> -// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32> -// CHECK-NEXT: return %[[RES]] : vector<2x4xf32> -// CHECK-NEXT: } +// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x4xf32> +// CHECK-NEXT: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32> +// CHECK-NEXT: vector.transfer_write %[[TRANSFER_READ]], %arg0[%arg1, %arg1] : vector<2x4xf32>, memref<8x8xvector<2x4xf32>> +// CHECK-NEXT: return %[[TRANSFER_READ]] : vector<2x4xf32> func.func @transfer_vector_element(%mem : memref<8x8xvector<2x4xf32>>, %idx : index) -> vector<2x4xf32> { %cf0 = arith.constant dense<0.0> : vector<2x4xf32> @@ -157,10 +147,10 @@ func.func @transfer_not_inbounds(%mem : memref<8x8xf32>, %idx : index) -> vector // CHECK-LABEL: func @transfer_nondefault_layout( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32, #{{.*}}>, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { -// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32, #{{.*}}>, vector<4xf32> -// CHECK-NEXT: vector.store %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32, #{{.*}}>, vector<4xf32> -// CHECK-NEXT: return %[[RES]] : vector<4xf32> -// CHECK-NEXT: } +// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]] {in_bounds = [true]} : memref<8x8xf32, #map>, vector<4xf32> +// CHECK-NEXT: vector.transfer_write %[[TRANSFER_READ]], %arg0[%arg1, %arg1] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32, #map> +// CHECK-NEXT: return %[[TRANSFER_READ]] : vector<4xf32> #layout = affine_map<(d0, d1) -> (d0*16 + d1)> func.func @transfer_nondefault_layout(%mem : memref<8x8xf32, #layout>, %idx : index) -> vector<4xf32> { @@ -191,11 +181,11 @@ func.func @transfer_perm_map(%mem : memref<8x8xf32>, %idx : index) -> vector<4xf // CHECK-LABEL: func @transfer_broadcasting( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32> -// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4xf32> -// CHECK-NEXT: return %[[RES]] : vector<4xf32> +// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]] : memref<8x8xf32>, vector +// CHECK-NEXT: %[[BROADCAST:.*]] = vector.broadcast %[[TRANSFER_READ]] : vector to vector<4xf32> +// CHECK-NEXT: return %[[BROADCAST]] : vector<4xf32> // CHECK-NEXT: } - #broadcast_1d = affine_map<(d0, d1) -> (0)> func.func @transfer_broadcasting(%mem : memref<8x8xf32>, %idx : index) -> vector<4xf32> { %cf0 = arith.constant 0.0 : f32 @@ -208,9 +198,9 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32>, %idx : index) -> vector // CHECK-LABEL: func @transfer_scalar( // CHECK-SAME: %[[MEM:.*]]: memref, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<1xf32> { -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref -// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1xf32> -// CHECK-NEXT: return %[[RES]] : vector<1xf32> +// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]] {in_bounds = [true]} : memref, vector<1xf32> +// CHECK-NEXT: return %[[TRANSFER_READ]] : vector<1xf32> // CHECK-NEXT: } func.func @transfer_scalar(%mem : memref, %idx : index) -> vector<1xf32> { %cf0 = arith.constant 0.0 : f32 @@ -222,9 +212,10 @@ func.func @transfer_scalar(%mem : memref, %idx : index) -> vector<1xf32 // CHECK-LABEL: func @transfer_broadcasting_2D( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<4x4xf32> { -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32> -// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4x4xf32> -// CHECK-NEXT: return %[[RES]] : vector<4x4xf32> +// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]] : memref<8x8xf32>, vector +// CHECK-NEXT: %[[BROADCAST:.*]] = vector.broadcast %[[TRANSFER_READ]] : vector to vector<4x4xf32> +// CHECK-NEXT: return %[[BROADCAST]] : vector<4x4xf32> // CHECK-NEXT: } #broadcast_2d = affine_map<(d0, d1) -> (0, 0)> @@ -240,9 +231,9 @@ func.func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %idx : index) -> vec // CHECK-LABEL: func @transfer_broadcasting_complex( // CHECK-SAME: %[[MEM:.*]]: memref<10x20x30x8x8xf32>, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<3x2x4x5xf32> { -// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]] : memref<10x20x30x8x8xf32>, vector<3x1x1x5xf32> -// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<3x1x1x5xf32> to vector<3x2x4x5xf32> -// CHECK-NEXT: return %[[RES]] : vector<3x2x4x5xf32> +// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1, %arg1, %arg1, %arg1], %[[CST]] {in_bounds = [true, true, true, true], permutation_map = #map2} : memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32> +// CHECK-NEXT: return %[[TRANSFER_READ]] : vector<3x2x4x5xf32> // CHECK-NEXT: } #broadcast_2d_in_4d = affine_map<(d0, d1, d2, d3, d4) -> (d1, 0, 0, d4)> @@ -322,8 +313,8 @@ func.func @transfer_read_permutations(%mem_0 : memref, %mem_1 : memref< // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> %6 = vector.transfer_read %mem_0[%c0, %c0], %cst {in_bounds = [true], permutation_map = #map6} : memref, vector<8xf32> -// CHECK: memref.load %{{.*}}[%[[C0]], %[[C0]]] : memref -// CHECK: vector.broadcast %{{.*}} : f32 to vector<8xf32> +// CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF0]] : memref, vector +// CHECK: vector.broadcast %{{.*}} : vector to vector<8xf32> return %0, %1, %2, %3, %4, %5, %6 : vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,