diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index fdf799a20efdd..ff5b762a969d8 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -307,6 +307,17 @@ class CastPattern final : public OpConversionPattern { } }; +/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. +class ExtractAlignedPointerAsIndexOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; } // namespace //===----------------------------------------------------------------------===// @@ -921,6 +932,20 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite( return success(); } +//===----------------------------------------------------------------------===// +// ExtractAlignedPointerAsIndexOp +//===----------------------------------------------------------------------===// + +LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite( + memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto &typeConverter = *getTypeConverter(); + Type indexType = typeConverter.getIndexType(); + rewriter.replaceOpWithNewOp(extractOp, indexType, + adaptor.getSource()); + return success(); +} + //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// @@ -928,10 +953,11 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite( namespace mlir { void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add(typeConverter, - patterns.getContext()); + patterns + .add( + typeConverter, patterns.getContext()); } } // namespace mlir diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index 8906de9db3724..d0ddac8cd801c 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -1,4 +1,6 @@ -// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s +// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8}, cse)" %s | FileCheck %s +// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s \ +// RUN: | FileCheck --check-prefix=CHECK64 %s // Check that with proper compute and storage extensions, we don't need to // perform special tricks. @@ -420,6 +422,43 @@ func.func @cast_to_static_zero_elems(%arg: memref, #spirv.resource_limits<>> +} { +// CHECK-LABEL: func @extract_aligned_pointer_as_index_kernel +func.func @extract_aligned_pointer_as_index_kernel(%m: memref>) -> index { + %0 = memref.extract_aligned_pointer_as_index %m: memref> -> index + // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr to i32 + // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index + // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr to i64 + // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index + + // CHECK: return %[[R:.*]] : index + return %0: index +} +} + +// ----- + +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { +// CHECK-LABEL: func @extract_aligned_pointer_as_index_shader +func.func @extract_aligned_pointer_as_index_shader(%m: memref>) -> index { + %0 = memref.extract_aligned_pointer_as_index %m: memref> -> index + // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr)>, CrossWorkgroup> to i32 + // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index + // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr)>, CrossWorkgroup> to i64 + // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index + + // CHECK: return %[[R:.*]] : index + return %0: index +} +} + + // ----- // Check nontemporal attribute