diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp index 4cbc3dfdae223..1fbc5a03987e8 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp @@ -59,7 +59,8 @@ using namespace mlir; MAP_FN(spirv::StorageClass::UniformConstant, 8) \ MAP_FN(spirv::StorageClass::Input, 9) \ MAP_FN(spirv::StorageClass::Output, 10) \ - MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11) + MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11) \ + MAP_FN(spirv::StorageClass::Image, 12) std::optional spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) { diff --git a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class-vk.mlir b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class-vk.mlir new file mode 100644 index 0000000000000..3b2c1ae799c57 --- /dev/null +++ b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class-vk.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt --allow-unregistered-dialect --map-memref-spirv-storage-class='client-api=vulkan' %s | FileCheck %s + +// Vulkan Specific Mappings: +// 8 -> UniformConstant +// 9 -> Input +// 10 -> Output +// 11 -> PhysicalStorageBuffer +// 12 -> Image + +/// Check that Vulkan specific memory space indices get converted into the correct +/// SPIR-V storage class. If mappings to OpenCL address spaces are added for these +/// indices then those test case should be moved into the common test file. + +// CHECK-LABEL: func @test_vk_specific_memory_spaces +func.func @test_vk_specific_memory_spaces() { + // CHECK: memref<4xi32, #spirv.storage_class> + %1 = "dialect.memref_producer"() : () -> (memref<4xi32, 8>) + // CHECK: memref<4xi32, #spirv.storage_class> + %2 = "dialect.memref_producer"() : () -> (memref<4xi32, 9>) + // CHECK: memref<4xi32, #spirv.storage_class> + %3 = "dialect.memref_producer"() : () -> (memref<4xi32, 10>) + // CHECK: memref<4xi32, #spirv.storage_class> + %4 = "dialect.memref_producer"() : () -> (memref<4xi32, 11>) + // CHECK: memref<4xi32, #spirv.storage_class> + %5 = "dialect.memref_producer"() : () -> (memref<4xi32, 12>) + return +} diff --git a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir index f0956b62760a2..fdc69b8119994 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class='client-api=vulkan' -verify-diagnostics %s -o - | FileCheck %s --check-prefix=VULKAN // RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class='client-api=opencl' -verify-diagnostics %s -o - | FileCheck %s --check-prefix=OPENCL +// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class -verify-diagnostics %s -o - | FileCheck %s // Vulkan Mappings: // 0 -> StorageBuffer @@ -7,6 +8,14 @@ // 2 -> [null] // 3 -> Workgroup // 4 -> Uniform +// 5 -> Private +// 6 -> Function +// 7 -> PushConstant +// 8 -> UniformConstant +// 9 -> Input +// 10 -> Output +// 11 -> PhysicalStorageBuffer +// 12 -> Image // OpenCL Mappings: // 0 -> CrossWorkgroup @@ -14,6 +23,9 @@ // 2 -> [null] // 3 -> Workgroup // 4 -> UniformConstant +// 5 -> Private +// 6 -> Function +// 7 -> Image // VULKAN-LABEL: func @operand_result // OPENCL-LABEL: func @operand_result @@ -30,6 +42,15 @@ func.func @operand_result() { // VULKAN: memref<*xf16, #spirv.storage_class> // OPENCL: memref<*xf16, #spirv.storage_class> %3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>) + // VULKAN: memref<*xf16, #spirv.storage_class> + // OPENCL: memref<*xf16, #spirv.storage_class> + %4 = "dialect.memref_producer"() : () -> (memref<*xf16, 5>) + // VULKAN: memref<*xf16, #spirv.storage_class> + // OPENCL: memref<*xf16, #spirv.storage_class> + %5 = "dialect.memref_producer"() : () -> (memref<*xf16, 6>) + // VULKAN: memref<*xf16, #spirv.storage_class> + // OPENCL: memref<*xf16, #spirv.storage_class> + %6 = "dialect.memref_producer"() : () -> (memref<*xf16, 7>) "dialect.memref_consumer"(%0) : (memref) -> () @@ -42,6 +63,15 @@ func.func @operand_result() { // VULKAN: memref<*xf16, #spirv.storage_class> // OPENCL: memref<*xf16, #spirv.storage_class> "dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> () + // VULKAN: memref<*xf16, #spirv.storage_class> + // OPENCL: memref<*xf16, #spirv.storage_class> + "dialect.memref_consumer"(%4) : (memref<*xf16, 5>) -> () + // VULKAN: memref<*xf16, #spirv.storage_class> + // OPENCL: memref<*xf16, #spirv.storage_class> + "dialect.memref_consumer"(%5) : (memref<*xf16, 6>) -> () + // VULKAN: memref<*xf16, #spirv.storage_class> + // OPENCL: memref<*xf16, #spirv.storage_class> + "dialect.memref_consumer"(%6) : (memref<*xf16, 7>) -> () return } @@ -166,4 +196,4 @@ func.func @operand_result() { "dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> () return } -} \ No newline at end of file +}