-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][spirv] Add GpuToLLVM cconv suited to Vulkan, migrate last tests #123384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -428,18 +428,18 @@ class LegalizeLaunchFuncOpPattern | |||||||||
| public: | ||||||||||
| LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter, | ||||||||||
| bool kernelBarePtrCallConv, | ||||||||||
| bool typeCheckKernelArgs) | ||||||||||
| bool kernelIntersperseSizeCallConv) | ||||||||||
| : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter), | ||||||||||
| kernelBarePtrCallConv(kernelBarePtrCallConv), | ||||||||||
| typeCheckKernelArgs(typeCheckKernelArgs) {} | ||||||||||
| kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {} | ||||||||||
|
|
||||||||||
| private: | ||||||||||
| LogicalResult | ||||||||||
| matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, | ||||||||||
| ConversionPatternRewriter &rewriter) const override; | ||||||||||
|
|
||||||||||
| bool kernelBarePtrCallConv; | ||||||||||
| bool typeCheckKernelArgs; | ||||||||||
| bool kernelIntersperseSizeCallConv; | ||||||||||
| }; | ||||||||||
|
|
||||||||||
| /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime | ||||||||||
|
|
@@ -566,8 +566,9 @@ void GpuToLLVMConversionPass::runOnOperation() { | |||||||||
| populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); | ||||||||||
| populateAsyncStructuralTypeConversionsAndLegality(converter, patterns, | ||||||||||
| target); | ||||||||||
| populateGpuToLLVMConversionPatterns( | ||||||||||
| converter, patterns, kernelBarePtrCallConv, typeCheckKernelArgs); | ||||||||||
| populateGpuToLLVMConversionPatterns(converter, patterns, | ||||||||||
| kernelBarePtrCallConv, | ||||||||||
| kernelIntersperseSizeCallConv); | ||||||||||
|
|
||||||||||
| if (failed( | ||||||||||
| applyPartialConversion(getOperation(), target, std::move(patterns)))) | ||||||||||
|
|
@@ -970,33 +971,55 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( | |||||||||
| else if (launchOp.getAsyncToken()) | ||||||||||
| stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(); | ||||||||||
|
|
||||||||||
| if (typeCheckKernelArgs) { | ||||||||||
| // The current non-bare-pointer ABI is a bad fit for `mgpuLaunchKernel`, | ||||||||||
| // which takes an untyped list of arguments. The type check here prevents | ||||||||||
| // accidentally violating the assumption made in vulkan-runtime-wrappers.cpp | ||||||||||
| // and creating a unchecked runtime ABI mismatch. | ||||||||||
| // TODO(https://github.com/llvm/llvm-project/issues/73457): Change the ABI | ||||||||||
| // here to remove the need for this type check. | ||||||||||
| for (Value arg : launchOp.getKernelOperands()) { | ||||||||||
| if (auto memrefTy = dyn_cast<MemRefType>(arg.getType())) { | ||||||||||
| if (memrefTy.getRank() != 1 || | ||||||||||
| memrefTy.getElementTypeBitWidth() != 32) { | ||||||||||
| return rewriter.notifyMatchFailure( | ||||||||||
| launchOp, "Operand to launch op is not a rank-1 memref with " | ||||||||||
| "32-bit element type."); | ||||||||||
| } | ||||||||||
| } else { | ||||||||||
| // Lower the kernel operands to match kernel parameters. | ||||||||||
| // Note: If `useBarePtrCallConv` is set in the type converter's options, | ||||||||||
| // the value of `kernelBarePtrCallConv` will be ignored. | ||||||||||
| OperandRange origArguments = launchOp.getKernelOperands(); | ||||||||||
| SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands( | ||||||||||
| loc, origArguments, adaptor.getKernelOperands(), rewriter, | ||||||||||
| /*useBarePtrCallConv=*/kernelBarePtrCallConv); | ||||||||||
| SmallVector<Value, 8> llvmArgumentsWithSizes; | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: don't need to explicitly specify
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer the explicit size in this case because |
||||||||||
|
|
||||||||||
| // Intersperse size information if requested. | ||||||||||
| if (kernelIntersperseSizeCallConv) { | ||||||||||
| if (origArguments.size() != llvmArguments.size()) { | ||||||||||
| // This shouldn't happen if the bare-pointer calling convention is used. | ||||||||||
| return rewriter.notifyMatchFailure( | ||||||||||
| launchOp, | ||||||||||
| "Cannot add sizes to arguments with one-to-many LLVM IR expansion."); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2); | ||||||||||
| for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) { | ||||||||||
| auto memrefTy = dyn_cast<MemRefType>(origArg.getType()); | ||||||||||
| if (!memrefTy) { | ||||||||||
| return rewriter.notifyMatchFailure( | ||||||||||
| launchOp, "Operand to launch op is not a memref."); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| if (!memrefTy.hasStaticShape() || | ||||||||||
| !memrefTy.getElementType().isIntOrFloat()) { | ||||||||||
| return rewriter.notifyMatchFailure( | ||||||||||
| launchOp, "Operand to launch op is not a memref with a static " | ||||||||||
| "shape and an integer or float element type."); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| unsigned bitwidth = memrefTy.getElementTypeBitWidth(); | ||||||||||
| if (bitwidth % 8 != 0) { | ||||||||||
| return rewriter.notifyMatchFailure( | ||||||||||
| launchOp, "Operand to launch op is not a memref with a " | ||||||||||
| "byte-aligned element type."); | ||||||||||
| } | ||||||||||
|
||||||||||
| // Non 1-bit dense elements are padded to 8-bits. | |
| size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT); | |
| assert(((data.size() / storageSize) == numElements) && | |
| "data does not hold expected number of elements"); |
I think it's fine to reject such memrefs. If someone wants to support them in the future, we can work with them on adding that.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| // RUN: mlir-opt %s --gpu-to-llvm="use-bare-pointers-for-kernels=1 intersperse-sizes-for-kernels=1" -split-input-file | FileCheck %s | ||
|
|
||
| module attributes {gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} { | ||
| llvm.func @malloc(i64) -> !llvm.ptr | ||
| gpu.binary @kernels [#gpu.object<#spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>, "">] | ||
| func.func @main() attributes {llvm.emit_c_interface} { | ||
| // CHECK: [[RANK1UMD:%.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> | ||
| %rank1UndefMemrefDescriptor = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> | ||
| // CHECK: [[RANK2UMD:%.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> | ||
| %rank2UndefMemrefDescriptor = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> | ||
| %c1 = arith.constant 1 : index | ||
| // CHECK: [[PTR1:%.*]] = llvm.extractvalue [[RANK1UMD]][1] | ||
| // CHECK: [[PTR2:%.*]] = llvm.extractvalue [[RANK2UMD]][1] | ||
| // CHECK: [[PTR3:%.*]] = llvm.extractvalue [[RANK2UMD]][1] | ||
| // CHECK: [[SIZE1:%.*]] = llvm.mlir.constant(32 : index) : i64 | ||
| // CHECK: [[SIZE2:%.*]] = llvm.mlir.constant(256 : index) : i64 | ||
| // CHECK: [[SIZE3:%.*]] = llvm.mlir.constant(48 : index) : i64 | ||
| %6 = builtin.unrealized_conversion_cast %rank1UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<8xf32> | ||
| %10 = builtin.unrealized_conversion_cast %rank2UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<8x8xi32> | ||
| %14 = builtin.unrealized_conversion_cast %rank2UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<4x12xi8> | ||
| // CHECK: gpu.launch_func @kernels::@kernel_add blocks in ({{.*}}) threads in ({{.*}}) : i64 args([[PTR1]] : !llvm.ptr, [[SIZE1]] : i64, [[PTR2]] : !llvm.ptr, [[SIZE2]] : i64, [[PTR3]] : !llvm.ptr, [[SIZE3]] : i64) | ||
| gpu.launch_func @kernels::@kernel_add blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%6 : memref<8xf32>, %10 : memref<8x8xi32>, %14 : memref<4x12xi8>) | ||
| return | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe better to make it enum instead of 2 bools?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That might not be a bad idea, but I am really unsure if it's possible for an option to be an enum, or what I have to do if I want that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Specifically the
Optiondefined inPasses.td. I know that it's somehow leveraging theclinfrastructure so enums "should" be possible but I haven't managed to figure it out from staring at the headers and tablegen definitions, there's too many layers.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the example https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Conversion/Passes.td#L1145, but we can do it in separate PR