diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index b866afbce98b0..e78dca75116c2 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -243,6 +243,16 @@ class LoadOpPattern final : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override; }; +/// Converts memref.load to spirv.Image + spirv.ImageFetch +class ImageLoadOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts memref.store to spirv.Store on integers. class IntStoreOpPattern final : public OpConversionPattern { public: @@ -527,6 +537,17 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, if (!memrefType.getElementType().isSignlessInteger()) return failure(); + auto memorySpaceAttr = + dyn_cast_if_present(memrefType.getMemorySpace()); + if (!memorySpaceAttr) + return rewriter.notifyMatchFailure( + loadOp, "missing memory space SPIR-V storage class attribute"); + + if (memorySpaceAttr.getValue() == spirv::StorageClass::Image) + return rewriter.notifyMatchFailure( + loadOp, + "failed to lower memref in image storage class to storage buffer"); + const auto &typeConverter = *getTypeConverter(); Value accessChain = spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), @@ -643,6 +664,18 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, auto memrefType = cast(loadOp.getMemref().getType()); if (memrefType.getElementType().isSignlessInteger()) return failure(); + + auto memorySpaceAttr = + dyn_cast_if_present(memrefType.getMemorySpace()); + if (!memorySpaceAttr) + return rewriter.notifyMatchFailure( + loadOp, "missing memory space SPIR-V storage class attribute"); + + if (memorySpaceAttr.getValue() == spirv::StorageClass::Image) + return rewriter.notifyMatchFailure( + loadOp, + "failed to lower memref in image storage class to storage buffer"); + Value loadPtr = spirv::getElementPtr( *getTypeConverter(), memrefType, adaptor.getMemref(), adaptor.getIndices(), loadOp.getLoc(), rewriter); @@ -661,6 +694,87 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, return success(); } +LogicalResult +ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto memrefType = cast(loadOp.getMemref().getType()); + + auto memorySpaceAttr = + dyn_cast_if_present(memrefType.getMemorySpace()); + if (!memorySpaceAttr) + return rewriter.notifyMatchFailure( + loadOp, "missing memory space SPIR-V storage class attribute"); + + if (memorySpaceAttr.getValue() != spirv::StorageClass::Image) + return rewriter.notifyMatchFailure( + loadOp, "failed to lower memref in non-image storage class to image"); + + Value loadPtr = adaptor.getMemref(); + auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp); + if (failed(memoryRequirements)) + return rewriter.notifyMatchFailure( + loadOp, "failed to determine memory requirements"); + + const auto [memoryAccess, alignment] = *memoryRequirements; + + if (!loadOp.getMemRefType().hasRank()) + return rewriter.notifyMatchFailure( + loadOp, "cannot lower unranked memrefs to SPIR-V images"); + + // We currently only support lowering of scalar memref elements to texels in + // the R[16|32][f|i|ui] formats. Future work will enable lowering of vector + // elements to texels in richer formats. + if (!isa(loadOp.getMemRefType().getElementType())) + return rewriter.notifyMatchFailure( + loadOp, + "cannot lower memrefs who's element type is not a SPIR-V scalar type" + "to SPIR-V images"); + + // We currently only support sampled images since OpImageFetch does not work + // for plain images and the OpImageRead instruction needs to be materialized + // instead or texels need to be accessed via atomics through a texel pointer. + // Future work will generalize support to plain images. + auto convertedPointeeType = cast( + getTypeConverter()->convertType(loadOp.getMemRefType())); + if (!isa(convertedPointeeType.getPointeeType())) + return rewriter.notifyMatchFailure(loadOp, + "cannot lower memrefs which do not " + "convert to SPIR-V sampled images"); + + // Materialize the lowering. + Location loc = loadOp->getLoc(); + auto imageLoadOp = + spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment); + // Extract the image from the sampled image. + auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp); + + // Build a vector of coordinates or just a scalar index if we have a 1D image. + Value coords; + if (memrefType.getRank() != 1) { + auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()}, + adaptor.getIndices().getType()[0]); + coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType, + adaptor.getIndices()); + } else { + coords = adaptor.getIndices()[0]; + } + + // Fetch the value out of the image. + auto resultVectorType = VectorType::get({4}, loadOp.getType()); + auto fetchOp = spirv::ImageFetchOp::create( + rewriter, loc, resultVectorType, imageOp, coords, + mlir::spirv::ImageOperandsAttr{}, ValueRange{}); + + // Note that because OpImageFetch returns a rank 4 vector we need to extract + // the elements corresponding to the load which will since we only support the + // R[16|32][f|i|ui] formats will always be the R(red) 0th vector element. + auto compositeExtractOp = + spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0); + + rewriter.replaceOp(loadOp, compositeExtractOp); + return success(); +} + LogicalResult IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -952,11 +1066,11 @@ LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite( namespace mlir { void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns - .add( - typeConverter, patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext()); } } // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index f70b3325f8725..a7a4a1ff6d921 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -560,6 +560,45 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv, return wrapInStructAndGetPointer(arrayType, storageClass); } +static spirv::Dim convertRank(int64_t rank) { + switch (rank) { + case 1: + return spirv::Dim::Dim1D; + case 2: + return spirv::Dim::Dim2D; + case 3: + return spirv::Dim::Dim3D; + default: + llvm_unreachable("Invalid memref rank!"); + } +} + +static spirv::ImageFormat getImageFormat(Type elementType) { + return llvm::TypeSwitch(elementType) + .Case([](Float16Type) { return spirv::ImageFormat::R16f; }) + .Case([](Float32Type) { return spirv::ImageFormat::R32f; }) + .Case([](IntegerType intType) { + auto const isSigned = intType.isSigned() || intType.isSignless(); +#define BIT_WIDTH_CASE(BIT_WIDTH) \ + case BIT_WIDTH: \ + return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \ + : spirv::ImageFormat::R##BIT_WIDTH##ui + + switch (intType.getWidth()) { + BIT_WIDTH_CASE(16); + BIT_WIDTH_CASE(32); + default: + llvm_unreachable("Unhandled integer type!"); + } + }) + .Default([](Type) { + llvm_unreachable("Unhandled element type!"); + // We need to return something here to satisfy the type switch. + return spirv::ImageFormat::R32f; + }); +#undef BIT_WIDTH_CASE +} + static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type) { @@ -575,6 +614,41 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, } spirv::StorageClass storageClass = attr.getValue(); + // Images are a special case since they are an opaque type from which elements + // may be accessed via image specific ops or directly through a texture + // pointer. + if (storageClass == spirv::StorageClass::Image) { + const int64_t rank = type.getRank(); + if (rank < 1 || rank > 3) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot lower memref of rank " << rank + << " to a SPIR-V Image\n"); + return nullptr; + } + + // Note that we currently only support lowering to single element texels + // e.g. R32f. + auto elementType = type.getElementType(); + if (!isa(elementType)) { + LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of " + << elementType << " to a SPIR-V Image\n"); + return nullptr; + } + + // Currently every memref in the image storage class is converted to a + // sampled image so we can hardcode the NeedSampler field. Future work + // will generalize this to support regular non-sampled images. + auto spvImageType = spirv::ImageType::get( + elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown, + spirv::ImageArrayedInfo::NonArrayed, + spirv::ImageSamplingInfo::SingleSampled, + spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType)); + auto spvSampledImageType = spirv::SampledImageType::get(spvImageType); + auto imagePtrType = spirv::PointerType::get( + spvSampledImageType, spirv::StorageClass::UniformConstant); + return imagePtrType; + } + if (isa(type.getElementType())) { if (type.getElementTypeBitWidth() == 1) return convertBoolMemrefType(targetEnv, options, type, storageClass); diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index d0ddac8cd801c..2a7be0be7477a 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -488,3 +488,191 @@ module attributes { return } } + +// ----- + +// Check Image Support. + +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK-LABEL: @load_from_image_1D( + // CHECK-SAME: %[[ARG0:.*]]: memref<1xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<1xf32, #spirv.storage_class> + func.func @load_from_image_1D(%arg0: memref<1xf32, #spirv.storage_class>, %arg1: memref<1xf32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + %cst = arith.constant 0 : index + // CHECK: %[[COORDS:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i32 + // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> + // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> + // CHECK-NOT: spirv.CompositeConstruct + // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, i32 -> vector<4xf32> + // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32> + %0 = memref.load %arg0[%cst] : memref<1xf32, #spirv.storage_class> + // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32 + memref.store %0, %arg1[%cst] : memref<1xf32, #spirv.storage_class> + return + } + + // CHECK-LABEL: @load_from_image_2D( + // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xf32, #spirv.storage_class> + func.func @load_from_image_2D(%arg0: memref<1x1xf32, #spirv.storage_class>, %arg1: memref<1x1xf32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + %cst = arith.constant 0 : index + // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> + // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xf32> + // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32> + %0 = memref.load %arg0[%cst, %cst] : memref<1x1xf32, #spirv.storage_class> + // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32 + memref.store %0, %arg1[%cst, %cst] : memref<1x1xf32, #spirv.storage_class> + return + } + + // CHECK-LABEL: @load_from_image_3D( + // CHECK-SAME: %[[ARG0:.*]]: memref<1x1x1xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1x1xf32, #spirv.storage_class> + func.func @load_from_image_3D(%arg0: memref<1x1x1xf32, #spirv.storage_class>, %arg1: memref<1x1x1xf32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1x1xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1x1xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + %cst = arith.constant 0 : index + // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> + // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}}, %{{.*}} : (i32, i32, i32) -> vector<3xi32> + // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<3xi32> -> vector<4xf32> + // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32> + %0 = memref.load %arg0[%cst, %cst, %cst] : memref<1x1x1xf32, #spirv.storage_class> + // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32 + memref.store %0, %arg1[%cst, %cst, %cst] : memref<1x1x1xf32, #spirv.storage_class> + return + } + + // CHECK-LABEL: @load_from_image_2D_f16( + // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf16, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xf16, #spirv.storage_class> + func.func @load_from_image_2D_f16(%arg0: memref<1x1xf16, #spirv.storage_class>, %arg1: memref<1x1xf16, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xf16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xf16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + %cst = arith.constant 0 : index + // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> + // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xf16> + // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf16> + %0 = memref.load %arg0[%cst, %cst] : memref<1x1xf16, #spirv.storage_class> + // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f16 + memref.store %0, %arg1[%cst, %cst] : memref<1x1xf16, #spirv.storage_class> + return + } + + // CHECK-LABEL: @load_from_image_2D_i32( + // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xi32, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xi32, #spirv.storage_class> + func.func @load_from_image_2D_i32(%arg0: memref<1x1xi32, #spirv.storage_class>, %arg1: memref<1x1xi32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xi32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xi32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + %cst = arith.constant 0 : index + // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> + // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xi32> + // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi32> + %0 = memref.load %arg0[%cst, %cst] : memref<1x1xi32, #spirv.storage_class> + // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i32 + memref.store %0, %arg1[%cst, %cst] : memref<1x1xi32, #spirv.storage_class> + return + } + + // CHECK-LABEL: @load_from_image_2D_ui32( + // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xui32, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xui32, #spirv.storage_class> + func.func @load_from_image_2D_ui32(%arg0: memref<1x1xui32, #spirv.storage_class>, %arg1: memref<1x1xui32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xui32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xui32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + %cst = arith.constant 0 : index + // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> + // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xui32> + // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui32> + %0 = memref.load %arg0[%cst, %cst] : memref<1x1xui32, #spirv.storage_class> + // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui32 + memref.store %0, %arg1[%cst, %cst] : memref<1x1xui32, #spirv.storage_class> + return + } + + // CHECK-LABEL: @load_from_image_2D_i16( + // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xi16, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xi16, #spirv.storage_class> + func.func @load_from_image_2D_i16(%arg0: memref<1x1xi16, #spirv.storage_class>, %arg1: memref<1x1xi16, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xi16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xi16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + %cst = arith.constant 0 : index + // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> + // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xi16> + // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi16> + %0 = memref.load %arg0[%cst, %cst] : memref<1x1xi16, #spirv.storage_class> + // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i16 + memref.store %0, %arg1[%cst, %cst] : memref<1x1xi16, #spirv.storage_class> + return + } + + // CHECK-LABEL: @load_from_image_2D_ui16( + // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xui16, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xui16, #spirv.storage_class> + func.func @load_from_image_2D_ui16(%arg0: memref<1x1xui16, #spirv.storage_class>, %arg1: memref<1x1xui16, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xui16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xui16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + %cst = arith.constant 0 : index + // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> + // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xui16> + // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui16> + %0 = memref.load %arg0[%cst, %cst] : memref<1x1xui16, #spirv.storage_class> + // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui16 + memref.store %0, %arg1[%cst, %cst] : memref<1x1xui16, #spirv.storage_class> + return + } + + // CHECK-LABEL: @load_from_image_2D_rank0( + func.func @load_from_image_2D_rank0(%arg0: memref>, %arg1: memref>) { + %cst = arith.constant 0 : index + // CHECK-NOT: spirv.Image + // CHECK-NOT: spirv.ImageFetch + %0 = memref.load %arg0[] : memref> + memref.store %0, %arg1[] : memref> + return + } + + // CHECK-LABEL: @load_from_image_2D_rank4( + func.func @load_from_image_2D_rank4(%arg0: memref<1x1x1x1xf32, #spirv.storage_class>, %arg1: memref<1x1x1x1xf32, #spirv.storage_class>) { + %cst = arith.constant 0 : index + // CHECK-NOT: spirv.Image + // CHECK-NOT: spirv.ImageFetch + %0 = memref.load %arg0[%cst, %cst, %cst, %cst] : memref<1x1x1x1xf32, #spirv.storage_class> + memref.store %0, %arg1[%cst, %cst, %cst, %cst] : memref<1x1x1x1xf32, #spirv.storage_class> + return + } + + // CHECK-LABEL: @load_from_image_2D_vector( + func.func @load_from_image_2D_vector(%arg0: memref<1xvector<1xf32>, #spirv.storage_class>, %arg1: memref<1xvector<1xf32>, #spirv.storage_class>) { + %cst = arith.constant 0 : index + // CHECK-NOT: spirv.Image + // CHECK-NOT: spirv.ImageFetch + %0 = memref.load %arg0[%cst] : memref<1xvector<1xf32>, #spirv.storage_class> + memref.store %0, %arg1[%cst] : memref<1xvector<1xf32>, #spirv.storage_class> + return + } +}