From 470e7447ef3d3ce3cb1d2a30d26b0642206e0fdc Mon Sep 17 00:00:00 2001 From: Copybara Bot Date: Mon, 30 Jun 2025 13:18:04 -0700 Subject: [PATCH] Integrate internal changes Author: Sagar Shelke [executor] Add complex type support to `ScalarValue` Previously, ScalarValue which represents scalar runtime value did not support complex type. This MR adds support for complex type by making storage union of real and complex data instaed of just real. MLIR tests are added via constant subgraph execution. Author: Christopher Bate [compiler] Enable more `stablehlo.dot_general` to TensorRT using `tensorrt.einsum` Previously, we relied on canonicalization of `stablehlo.dot_general` to put all such contraction operations into a form that could be converted to `tensorrt.matrix_multiply`. Based on recent experiments, this can actually produce very inefficient TensorRT programs due to the number of reshapes and transpositions that must be inserted to coerce general `stablehlo.dot_general` into batched matrix multiplications. This change enables conversion of `stablehlo.dot_general` to `tensorrt.einsum`, and the pass and patterns now contain configurable parameters to control whether `tensorrt.einsum` is used as the primary method or only for fallback when conversion to `tensorrt.matrix_multiply` is not possible. A follow on change will revamp the Stablehlo preprocessing that we perform on 'stablehlo.dot_general' to avoid creating inefficient patterns and enable wider use of this pattern. Author: Christopher Bate [compiler] Fix stablehlo-to-scf scalarization heuristics Fixes an issue where float tensors in the 'before' region of converted while loops where scalarized. The transform should only scalarize operands which are likely to be for-style induction variables. Author: Christopher Bate [compiler] NFC: Drop dead code from StablehloToExecutableTask Author: Chris Bate [compiler] Add `plan-promote-host-tensors-to-host-pinned` pass Adds a simple pass to promote "host" tensors to "host-pinned" tensors in common cases where we know a tensor will be transferred between host and device spaces. This pass runs after `plan-optimize-memory-spaces` since the former is sensitive to mismatching host spaces for patterns related to moving tranfers out of loops. Author: Sagar Shelke [executor] Handle elided dense resource elements attr during translation Translation to executable (which is flatbuffer) uses MLIR attr serialization to serialize `ElementsAttr`. However, this doesn't work when attr is elided dense resource and results in segfault. This MR handles this situation by replacing elided resource with `DenseElementsAttr` of all `one`s (`true` in case of boolean). IR with elided resource is usally seen only during testing of passes and not useful for e2e functional execution. Testing of `ExecuteConstantFoldableSubgraphs` pass is such case. Thus, MLIR test cases for this pass are added. Author: Chris Bate [tensorrt] Fix TRT layer name generation function The TRT layer naming had some faulty logic that could cause the layer name to grow very large in the process to create a unique name. Fix the issue and use a static counter to reduce time spent in the loop. Author: Christopher Bate Further fixes to LIT configs Previously, we were setting `lit_config.parallelism_group` instead of `config.parallelism_group`. Apparently, the previous method does nothing, only `config.parallelism_group` has any effect. Author: Chris Bate Update LIT test parallelism configs In more recent versions of TensorRT (10.11+ at least), the builder is taking a much larger amount of host memory. This can cause OOM when running the LIT test suites under their existing configurations. This change updates all LIT configs: - Make sure to use `%pick-one-gpu` in the LIT command line to ensure we stall if there are not enough GPU or host resources available. Add a hard limit that there must be at least 5GB of host memory available. - Update configurations to reduce the amount of estimated parallelism by increasing host memory requirements and reducing the amount of host memory to 50% for the purposes of the parallelism calculation. - Force all tests to use a common parallelism group unless otherwise specified in the test config. Author: Christopher Bate [compiler] Fix failure case in stablehlo-to-scf Fixes a failure case due to one of the recently introduced rewrites in `stablehlo-to-scf`. Author: Christopher Bate [compiler] Further improvements to plan bufferization pipeline - Split `plan-assign-memory-spaces` into three passes: - `plan-assign-memory-spaces` - `plan-optimize-memory-spaces` - `plan-materialize-explicit-transfers` - The last one is the only new code: `plan-materialize-explicit-transfers` converts `tensor.cast` ops that change the memory space encoding into explicit `bufferization.alloc_tensor` + `bufferization.materialize_in_destination` operations. - Improve handling of `bufferization.alloc_tensor` and optimization of `scf.for` iteration args in `plan-assign-memory-spaces`. - Improve handling of `tensor.reshape` in `plan-assign-memory-spaces`. - Fix handling of `tensor.reshape` when rewriting functions to be in DPS style in `plan-alloc-tensors`. This change also updates the LLVM dependencies in order to cherry-pick fix to the `tensor.reshape` bufferization interface that I merged upstream (https://github.com/llvm/llvm-project/pull/128590). In addition, fix APInt assertions in `plan-execute-constant-foldable-subgraphs`. Author: Chris Bate [compiler] Enable While-to-For conversion in Stablehlo-to-Scf pass This change adds some patterns to the Stablehlo-to-Scf pass to enable While-to-For conversion after the Stablehlo-to-Scf conversion. This transformation is combined with the Stablehlo-to-Scf conversion because the While-to-For patterns require first scalarizing block arguments of the While operation. The heuristics for which block arguments should be scalarized are implemented as control callbacks for the scalarization patterns. These callbacks need Stablehlo-specific logic, so it makes sense to test the combined conversion as a single pass. From the pass users' perspective, it gives the appearence of going directly from `stablehlo.while` to `scf.for`. The test cases are updated to cover the new patterns. Author: Chris Bate [compiler] Fix assign-memory-spaces pass to respect function-level constraints Fixes an issue where the `plan.memory_space` attribute on a function was not being respected when converting function signatures. MR: initialdl/mlir-tensorrt!2146 Author: Chris Bate [compiler] Update scf.while detensorization to increase flexibility In order to incorporate the upstream "uplift scf.while to scf.for" transformation as part of the `stablehlo-to-scf` conversion, we need to detensorize the operands of `scf.while` that are likely to correspond to the loop induction variable. This change refactors our existing 'scf.while' detensorization transformation to give more flexibility and control. The TensorKindAnalysis is no longer required in order to use the pattern(s). Detensorization of `after` and `before` arguments of `scf.while` are now controlled separately. Author: Chris Bate [compiler] Improve handling of memory space constraints in the Plan dialect This Plan dialect. Constraints are now specified using a common attribute 'plan.memory_space' that can be applied to functions or individual arguments/results. In addition, patterns in `plan-alloc-tensors` and `plan-assign-memory-spaces` are updated to avoid introducing unnecessary transfers between memory spaces. Author: Chris Bate [compiler] Add plan-buffer-results-to-out-params pass This change adds a new Plan dialect pass `plan-buffer-results-to-out-params`. This pass is based on the upstream Bufferization pass `buffer-results-to-out-params`, but it can handle a wider number of cases (such as promoting dynamic allocations) and uses alias analysis utilities to guard against failure cases that the upstream pass currently cannot handle. These improvements should eventually be upstreamed back to the Bufferization dialect. Author: Chris Bate [compiler] Update func conversion in host-to-emitc In the EmitC conversion/translation process, you can use `func.func` or `emitc.func` to define functions. Previously, we converted all `func.func` to `emitc.func`. However, `emitc.func` does not have a path for supporting multiple return values. Therefore, prefer use of type conversions on `func.func` instead of converting the entire op to `emitc.func`. Add tests to verify that we can support multiple return values. Author: Chris Bate [compiler] Fix two host-to-emitc bugs This change fixes two bugs exposed by new 'host-to-emitc' conversion testing: - The `!emitc.size_t` type does not have DataLayout information specified upstream. Therefore, to ensure that the type can be queried using DataLayout, we add a DataLayoutTypeInterface external model to the type. All queries are simply mapped to queries to the `index` type. - The upstream `func.call` conversion has a bug where it does not correctly convert the result types of the call operation, which can lead to a type mismatch for any type that does not have an identity conversion. Additional tests are added to `host-to-emitc`. Eventually the fixes for both these issues should be moved upstream. Author: Chris Bate [common] Add Linalg-to-loops (on tensors) implementation and conversion pass Adds a ToLoopsOpInterface implementation and for Linalg operations. In addition, a conversion pass is added that converts ToLoopOpInterface operations to loops. Author: Chris Bate NFC: Move ToLoopsOpInterface to 'mlir-tensorrt-common' Moves the ToLoopsOpInterface to the 'mlir-tensorrt-common' project. This is in preperation for enabling the ToLoopsOpInterface on LinalgOp (lowering while still using Tensor types) to replace the `convert-stablehlo-arith-to-scalar` pipeline. MR: initialdl/mlir-tensorrt!2137 Author: Christopher Bate NFC: Fix formatting across several files Author: Chris Bate [executor] Introduce RuntimeSession "features" to control loading of runtime modules Previously, the RuntimeSession would always load all available runtime modules. This causes some inefficiences. For example, in certain integration tests for the Executor runtime, we don't use CUDA at all. However, because CUDA is still initialized by default, we would still require a GPU to be present just to run the integration test. Furthermore, some experimental modules (e.g. Lua cublas module) are not ready for "production" use and are only really invoked inside special integration tests. This change inroduces a notion of "features" to the RuntimeSession and RuntimeSessionOptions. A feature is just a string that identifies a particular runtime component. The particular semantic of a "feature" depends on the the actual runtime implementation. For example, for the LuaRuntimeSession, the feature names correspond to the available Lua "modules" (a module is just a group of C++ Lua extension functions), e.g. "core", "cuda", "tensorrt", etc. The RuntimeSessionOptions gains methods for enabling/disabling features. Certain features cause others to be added to the set automatically, e.g. "tensorrt" and "nccl" both require "cuda" to be added. The API is piped through all the way to the Python bindings to allow control of loaded modules at all levels. To preserve existing behavior, RuntimeSessions created from Python will load all available modules by default, but the `executor-runner|mlir-tensorrt-runner` tools now require features to be explicitly specified. Author: Christopher Bate NFC: Fix include guard for 'mlir-executor/Support/Status.h' Author: Sagar Shelke [compiler/lib] Add stablehlo composite to call pass to pre-processing pipeline This MR adds `StablehloLegalizeCompositeToCallPass` to the pre-processing pipeline. MLIR test is added. Author: Chris Bate [compiler] Add "default memory space" to ClusterKindAttrInterface Adds a new method to the ClusterKindAttrInterface so that backends can control the default tensor encoding (#plan.memory_space<..>) assigned by the `plan.assign-memory-spaces` pass at a function-scope level. In addition, we also allow an attribute to override the default space at function argument/results. This override mechnanism was previously lacking and will help resolve a long-standing issue where users cannot control the memory space of arguments/results reliably. Author: Christopher Bate [compiler] Fix some issues related to pipeline extension mechanism The StablehloToExecutableTensorRTExtension had both 'disable' and an inherited 'disabled' member variable. Delete the inherited one such it should not have been introduced and was not bound to any option. Further, remove unused 'extensions' vector from CompilationTaskOptionsBase. Author: Christopher Bate [executor] Fix ptrtoint and inttoptr op translation to Lua Previously, we could generate conflicting function types (due to pointer address space) when converting `executor.ptrtoint` and `executor.inttoptr` ops to opaque calls. Instead, defer the conversion to function call until the actual Lua translation point. At that point we can generate a function name without having to consider the pointer address space. Author: Chris Bate Introduce 'MLIRTensorRTCommmon' sub-project Certain targets need to be used across multiple sub-projects. For example, the 'TensorRTDynamicLoader' target is used in all sub-projects. In addition, the sub-projects need to be independently buildable. This change introduces another sub-project under the 'common' directory where shared code can be placed. This allows us to use `find_package` to declare the dependency, and downstream consumers to meet the requirement using any number of techniques to fullfill the 'find_package' call. Author: Chris Bate [compiler] Harden `stablehlo.constant` to `arith.constant` conversion There is a utility pass that runs in the stablehlo-to-executable pipeline that converts `stablehlo.constant` to `arith.constant`. This pass can temporarily create invalid IR due to `arith.constant` not supporting signful integer types. If the "verify-each" option is off, then the issue will not be caught since it happens to be self-correcting. However, the issue can still cause verification failures while debugging. This change fixes the issue by adding a `builtin.unrealized_conversion_cast` operation to bridge the type change between signless-and-signfull integer types. Author: Chris Bate Integrate LLVM at f137c3d592e96330e450a8fd63ef7e8877fc1908 Author: Christopher Bate Fix build with BUILD_SHARED_LIBS=ON The new InferTensorValueRangeInterface was used without correctly specifying the library dependency the PlanIR and StablehloExtIR libraries. Author: Sagar Shelke [compiler] Maintain output order in TensorRT engine. For TensorRT engine conversion, first step in lowering a cluster containing TensorRT ops is created inline group op. Operands to the yield op (i.e. terminator) of inline group op are values from the cluster that are used outside the cluster. These values are collected by getting uses of each op (with `op->getUses()`) and checking if they are outside the cluster. However, this use order is not deterministic and sometimes it is desired to get yield results in a certian order. This MR makes the following changes, 1. Add a function callback option named `ReorderRegionOpYieldValues` to `mlir::createRegionOpFromCluster` method. This callback function has signature `std::function &yieldValues, SmallVectorImpl &yieldTypes)>` which takes cluster values used outside the cluster (in SetVector) and their types. By default this is set to nullptr. 2. TensorRTToExecutable task is used in cases where a single `func.func` represents a single TensorRT engine. In this case, `ReorderRegionOpYieldValues` callback is implemented to make sure inline group op yield value order is same as func.func return values order. Valid MLIR test is added. GitOrigin-RevId: 630a69d8e14506db43cfefe4be2c790f9352da4f DependencyProvider.cmake # modified: build_tools/cmake/Dependencies.cmake # modified: build_tools/patches/mlir/0005-mlir-memref-Fix-memref.global-overly-constrained-ver.patch build_tools/patches/mlir/0006-mlir-emitc-Fix-two-EmitC-bugs.patch # deleted: build_tools/patches/mlir/0008-MLIR-Remove-unnecessary-include-from-MathToEmitC.h-t.patch build_tools/patches/mlir/0009-mlir-Support-FileLineColRange-in-LLVM-debug-translat.patch build_tools/patches/mlir/0010-MLIR-Fix-LLVMIRTransforms-build-failure-125485.patch build_tools/patches/mlir/0011-MLIR-Fix-bufferization-interface-for-tensor-reshape.patch build_tools/patches/stablehlo/0001-Fix-a-couple-missing-checks-for-static-shapes-in-sta.patch build_tools/patches/stablehlo/0002-cmake-Update-usage-of-HandleLLVMOptions-and-LLVM_DEF.patch build_tools/patches/stablehlo/0003-Don-t-insert-unnecessary-arith.index_cast-ops.patch build_tools/patches/stablehlo/0004-Fix-ZeroExtent-condition-in-simplification-pattern.patch build_tools/patches/stablehlo/0005-Fix-crash-on-ComplexType-in-PointwiseToLinalgMapConv.patch build_tools/patches/stablehlo/0006-Remove-explicit-use-of-LLVMSupport.patch build_tools/patches/stablehlo/0007-Fix-circular-dependence-between-StablehloPasses-and-.patch build_tools/patches/torch_mlir/0001-cmake-Allow-finding-Stablehlo-via-find_package.patch build_tools/patches/torch_mlir/0002-Make-compatible-with-more-recent-Stablehlo-version.patch build_tools/patches/torch_mlir/0003-Fix-some-configuration-paths-in-LIT-cfg.patch common/include/mlir-tensorrt-common/CMakeLists.txt # renamed: executor/include/mlir-executor/Runtime/Backend/Lua/LuaRegistration.h -> common/include/mlir-tensorrt-common/Conversion/Passes.h # new file: common/include/mlir-tensorrt-common/Conversion/Passes.td # new file: common/include/mlir-tensorrt-common/Dialect/EmitCExt/IR/DataLayoutImpl.h common/include/mlir-tensorrt-common/Dialect/LinalgExt/Transforms/ToLoopsOpInterfaceImpl.h common/include/mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.h # new file: common/include/mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.td # new file: common/lib/CMakeLists.txt # new file: common/lib/Conversion/CMakeLists.txt # new file: common/lib/Conversion/ToLoops/CMakeLists.txt # new file: common/lib/Conversion/ToLoops/ConvertToLoops.cpp # new file: common/lib/Dialect/CMakeLists.txt # new file: common/lib/Dialect/EmitCExt/CMakeLists.txt # new file: common/lib/Dialect/EmitCExt/DataLayoutImpl.cpp # new file: common/lib/Dialect/LinalgExt/CMakeLists.txt # new file: common/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt # new file: common/lib/Dialect/LinalgExt/Transforms/ToLoopsOpInterfaceImpl.cpp # new file: common/lib/Interfaces/CMakeLists.txt # new file: common/lib/Interfaces/ToLoopsOpInterface.cpp # new file: common/lib/Utils/CMakeLists.txt # renamed: executor/lib/Utils/TensorRTDynamicLoader/CMakeLists.txt -> common/lib/Utils/TensorRTDynamicLoader/CMakeLists.txt # renamed: executor/lib/Utils/TensorRTDynamicLoader/TensorRTDynamicLoader.cpp -> common/lib/Utils/TensorRTDynamicLoader/TensorRTDynamicLoader.cpp # modified: compiler/CMakeLists.txt # modified: compiler/include/mlir-tensorrt/Backends/Host/HostBackend.td # modified: compiler/include/mlir-tensorrt/Compiler/Extension.h # modified: compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/StablehloToExecutable.h compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/TensorRTExtension.h modified: compiler/include/mlir-tensorrt/Conversion/StablehloToTensorRT/StablehloToTensorRT.h compiler/include/mlir-tensorrt/Conversion/TensorRTCommon/ConvertToTensorRTCommon.h compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanDialect.td # new file: compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanEnums.h # modified: compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.h # modified: compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.td # modified: compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.td # modified: compiler/include/mlir-tensorrt/InitAllDialects.h # modified: compiler/include/mlir-tensorrt/InitAllPasses.h # modified: compiler/include/mlir-tensorrt/Transforms/Transforms.h # modified: compiler/lib/Backends/Host/HostBackend.cpp # modified: compiler/lib/CAPI/Compiler/Registration/RegisterAllDialects.cpp # modified: compiler/lib/Compiler/OptionsProviders.cpp # modified: compiler/lib/Compiler/StablehloToExecutable/Passes.cpp # modified: compiler/lib/Compiler/StablehloToExecutable/StableHloInputPipelines.cpp compiler/lib/Compiler/StablehloToExecutable/StablehloToExecutable.cpp modified: compiler/lib/Conversion/HostToEmitC/HostToEmitC.cpp # modified: compiler/lib/Conversion/StablehloToScf/CMakeLists.txt # modified: compiler/lib/Conversion/StablehloToScf/StablehloToScf.cpp compiler/lib/Conversion/StablehloToTensorRT/CMakeLists.txt # modified: compiler/lib/Conversion/StablehloToTensorRT/Matchers.h # new file: compiler/lib/Conversion/StablehloToTensorRT/ReductionConversions.cpp # modified: compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp # modified: compiler/lib/Dialect/Plan/Transforms/AllocTensors.cpp # modified: compiler/lib/Dialect/Plan/Transforms/AssignMemorySpaces.cpp # modified: compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt # modified: compiler/lib/Dialect/Plan/Transforms/CreateShapeFuncs.cpp compiler/lib/Dialect/Plan/Transforms/MaterializeExplicitTransfers.cpp compiler/lib/Dialect/Plan/Transforms/ModuleBufferization/BufferResultsToOutParams.cpp compiler/lib/Dialect/Plan/Transforms/ModuleBufferization/ModuleBufferizationAnalysis.cpp compiler/lib/Dialect/Plan/Transforms/OptimizeMemorySpaces.cpp # modified: compiler/lib/Dialect/Plan/Transforms/Passes.cpp # new file: compiler/lib/Dialect/Plan/Transforms/PromoteHostTensorsToHostPinned.cpp compiler/lib/Transforms/SCFDetensorizeLoops/SCFDetensorizeLoops.cpp # new file: compiler/test/Conversion/HostToEmitC/func-to-emitc.mlir # modified: compiler/test/Conversion/HostToEmitC/memref-to-emitc.mlir compiler/test/Conversion/StablehloToArith/stablehlo-constant-to-arith.mlir compiler/test/Conversion/StablehloToScf/stablehlo-to-scf.mlir # new file: compiler/test/Conversion/StablehloToTensorRT/dot-to-einsum.mlir # modified: compiler/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-invalid.mlir compiler/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-trt10.mlir compiler/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt.mlir file: compiler/test/Dialect/Plan/assign-and-optimize-memory-spaces.mlir # deleted: compiler/test/Dialect/Plan/assign-memory-spaces.mlir # new file: compiler/test/Dialect/Plan/buffer-results-to-out-params.mlir # new file: compiler/test/Dialect/Plan/materialize-explicit-transfers.mlir compiler/test/Dialect/Plan/materialize-shape-calculations-composite.mlir compiler/test/Dialect/Plan/materialize-shape-calculations.mlir # modified: compiler/test/Dialect/Plan/plan-bufferize-pipeline.mlir # new file: compiler/test/Dialect/Plan/promote-host-tensors-to-host-pinned.mlir # new file: compiler/test/Pipelines/StableHloInputPipeline/preprocessing-pipeline.mlir compiler/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-binary.mlir compiler/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-unary.mlir compiler/test/Target/Lua/IntegrationTests/buffer-ops-bf16.mlir # modified: compiler/test/Target/Lua/IntegrationTests/buffer-ops-dynamic.mlir # modified: compiler/test/Target/Lua/IntegrationTests/buffer-ops-f16.mlir # modified: compiler/test/Target/Lua/IntegrationTests/buffer-ops-f32.mlir # modified: compiler/test/Target/Lua/IntegrationTests/buffer-ops-f8E4M3FN.mlir # modified: compiler/test/Target/Lua/IntegrationTests/buffer-ops-i1.mlir # modified: compiler/test/Target/Lua/IntegrationTests/buffer-ops-i4.mlir # new file: compiler/test/Target/Lua/IntegrationTests/lit.local.cfg # modified: compiler/test/Target/Lua/IntegrationTests/memcpy-strided.mlir # modified: compiler/test/Target/Lua/IntegrationTests/memcpy.mlir # modified: compiler/test/Transforms/SCFDetensorizeLoops/scf-detensorize-loops.mlir compiler/test/python/IntegrationTests/Torch/test_torch_add.py # modified: compiler/test/python/IntegrationTests/lit.local.cfg # modified: compiler/test/python/IntegrationTests/test_call_validation.py # modified: compiler/test/python/IntegrationTests/test_non_dps_cconv.py # modified: compiler/test/python/IntegrationTests/test_return_allocation_loop.py # modified: compiler/test/python/IntegrationTests/test_stablehlo_add.py # modified: compiler/test/python/IntegrationTests/test_stablehlo_dynamic.py # modified: compiler/test/python/IntegrationTests/test_stablehlo_dynamic_iota.py # modified: compiler/test/python/IntegrationTests/test_tensorrt10_data_type_support.py compiler/test/python/IntegrationTests/test_tensorrt_add.py # modified: compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_api.py compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_debug_dump.py compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_plugin_schema_api.py compiler/test/python/mlir_tensorrt_runtime/test_runtime_api.py # modified: compiler/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py executor/cmake/ExecutorDependencies.cmake # modified: executor/include/mlir-executor-c/Runtime/Runtime.h # modified: executor/include/mlir-executor/Conversion/ConvertToExecutorCommon.h # modified: executor/include/mlir-executor/Executor/IR/ExecutorOps.td modified: executor/include/mlir-executor/Runtime/API/API.h # modified: executor/include/mlir-executor/Runtime/Backend/Lua/LuaExtensionRegistry.h executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h # modified: executor/include/mlir-executor/Runtime/Backend/Utils/NvtxUtils.h # modified: executor/include/mlir-executor/Support/Status.h # modified: executor/lib/CAPI/Runtime/Runtime.cpp # modified: executor/lib/Executor/IR/Executor.cpp # modified: executor/lib/Executor/Transforms/Passes.cpp # modified: executor/lib/Runtime/API/API.cpp # modified: executor/lib/Runtime/Backend/Lua/LuaExtensionRegistry.cpp # modified: executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp # modified: executor/lib/Target/Lua/TranslateToLua.cpp # modified: executor/lib/Target/Lua/TranslateToRuntimeExecutable.cpp # modified: executor/lib/Tools/ExecutorRunnerMain.cpp # modified: executor/lib/Utils/CMakeLists.txt # modified: executor/test/Executor/lower-builtins.mlir # modified: executor/test/IntegrationTests/arithmetic.mlir # modified: executor/test/IntegrationTests/assertion.mlir # modified: executor/test/IntegrationTests/complex.mlir # modified: executor/test/IntegrationTests/control-flow-nested.mlir # modified: executor/test/IntegrationTests/control-flow.mlir # modified: executor/test/IntegrationTests/coroutine.mlir # modified: executor/test/IntegrationTests/fill-device-f32.mlir # modified: executor/test/IntegrationTests/fill-f32.mlir # modified: executor/test/IntegrationTests/fill-i1.mlir # modified: executor/test/IntegrationTests/host-buffer-c32.mlir # modified: executor/test/IntegrationTests/host-buffer-i4.mlir # modified: executor/test/IntegrationTests/load-globals.mlir # modified: executor/test/IntegrationTests/pointer-cast-ops.mlir # new file: executor/test/IntegrationTests/ptr-to-int.mlir # modified: executor/test/IntegrationTests/stream.mlir # modified: executor/test/Unit/Runtime/LuaRuntime/ExecuteFunctionWithLuaBackendTests.cpp modified: integrations/python/bindings/Runtime/RuntimePyBind.cpp # modified: integrations/python/mlir_tensorrt_runtime/mlir_tensorrt/runtime/_mlir_libs/_api.pyi integrations/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td # modified: tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp # modified: tensorrt/test/lit.cfg.py # new file: third_party/torch-mlir-cmake/CMakeLists.txt # new file: third_party/torch-mlir-cmake/TorchMLIRModule.cpp # --- mlir-tensorrt/CMakeLists.txt | 2 + mlir-tensorrt/DependencyProvider.cmake | 74 ++- .../build_tools/cmake/Dependencies.cmake | 46 +- ...memref.global-overly-constrained-ver.patch | 17 +- .../0006-mlir-emitc-Fix-two-EmitC-bugs.patch | 16 +- ...cessary-include-from-MathToEmitC.h-t.patch | 29 - ...eLineColRange-in-LLVM-debug-translat.patch | 13 +- ...LVMIRTransforms-build-failure-125485.patch | 44 -- ...ization-interface-for-tensor-reshape.patch | 71 +++ ...sing-checks-for-static-shapes-in-sta.patch | 57 +- ...ge-of-HandleLLVMOptions-and-LLVM_DEF.patch | 39 +- ...ert-unnecessary-arith.index_cast-ops.patch | 53 -- ...-condition-in-simplification-pattern.patch | 15 +- ...plexType-in-PointwiseToLinalgMapConv.patch | 95 --- ...6-Remove-explicit-use-of-LLVMSupport.patch | 10 +- ...endence-between-StablehloPasses-and-.patch | 20 +- ...w-finding-Stablehlo-via-find_package.patch | 60 -- ...e-with-more-recent-Stablehlo-version.patch | 41 -- ...-some-configuration-paths-in-LIT-cfg.patch | 58 -- mlir-tensorrt/common/CMakeLists.txt | 38 ++ .../mlir-tensorrt-common/CMakeLists.txt | 0 .../mlir-tensorrt-common/Conversion/Passes.h} | 32 +- .../mlir-tensorrt-common/Conversion/Passes.td | 18 + .../Dialect/EmitCExt/IR/DataLayoutImpl.h | 36 ++ .../Transforms/ToLoopsOpInterfaceImpl.h | 53 ++ .../Interfaces/ToLoopsOpInterface.h | 38 ++ .../Interfaces/ToLoopsOpInterface.td | 28 + mlir-tensorrt/common/lib/CMakeLists.txt | 5 + .../common/lib/Conversion/CMakeLists.txt | 11 + .../lib/Conversion/ToLoops/CMakeLists.txt | 15 + .../lib/Conversion/ToLoops/ConvertToLoops.cpp | 75 +++ .../common/lib/Dialect/CMakeLists.txt | 2 + .../lib/Dialect/EmitCExt/CMakeLists.txt | 9 + .../lib/Dialect/EmitCExt/DataLayoutImpl.cpp | 65 ++ .../lib/Dialect/LinalgExt/CMakeLists.txt | 1 + .../LinalgExt/Transforms/CMakeLists.txt | 12 + .../Transforms/ToLoopsOpInterfaceImpl.cpp | 194 ++++++ .../common/lib/Interfaces/CMakeLists.txt | 35 ++ .../lib/Interfaces/ToLoopsOpInterface.cpp | 13 + mlir-tensorrt/common/lib/Utils/CMakeLists.txt | 3 + .../TensorRTDynamicLoader/CMakeLists.txt | 2 +- .../TensorRTDynamicLoader.cpp | 0 mlir-tensorrt/compiler/CMakeLists.txt | 1 + .../Backends/Host/HostBackend.td | 2 +- .../mlir-tensorrt/Compiler/Extension.h | 5 - .../mlir-tensorrt/Compiler/OptionsProviders.h | 9 +- .../StablehloToExecutable.h | 7 - .../StablehloToExecutable/TensorRTExtension.h | 4 +- .../mlir-tensorrt/Conversion/Passes.td | 9 +- .../StablehloToTensorRT/StablehloToTensorRT.h | 15 +- .../TensorRTCommon/ConvertToTensorRTCommon.h | 44 +- .../mlir-tensorrt/Dialect/Plan/IR/Plan.h | 6 +- .../Dialect/Plan/IR/PlanDialect.td | 8 + .../mlir-tensorrt/Dialect/Plan/IR/PlanEnums.h | 29 + .../Dialect/Plan/IR/PlanInterfaces.h | 1 + .../Dialect/Plan/IR/PlanInterfaces.td | 10 + .../Dialect/Plan/Transforms/Passes.td | 101 ++++ .../include/mlir-tensorrt/InitAllDialects.h | 4 + .../include/mlir-tensorrt/InitAllPasses.h | 6 +- .../mlir-tensorrt/Transforms/Transforms.h | 19 + .../lib/Backends/Host/HostBackend.cpp | 4 + .../Registration/RegisterAllDialects.cpp | 24 +- .../lib/Compiler/OptionsProviders.cpp | 5 +- .../Compiler/StablehloToExecutable/Passes.cpp | 59 +- .../StableHloInputPipelines.cpp | 3 + .../StablehloToExecutable.cpp | 61 +- .../lib/Conversion/HostToEmitC/CMakeLists.txt | 1 + .../Conversion/HostToEmitC/HostToEmitC.cpp | 121 ++-- .../Conversion/StablehloToScf/CMakeLists.txt | 6 + .../StablehloToScf/StablehloToScf.cpp | 344 ++++++++++- .../StablehloToTensorRT/CMakeLists.txt | 3 +- .../Conversion/StablehloToTensorRT/Matchers.h | 44 +- .../ReductionConversions.cpp | 529 ++++++++++++++++ .../StablehloToTensorRT.cpp | 327 +--------- .../Dialect/Plan/Transforms/AllocTensors.cpp | 35 +- .../Plan/Transforms/AssignMemorySpaces.cpp | 567 ++++++++++-------- .../Dialect/Plan/Transforms/CMakeLists.txt | 4 + .../Plan/Transforms/CreateShapeFuncs.cpp | 14 +- .../MaterializeExplicitTransfers.cpp | 181 ++++++ .../BufferResultsToOutParams.cpp | 517 ++++++++++++++++ .../ModuleBufferizationAnalysis.cpp | 61 +- .../Plan/Transforms/OptimizeMemorySpaces.cpp | 552 +++++++++++++++++ .../lib/Dialect/Plan/Transforms/Passes.cpp | 17 +- .../PromoteHostTensorsToHostPinned.cpp | 135 +++++ .../SCFDetensorizeLoops.cpp | 279 +++++---- .../Conversion/HostToEmitC/func-to-emitc.mlir | 49 ++ .../HostToEmitC/memref-to-emitc.mlir | 56 +- .../stablehlo-constant-to-arith.mlir | 23 + .../StablehloToScf/stablehlo-to-scf.mlir | 282 ++++++--- .../StablehloToTensorRT/dot-to-einsum.mlir | 141 +++++ .../stablehlo-to-tensorrt-invalid.mlir | 17 - .../stablehlo-to-tensorrt-trt10.mlir | 3 +- .../stablehlo-to-tensorrt.mlir | 14 +- .../test/Dialect/LinalgExt/to-loops.mlir | 124 ++++ .../assign-and-optimize-memory-spaces.mlir | 238 ++++++++ .../Dialect/Plan/assign-memory-spaces.mlir | 73 --- .../Plan/buffer-results-to-out-params.mlir | 179 ++++++ .../Plan/materialize-explicit-transfers.mlir | 62 ++ ...erialize-shape-calculations-composite.mlir | 2 +- .../Plan/materialize-shape-calculations.mlir | 18 +- .../Dialect/Plan/plan-bufferize-pipeline.mlir | 231 +++++-- .../promote-host-tensors-to-host-pinned.mlir | 30 + .../preprocessing-pipeline.mlir | 19 + .../end-to-end-binary.mlir | 5 +- .../end-to-end-unary.mlir | 4 +- .../Lua/IntegrationTests/buffer-ops-bf16.mlir | 2 +- .../IntegrationTests/buffer-ops-dynamic.mlir | 2 +- .../Lua/IntegrationTests/buffer-ops-f16.mlir | 2 +- .../Lua/IntegrationTests/buffer-ops-f32.mlir | 2 +- .../IntegrationTests/buffer-ops-f8E4M3FN.mlir | 2 +- .../Lua/IntegrationTests/buffer-ops-i1.mlir | 2 +- .../Lua/IntegrationTests/buffer-ops-i4.mlir | 2 +- .../Target/Lua/IntegrationTests/lit.local.cfg | 1 + .../Lua/IntegrationTests/memcpy-strided.mlir | 2 +- .../Target/Lua/IntegrationTests/memcpy.mlir | 2 +- .../scf-detensorize-loops.mlir | 4 +- mlir-tensorrt/compiler/test/lit.cfg.py | 21 +- .../IntegrationTests/Torch/test_torch_add.py | 2 +- .../python/IntegrationTests/lit.local.cfg | 2 - .../IntegrationTests/test_call_validation.py | 2 +- .../IntegrationTests/test_non_dps_cconv.py | 2 +- .../test_return_allocation_loop.py | 2 +- .../IntegrationTests/test_stablehlo_add.py | 2 +- .../test_stablehlo_dynamic.py | 2 +- .../test_stablehlo_dynamic_iota.py | 2 +- .../test_tensorrt10_data_type_support.py | 2 +- .../IntegrationTests/test_tensorrt_add.py | 10 +- .../compiler_api/test_compiler_api.py | 2 +- .../compiler_api/test_compiler_debug_dump.py | 2 +- .../compiler_api/test_plugin_schema_api.py | 1 - .../mlir_tensorrt_runtime/test_runtime_api.py | 5 +- .../test_runtime_debug_dump.py | 1 - mlir-tensorrt/executor/CMakeLists.txt | 4 +- .../executor/cmake/ExecutorDependencies.cmake | 22 - .../include/mlir-executor-c/Runtime/Runtime.h | 5 + .../Conversion/ConvertToExecutorCommon.h | 53 +- .../mlir-executor/Executor/IR/ExecutorOps.td | 30 +- .../include/mlir-executor/InitAllPasses.h | 6 +- .../include/mlir-executor/Runtime/API/API.h | 100 ++- .../Backend/Lua/LuaExtensionRegistry.h | 13 +- .../Runtime/Backend/Lua/LuaRuntime.h | 4 +- .../Runtime/Backend/Utils/NvtxUtils.h | 14 - .../include/mlir-executor/Support/Status.h | 6 +- .../executor/lib/CAPI/Runtime/Runtime.cpp | 6 + .../executor/lib/Executor/IR/Executor.cpp | 11 - .../lib/Executor/Transforms/Passes.cpp | 4 +- .../executor/lib/Runtime/API/API.cpp | 79 ++- .../Backend/Lua/LuaExtensionRegistry.cpp | 25 +- .../lib/Runtime/Backend/Lua/LuaRuntime.cpp | 80 +-- .../lib/Target/Lua/TranslateToLua.cpp | 52 +- .../Lua/TranslateToRuntimeExecutable.cpp | 63 +- .../executor/lib/Tools/ExecutorRunnerMain.cpp | 71 ++- .../executor/lib/Utils/CMakeLists.txt | 4 - .../test/Executor/lower-builtins.mlir | 26 - .../test/IntegrationTests/arithmetic.mlir | 8 +- .../test/IntegrationTests/assertion.mlir | 2 +- .../test/IntegrationTests/complex.mlir | 2 +- .../IntegrationTests/control-flow-nested.mlir | 2 +- .../test/IntegrationTests/control-flow.mlir | 2 +- .../test/IntegrationTests/coroutine.mlir | 2 +- .../IntegrationTests/fill-device-f32.mlir | 2 +- .../test/IntegrationTests/fill-f32.mlir | 2 +- .../test/IntegrationTests/fill-i1.mlir | 2 +- .../IntegrationTests/host-buffer-c32.mlir | 2 +- .../test/IntegrationTests/host-buffer-i4.mlir | 2 +- .../test/IntegrationTests/load-globals.mlir | 2 +- .../IntegrationTests/pointer-cast-ops.mlir | 2 +- .../test/IntegrationTests/ptr-to-int.mlir | 32 + .../test/IntegrationTests/stream.mlir | 2 +- .../ExecuteFunctionWithLuaBackendTests.cpp | 23 +- .../test/executor-runner/invalid.mlir | 2 +- .../python/bindings/Runtime/RuntimePyBind.cpp | 21 +- .../mlir_tensorrt/runtime/_mlir_libs/_api.pyi | 6 +- .../mlir_tensorrt/tools/gpu_tools.py | 25 +- mlir-tensorrt/tensorrt/CMakeLists.txt | 1 + .../NetworkEncoder.h | 12 +- .../TensorRT/IR/TensorRTOps.td | 2 +- .../NetworkEncoder.cpp | 27 +- mlir-tensorrt/tensorrt/test/lit.cfg.py | 2 +- .../torch-mlir-cmake/CMakeLists.txt | 262 ++++++++ .../torch-mlir-cmake/TorchMLIRModule.cpp | 33 + 181 files changed, 6286 insertions(+), 2145 deletions(-) delete mode 100644 mlir-tensorrt/build_tools/patches/mlir/0008-MLIR-Remove-unnecessary-include-from-MathToEmitC.h-t.patch delete mode 100644 mlir-tensorrt/build_tools/patches/mlir/0010-MLIR-Fix-LLVMIRTransforms-build-failure-125485.patch create mode 100644 mlir-tensorrt/build_tools/patches/mlir/0011-MLIR-Fix-bufferization-interface-for-tensor-reshape.patch delete mode 100644 mlir-tensorrt/build_tools/patches/stablehlo/0003-Don-t-insert-unnecessary-arith.index_cast-ops.patch delete mode 100644 mlir-tensorrt/build_tools/patches/stablehlo/0005-Fix-crash-on-ComplexType-in-PointwiseToLinalgMapConv.patch delete mode 100644 mlir-tensorrt/build_tools/patches/torch_mlir/0001-cmake-Allow-finding-Stablehlo-via-find_package.patch delete mode 100644 mlir-tensorrt/build_tools/patches/torch_mlir/0002-Make-compatible-with-more-recent-Stablehlo-version.patch delete mode 100644 mlir-tensorrt/build_tools/patches/torch_mlir/0003-Fix-some-configuration-paths-in-LIT-cfg.patch create mode 100644 mlir-tensorrt/common/CMakeLists.txt create mode 100644 mlir-tensorrt/common/include/mlir-tensorrt-common/CMakeLists.txt rename mlir-tensorrt/{executor/include/mlir-executor/Runtime/Backend/Lua/LuaRegistration.h => common/include/mlir-tensorrt-common/Conversion/Passes.h} (51%) create mode 100644 mlir-tensorrt/common/include/mlir-tensorrt-common/Conversion/Passes.td create mode 100644 mlir-tensorrt/common/include/mlir-tensorrt-common/Dialect/EmitCExt/IR/DataLayoutImpl.h create mode 100644 mlir-tensorrt/common/include/mlir-tensorrt-common/Dialect/LinalgExt/Transforms/ToLoopsOpInterfaceImpl.h create mode 100644 mlir-tensorrt/common/include/mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.h create mode 100644 mlir-tensorrt/common/include/mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.td create mode 100644 mlir-tensorrt/common/lib/CMakeLists.txt create mode 100644 mlir-tensorrt/common/lib/Conversion/CMakeLists.txt create mode 100644 mlir-tensorrt/common/lib/Conversion/ToLoops/CMakeLists.txt create mode 100644 mlir-tensorrt/common/lib/Conversion/ToLoops/ConvertToLoops.cpp create mode 100644 mlir-tensorrt/common/lib/Dialect/CMakeLists.txt create mode 100644 mlir-tensorrt/common/lib/Dialect/EmitCExt/CMakeLists.txt create mode 100644 mlir-tensorrt/common/lib/Dialect/EmitCExt/DataLayoutImpl.cpp create mode 100644 mlir-tensorrt/common/lib/Dialect/LinalgExt/CMakeLists.txt create mode 100644 mlir-tensorrt/common/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt create mode 100644 mlir-tensorrt/common/lib/Dialect/LinalgExt/Transforms/ToLoopsOpInterfaceImpl.cpp create mode 100644 mlir-tensorrt/common/lib/Interfaces/CMakeLists.txt create mode 100644 mlir-tensorrt/common/lib/Interfaces/ToLoopsOpInterface.cpp create mode 100644 mlir-tensorrt/common/lib/Utils/CMakeLists.txt rename mlir-tensorrt/{executor => common}/lib/Utils/TensorRTDynamicLoader/CMakeLists.txt (64%) rename mlir-tensorrt/{executor => common}/lib/Utils/TensorRTDynamicLoader/TensorRTDynamicLoader.cpp (100%) create mode 100644 mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanEnums.h create mode 100644 mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/ReductionConversions.cpp create mode 100644 mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeExplicitTransfers.cpp create mode 100644 mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/ModuleBufferization/BufferResultsToOutParams.cpp create mode 100644 mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OptimizeMemorySpaces.cpp create mode 100644 mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/PromoteHostTensorsToHostPinned.cpp create mode 100644 mlir-tensorrt/compiler/test/Conversion/HostToEmitC/func-to-emitc.mlir create mode 100644 mlir-tensorrt/compiler/test/Conversion/StablehloToArith/stablehlo-constant-to-arith.mlir create mode 100644 mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/dot-to-einsum.mlir create mode 100644 mlir-tensorrt/compiler/test/Dialect/LinalgExt/to-loops.mlir create mode 100644 mlir-tensorrt/compiler/test/Dialect/Plan/assign-and-optimize-memory-spaces.mlir delete mode 100644 mlir-tensorrt/compiler/test/Dialect/Plan/assign-memory-spaces.mlir create mode 100644 mlir-tensorrt/compiler/test/Dialect/Plan/buffer-results-to-out-params.mlir create mode 100644 mlir-tensorrt/compiler/test/Dialect/Plan/materialize-explicit-transfers.mlir create mode 100644 mlir-tensorrt/compiler/test/Dialect/Plan/promote-host-tensors-to-host-pinned.mlir create mode 100644 mlir-tensorrt/compiler/test/Pipelines/StableHloInputPipeline/preprocessing-pipeline.mlir create mode 100644 mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/lit.local.cfg create mode 100644 mlir-tensorrt/executor/test/IntegrationTests/ptr-to-int.mlir create mode 100644 mlir-tensorrt/third_party/torch-mlir-cmake/CMakeLists.txt create mode 100644 mlir-tensorrt/third_party/torch-mlir-cmake/TorchMLIRModule.cpp diff --git a/mlir-tensorrt/CMakeLists.txt b/mlir-tensorrt/CMakeLists.txt index 2cb09c49c..00a316ef1 100644 --- a/mlir-tensorrt/CMakeLists.txt +++ b/mlir-tensorrt/CMakeLists.txt @@ -190,6 +190,8 @@ if(MLIR_TRT_ENABLE_PYTHON) mlir_tensorrt_find_dlpack() endif() +find_package(MLIRTensorRTCommon REQUIRED) + #-------------------------------------------------- # Diagnostics #-------------------------------------------------- diff --git a/mlir-tensorrt/DependencyProvider.cmake b/mlir-tensorrt/DependencyProvider.cmake index c7d36e001..bed343064 100644 --- a/mlir-tensorrt/DependencyProvider.cmake +++ b/mlir-tensorrt/DependencyProvider.cmake @@ -17,7 +17,7 @@ if("${MLIR_TRT_USE_LLVM}" STREQUAL "prebuilt") set(MTRT_BUILD_LLVM_FROM_SOURCE OFF) endif() -set(MLIR_TRT_LLVM_COMMIT "729416e586fba71b4f63d71b1b5c765aefbf200b") +set(MLIR_TRT_LLVM_COMMIT "f137c3d592e96330e450a8fd63ef7e8877fc1908") set(mlir_patch_dir "${CMAKE_CURRENT_LIST_DIR}/build_tools/patches/mlir") @@ -43,7 +43,7 @@ else() "${mlir_patch_dir}/0005-mlir-memref-Fix-memref.global-overly-constrained-ver.patch" "${mlir_patch_dir}/0006-mlir-emitc-Fix-two-EmitC-bugs.patch" "${mlir_patch_dir}/0009-mlir-Support-FileLineColRange-in-LLVM-debug-translat.patch" - "${mlir_patch_dir}/0010-MLIR-Fix-LLVMIRTransforms-build-failure-125485.patch" + "${mlir_patch_dir}/0011-MLIR-Fix-bufferization-interface-for-tensor-reshape.patch" # Set the CPM cache key to the Git hash for easy navigation. PRE_ADD_HOOK [[ list(APPEND _vap_UNPARSED_ARGUMENTS @@ -63,6 +63,8 @@ else() ) list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") + set(MLIR_MAIN_SRC_DIR "${LLVM_SOURCE_DIR}/mlir" CACHE STRING "" FORCE) + if(TARGET MLIRPythonExtension.Core) get_property(mlir_core_pybind_capi_embed TARGET MLIRPythonExtension.Core @@ -102,14 +104,12 @@ set(stablehlo_patch_dir "${CMAKE_SOURCE_DIR}/build_tools/patches/stablehlo") nv_register_package( NAME Stablehlo VERSION 1.9.3 - GIT_TAG 459897561d365ef97caba46984847f9184d472ec + GIT_TAG 4bf77d23bd9150782a70d85fda9c12a2dec5328c GIT_REPOSITORY "https://github.com/openxla/stablehlo.git" PATCHES "${stablehlo_patch_dir}/0001-Fix-a-couple-missing-checks-for-static-shapes-in-sta.patch" "${stablehlo_patch_dir}/0002-cmake-Update-usage-of-HandleLLVMOptions-and-LLVM_DEF.patch" - "${stablehlo_patch_dir}/0003-Don-t-insert-unnecessary-arith.index_cast-ops.patch" "${stablehlo_patch_dir}/0004-Fix-ZeroExtent-condition-in-simplification-pattern.patch" - "${stablehlo_patch_dir}/0005-Fix-crash-on-ComplexType-in-PointwiseToLinalgMapConv.patch" "${stablehlo_patch_dir}/0006-Remove-explicit-use-of-LLVMSupport.patch" "${stablehlo_patch_dir}/0007-Fix-circular-dependence-between-StablehloPasses-and-.patch" OPTIONS @@ -123,6 +123,42 @@ nv_register_package( ]] ) +#------------------------------------------------------------------------------------- +# MLIRTensorRTCommon +# +# MLIRTensorRTCommon is a sub-project that contains components used across the +# other sub-projects like MLIRExecutor and MLIRTensorRTDialect. +#------------------------------------------------------------------------------------- + +nv_register_package( + NAME MLIRTensorRTCommon + SOURCE_DIR "${CMAKE_SOURCE_DIR}/common" +) + +# ----------------------------------------------------------------------------- +# NVTX +# ----------------------------------------------------------------------------- + +nv_register_package( + NAME NVTX + GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git + GIT_TAG v3.1.0 + GIT_SHALLOW TRUE + SOURCE_SUBDIR c + EXCLUDE_FROM_ALL TRUE + DOWNLOAD_ONLY TRUE + POST_ADD_HOOK [[ + if(NOT TARGET nvtx3-cpp) + add_library(nvtx3-cpp INTERFACE IMPORTED) + target_include_directories(nvtx3-cpp INTERFACE + "$") + # Ignore some warnings due to NVTX3 code style. + target_compile_options(nvtx3-cpp INTERFACE + -Wno-missing-braces) + endif() + ]] +) + #------------------------------------------------------------------------------------- # MLIR-Executor # @@ -158,26 +194,23 @@ nv_register_package( #------------------------------------------------------------------------------------- # Torch-MLIR #------------------------------------------------------------------------------------- -set(torch_mlir_patch_dir "${CMAKE_SOURCE_DIR}/build_tools/patches/torch_mlir") - +set(MLIR_TRT_TORCH_MLIR_COMMIT "9f2ba5abaa85cefd95cc85579fafd0c53c1101e8") nv_register_package( NAME torch_mlir - GIT_REPOSITORY https://github.com/llvm/torch-mlir.git - GIT_TAG 0bb263e99415d43255350d29263097b4980303bf - PATCHES - "build_tools/patches/torch_mlir/0001-cmake-Allow-finding-Stablehlo-via-find_package.patch" - "build_tools/patches/torch_mlir/0002-Make-compatible-with-more-recent-Stablehlo-version.patch" - "build_tools/patches/torch_mlir/0003-Fix-some-configuration-paths-in-LIT-cfg.patch" - EXCLUDE_FROM_ALL TRUE + URL "https://github.com/llvm/torch-mlir/archive/${MLIR_TRT_TORCH_MLIR_COMMIT}.zip" # We need to specify an existing directory that is not actually a submodule # since GIT_SUBMODULES does not except the empty string due to # https://gitlab.kitware.com/cmake/cmake/-/issues/24578 GIT_SUBMODULES "docs" - OPTIONS - "TORCH_MLIR_OUT_OF_TREE_BUILD ON" - "TORCH_MLIR_ENABLE_STABLEHLO ON" - "TORCH_MLIR_EXTERNAL_STABLEHLO_DIR find_package" - "TORCH_MLIR_ENABLE_TOSA OFF" + DOWNLOAD_ONLY TRUE + + POST_ADD_HOOK [[ + add_subdirectory( + ${CMAKE_SOURCE_DIR}/third_party/torch-mlir-cmake + ${CMAKE_BINARY_DIR}/_deps/torch_mlir-build + EXCLUDE_FROM_ALL + ) + ]] ) #------------------------------------------------------------------------------------- @@ -202,7 +235,7 @@ macro(mtrt_provide_dependency method dep_name) endif() if("${dep_name}" MATCHES - "^(MLIRExecutor|MLIRTensorRTDialect|Stablehlo|torch_mlir)$") + "^(MLIRExecutor|MLIRTensorRTDialect|Stablehlo|torch_mlir|NVTX|MLIRTensorRTCommon)$") nv_add_package("${dep_name}") set("${dep_name}_FOUND" TRUE) endif() @@ -230,6 +263,7 @@ macro(mtrt_provide_dependency method dep_name) find_package(LLVM ${ARGN} BYPASS_PROVIDER) endif() endif() + endmacro() cmake_language( diff --git a/mlir-tensorrt/build_tools/cmake/Dependencies.cmake b/mlir-tensorrt/build_tools/cmake/Dependencies.cmake index f8b2217b2..1eb9efa8b 100644 --- a/mlir-tensorrt/build_tools/cmake/Dependencies.cmake +++ b/mlir-tensorrt/build_tools/cmake/Dependencies.cmake @@ -6,10 +6,16 @@ include(${CMAKE_CURRENT_LIST_DIR}/TensorRTDownloadURL.cmake) # expected version. #------------------------------------------------------------------------------------- macro(get_tensorrt_version nvinfer_version_file out_var) - file(STRINGS "${nvinfer_version_file}" VERSION_STRINGS REGEX "#define NV_TENSORRT_.*") + file(STRINGS "${nvinfer_version_file}" VERSION_STRINGS REGEX "#define (TRT_.+|NV_TENSORRT_.+) [0-9]+") foreach(TYPE MAJOR MINOR PATCH BUILD) - string(REGEX MATCH "NV_TENSORRT_${TYPE} [0-9]+" TRT_TYPE_STRING ${VERSION_STRINGS}) - string(REGEX MATCH "[0-9]+" TRT_${TYPE} ${TRT_TYPE_STRING}) + string(REGEX MATCH "(TRT_${TYPE}_ENTERPRISE|NV_TENSORRT_${TYPE}) [0-9]+" TRT_TYPE_STRING ${VERSION_STRINGS}) + if("${TRT_TYPE_STRING}" STREQUAL "") + message(FATAL_ERROR "Failed to extract TensorRT ${TYPE} version from ${nvinfer_version_file}") + endif() + string(REGEX MATCH "[0-9]+" "TRT_${TYPE}" "${TRT_TYPE_STRING}") + if("TRT_${TYPE}" STREQUAL "") + message(FATAL_ERROR "Failed to extract TensorRT ${TYPE} version from ${nvinfer_version_file}") + endif() endforeach(TYPE) set("${out_var}" "${TRT_MAJOR}.${TRT_MINOR}.${TRT_PATCH}.${TRT_BUILD}") endmacro() @@ -50,7 +56,7 @@ macro(configure_tensorrt_python_plugin_header) if(ARG_INSTALL_DIR) find_file( trt_python_plugin_header - NAMES plugin.h + NAMES NvInferPythonPlugin.h plugin.h HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl REQUIRED @@ -60,7 +66,7 @@ macro(configure_tensorrt_python_plugin_header) else() find_path( trt_python_plugin_header - NAMES plugin.h + NAMES NvInferPythonPlugin.h plugin.h REQUIRED NO_CACHE ) @@ -173,36 +179,6 @@ function(find_tensorrt) ) endfunction() -macro(configure_tensorrt_python_plugin_header) - if(ARG_INSTALL_DIR) - find_file( - trt_python_plugin_header - NAMES plugin.h - HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl - PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl - REQUIRED - NO_CMAKE_PATH NO_DEFAULT_PATH - NO_CACHE - ) - else() - find_path( - trt_python_plugin_header - NAMES plugin.h - REQUIRED - NO_CACHE - ) - endif() - file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/include/nvinfer") - file(COPY_FILE "${trt_python_plugin_header}" - "${CMAKE_BINARY_DIR}/include/nvinfer/trt_plugin_python.h" - ONLY_IF_DIFFERENT - RESULT copy_result - ) - if(copy_result) - message(FATAL_ERROR "failed to copy TensorRT QDP plugin header: ${copy_result}") - endif() -endmacro() - #------------------------------------------------------------------------------------- # Download and add DLPack to the build (header only) #------------------------------------------------------------------------------------- diff --git a/mlir-tensorrt/build_tools/patches/mlir/0005-mlir-memref-Fix-memref.global-overly-constrained-ver.patch b/mlir-tensorrt/build_tools/patches/mlir/0005-mlir-memref-Fix-memref.global-overly-constrained-ver.patch index 5a8047361..96c677373 100644 --- a/mlir-tensorrt/build_tools/patches/mlir/0005-mlir-memref-Fix-memref.global-overly-constrained-ver.patch +++ b/mlir-tensorrt/build_tools/patches/mlir/0005-mlir-memref-Fix-memref.global-overly-constrained-ver.patch @@ -1,18 +1,17 @@ -From f014186374bb3e71d44648781dc03aaefd29f0d5 Mon Sep 17 00:00:00 2001 -From: Christopher Bate -Date: Fri, 10 May 2024 22:39:44 -0600 -Subject: [PATCH 05/10] [mlir][memref] Fix memref.global overly constrained - verifier check +From 07f534dac7f915a496265c14745c0bc643185efe Mon Sep 17 00:00:00 2001 +From: Sagar Shelke +Date: Tue, 1 Jul 2025 00:17:04 +0000 +Subject: [PATCH] Apply patch 0005 --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp -index 4f75b7618d63..f12f41437759 100644 +index 11597505e788..66ce4b3638b0 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp -@@ -1117,7 +1117,7 @@ struct DimOfMemRefReshape : public OpRewritePattern { +@@ -1124,7 +1124,7 @@ struct DimOfMemRefReshape : public OpRewritePattern { } } // else dim.getIndex is a block argument to reshape->getBlock and // dominates reshape @@ -21,7 +20,7 @@ index 4f75b7618d63..f12f41437759 100644 else if (dim->getBlock() != reshape->getBlock() && !dim.getIndex().getParentRegion()->isProperAncestor( reshape->getParentRegion())) { -@@ -1607,9 +1607,11 @@ LogicalResult GlobalOp::verify() { +@@ -1614,9 +1614,11 @@ LogicalResult GlobalOp::verify() { // Check that the type of the initial value is compatible with the type of // the global variable. if (auto elementsAttr = llvm::dyn_cast(initValue)) { @@ -37,5 +36,5 @@ index 4f75b7618d63..f12f41437759 100644 << tensorType << ", but was of type " << initType; } -- -2.46.0 +2.48.1 diff --git a/mlir-tensorrt/build_tools/patches/mlir/0006-mlir-emitc-Fix-two-EmitC-bugs.patch b/mlir-tensorrt/build_tools/patches/mlir/0006-mlir-emitc-Fix-two-EmitC-bugs.patch index 72196e20b..cfe9a6d17 100644 --- a/mlir-tensorrt/build_tools/patches/mlir/0006-mlir-emitc-Fix-two-EmitC-bugs.patch +++ b/mlir-tensorrt/build_tools/patches/mlir/0006-mlir-emitc-Fix-two-EmitC-bugs.patch @@ -1,7 +1,7 @@ -From 47c84211f72fb407d72e2c8f87019802cda30432 Mon Sep 17 00:00:00 2001 -From: Christopher Bate -Date: Mon, 27 Jan 2025 08:28:33 +0000 -Subject: [PATCH 06/10] [mlir][emitc] Fix two EmitC bugs +From d81aaed8cb0190807dbe378e469fde53101f32eb Mon Sep 17 00:00:00 2001 +From: Sagar Shelke +Date: Tue, 1 Jul 2025 00:19:03 +0000 +Subject: [PATCH] Apply patch 0006 --- .../mlir/Conversion/FuncToEmitC/FuncToEmitC.h | 4 +- @@ -112,10 +112,10 @@ index 0b97f2641ad0..d2f368a7148d 100644 if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp -index 01de0e41f203..4f600f92ba6d 100644 +index b00820ffc542..803c58cc35c6 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp -@@ -273,6 +273,7 @@ private: +@@ -282,6 +282,7 @@ private: ExpressionOp emittedExpression; SmallVector emittedExpressionPrecedence; @@ -123,7 +123,7 @@ index 01de0e41f203..4f600f92ba6d 100644 void pushExpressionPrecedence(int precedence) { emittedExpressionPrecedence.push_back(precedence); } -@@ -670,12 +671,14 @@ static LogicalResult printOperation(CppEmitter &emitter, +@@ -695,12 +696,14 @@ static LogicalResult printOperation(CppEmitter &emitter, if (auto t = dyn_cast(attr)) { // Index attributes are treated specially as operand index. if (t.getType().isIndex()) { @@ -143,5 +143,5 @@ index 01de0e41f203..4f600f92ba6d 100644 } } -- -2.46.0 +2.48.1 diff --git a/mlir-tensorrt/build_tools/patches/mlir/0008-MLIR-Remove-unnecessary-include-from-MathToEmitC.h-t.patch b/mlir-tensorrt/build_tools/patches/mlir/0008-MLIR-Remove-unnecessary-include-from-MathToEmitC.h-t.patch deleted file mode 100644 index 438e39837..000000000 --- a/mlir-tensorrt/build_tools/patches/mlir/0008-MLIR-Remove-unnecessary-include-from-MathToEmitC.h-t.patch +++ /dev/null @@ -1,29 +0,0 @@ -From 75f0d527fe5dd23c9281c7240b0f54556648e2a7 Mon Sep 17 00:00:00 2001 -From: Tomer Solomon -Date: Mon, 3 Feb 2025 11:51:42 +0200 -Subject: [PATCH 08/10] [MLIR] Remove unnecessary include from MathToEmitC.h to - fix build issue (#125466) - -This removes the unnecessary inclusion of mlir/Dialect/EmitC/IR/EmitC.h -from MathToEmitC.h, which caused a build failure due to a missing -EmitCEnums.h.inc. The include was not needed, and removing it resolves -the issue without requiring additional dependencies. ---- - mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h | 1 - - 1 file changed, 1 deletion(-) - -diff --git a/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h b/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h -index 0fc33bf790be..c61773026ca5 100644 ---- a/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h -+++ b/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h -@@ -8,7 +8,6 @@ - - #ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H - #define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H --#include "mlir/Dialect/EmitC/IR/EmitC.h" - namespace mlir { - class RewritePatternSet; - namespace emitc { --- -2.46.0 - diff --git a/mlir-tensorrt/build_tools/patches/mlir/0009-mlir-Support-FileLineColRange-in-LLVM-debug-translat.patch b/mlir-tensorrt/build_tools/patches/mlir/0009-mlir-Support-FileLineColRange-in-LLVM-debug-translat.patch index 84fe52af5..1593dfe9d 100644 --- a/mlir-tensorrt/build_tools/patches/mlir/0009-mlir-Support-FileLineColRange-in-LLVM-debug-translat.patch +++ b/mlir-tensorrt/build_tools/patches/mlir/0009-mlir-Support-FileLineColRange-in-LLVM-debug-translat.patch @@ -1,15 +1,14 @@ -From 51c99ccf1a291295aed12a36395760026c268cbb Mon Sep 17 00:00:00 2001 -From: Christopher Bate -Date: Tue, 11 Mar 2025 22:34:24 +0000 -Subject: [PATCH 09/10] [mlir] Support FileLineColRange in LLVM debug - translation +From d679ee8c7978fe63321e90c5c6583b604ca3d1a5 Mon Sep 17 00:00:00 2001 +From: Sagar Shelke +Date: Tue, 1 Jul 2025 00:25:43 +0000 +Subject: [PATCH] Apply patch 0009 --- mlir/lib/Target/LLVMIR/DebugTranslation.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp -index cf734de49acd..c55d9a204468 100644 +index 1d3ed6f3262f..93e1d08faf4f 100644 --- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp @@ -547,6 +547,14 @@ llvm::DILocation *DebugTranslation::translateLoc(Location loc, @@ -28,5 +27,5 @@ index cf734de49acd..c55d9a204468 100644 ArrayRef locations = fusedLoc.getLocations(); -- -2.46.0 +2.48.1 diff --git a/mlir-tensorrt/build_tools/patches/mlir/0010-MLIR-Fix-LLVMIRTransforms-build-failure-125485.patch b/mlir-tensorrt/build_tools/patches/mlir/0010-MLIR-Fix-LLVMIRTransforms-build-failure-125485.patch deleted file mode 100644 index ff9bde258..000000000 --- a/mlir-tensorrt/build_tools/patches/mlir/0010-MLIR-Fix-LLVMIRTransforms-build-failure-125485.patch +++ /dev/null @@ -1,44 +0,0 @@ -From c5bef25c87a0e5a2377e6909b812acc9d026c7a2 Mon Sep 17 00:00:00 2001 -From: Thomas Preud'homme -Date: Mon, 10 Feb 2025 19:37:58 +0000 -Subject: [PATCH 10/10] [MLIR] Fix LLVMIRTransforms build failure (#125485) - -lib/libMLIRLLVMIRTransforms.a fails to build from scratch with the -following error: -In file included from llvm/include/llvm/Frontend/OpenMP/OMPConstants.h:19, - from llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h:19, - from mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h:26, - from mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h:24, - from mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp:17: -llvm/include/llvm/Frontend/OpenMP/OMP.h:16:10: -fatal error: llvm/Frontend/OpenMP/OMP.h.inc: No such file or directory - -Use a forward declaration for OpenMPIRBuilder in ModuleTranslation.h to -avoid pulling OpenMP frontend header that require generated headers. ---- - mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h | 5 +++-- - 1 file changed, 3 insertions(+), 2 deletions(-) - -diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h -index 1b62437761ed..6f4a5e1d347a 100644 ---- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h -+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h -@@ -23,12 +23,13 @@ - #include "mlir/Target/LLVMIR/TypeToLLVM.h" - - #include "llvm/ADT/SetVector.h" --#include "llvm/Frontend/OpenMP/OMPIRBuilder.h" -+#include "llvm/IR/FPEnv.h" - - namespace llvm { - class BasicBlock; --class IRBuilderBase; - class Function; -+class IRBuilderBase; -+class OpenMPIRBuilder; - class Value; - } // namespace llvm - --- -2.46.0 - diff --git a/mlir-tensorrt/build_tools/patches/mlir/0011-MLIR-Fix-bufferization-interface-for-tensor-reshape.patch b/mlir-tensorrt/build_tools/patches/mlir/0011-MLIR-Fix-bufferization-interface-for-tensor-reshape.patch new file mode 100644 index 000000000..69dc4b8bf --- /dev/null +++ b/mlir-tensorrt/build_tools/patches/mlir/0011-MLIR-Fix-bufferization-interface-for-tensor-reshape.patch @@ -0,0 +1,71 @@ +From 3438dfc7ff8863bdd8c34e41d0cade5ca4581891 Mon Sep 17 00:00:00 2001 +From: Christopher Bate +Date: Wed, 12 Mar 2025 22:19:01 -0600 +Subject: [PATCH] [mlir][tensor] Fix bufferization interface for + 'tensor.reshape' (#128590) + +Previously, the BufferizableOpInterface implementation for +'tensor.reshape' +listed the 'shape' operand as an alias for the result tensor, causing +unnecessary conflicts with ops that "write" to the shape operand. +--- + .../BufferizableOpInterfaceImpl.cpp | 4 +++ + .../Dialect/Tensor/one-shot-bufferize.mlir | 27 +++++++++++++++++++ + 2 files changed, 31 insertions(+) + +diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +index a9ba662348a5..4ac6eca58696 100644 +--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp ++++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +@@ -860,6 +860,10 @@ struct ReshapeOpInterface + + AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { ++ // Only the 'source' operand aliases the result. ++ auto reshapeOp = cast(op); ++ if (reshapeOp.getSourceMutable() != opOperand) ++ return {}; + return {{op->getOpResult(0), BufferRelation::Equivalent}}; + } + +diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +index af4f84640890..2983cd30258a 100644 +--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir ++++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +@@ -398,6 +398,33 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> { + + // ----- + ++// CHECK-LABEL: func @tensor_reshape_aliasing ++// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index) ++func.func @tensor_reshape_aliasing(%arg0: index, %arg1: index) -> tensor { ++ %t1_static = arith.constant dense<0.> : tensor<10xf32> ++ // CHECK-DAG: %[[T1:.+]] = memref.cast ++ %t1 = tensor.cast %t1_static : tensor<10xf32> to tensor ++ ++ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index ++ %c0 = arith.constant 0 : index ++ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index ++ %c1 = arith.constant 1 : index ++ ++ // CHECK-DAG: %[[SHAPE:.+]] = memref.alloc() {{.*}} : memref<2xindex> ++ %shape = bufferization.alloc_tensor() : tensor<2xindex> ++ // CHECK: memref.store %[[ARG0]], %[[SHAPE]][%[[C0]]] ++ %shape.0 = tensor.insert %arg0 into %shape[%c0] : tensor<2xindex> ++ // CHECK: memref.store %[[ARG1]], %[[SHAPE]][%[[C1]]] ++ %shape.1 = tensor.insert %arg1 into %shape.0[%c1] : tensor<2xindex> ++ ++ // CHECK: %[[RESHAPED:.+]] = memref.reshape %[[T1]](%[[SHAPE]]) ++ %reshaped = tensor.reshape %t1(%shape.1) : (tensor, tensor<2xindex>) -> tensor ++ // CHECK: return %[[RESHAPED]] ++ return %reshaped : tensor ++} ++ ++// ----- ++ + // CHECK-LABEL: @reshape_with_non_identity_layout( + // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>, + // CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>, +-- +2.48.1 + diff --git a/mlir-tensorrt/build_tools/patches/stablehlo/0001-Fix-a-couple-missing-checks-for-static-shapes-in-sta.patch b/mlir-tensorrt/build_tools/patches/stablehlo/0001-Fix-a-couple-missing-checks-for-static-shapes-in-sta.patch index 822bd16c3..398c01489 100644 --- a/mlir-tensorrt/build_tools/patches/stablehlo/0001-Fix-a-couple-missing-checks-for-static-shapes-in-sta.patch +++ b/mlir-tensorrt/build_tools/patches/stablehlo/0001-Fix-a-couple-missing-checks-for-static-shapes-in-sta.patch @@ -1,8 +1,7 @@ -From b387be5903482200f4c36f64f8ed102c288c0c29 Mon Sep 17 00:00:00 2001 -From: Christopher Bate -Date: Wed, 27 Nov 2024 00:10:11 +0000 -Subject: [PATCH 1/7] Fix a couple missing checks for static shapes in - `stablehlo-aggressive-folder` +From 47b6e97db1681d3a7e94d4ceefc6cbe0fd62535e Mon Sep 17 00:00:00 2001 +From: Sagar Shelke +Date: Tue, 1 Jul 2025 21:46:35 +0000 +Subject: [PATCH] Apply patch --- .../stablehlo_aggressive_folder.mlir | 27 +++++++++++++------ @@ -10,12 +9,12 @@ Subject: [PATCH 1/7] Fix a couple missing checks for static shapes in 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir -index 5b21a10d..c90c89c6 100644 +index 2b778a9b..3137bc35 100644 --- a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir +++ b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir @@ -4,14 +4,17 @@ // AddOp - + // CHECK-LABEL: @add_fold_cst -func.func @add_fold_cst() -> (tensor, tensor) { +func.func @add_fold_cst() -> (tensor, tensor, tensor) { @@ -31,11 +30,11 @@ index 5b21a10d..c90c89c6 100644 + %2 = stablehlo.add %cst_2, %cst_2 : (tensor<1xf32>, tensor<1xf32>) -> tensor + return %0, %1, %2 : tensor, tensor, tensor } - + // ----- @@ -106,14 +109,17 @@ func.func @concatenate_fold() -> (tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, // MulOp - + // CHECK-LABEL: @mul_fold_cst -func.func @mul_fold_cst() -> (tensor, tensor) { +func.func @mul_fold_cst() -> (tensor, tensor, tensor) { @@ -51,11 +50,11 @@ index 5b21a10d..c90c89c6 100644 + %2 = stablehlo.multiply %cst_2, %cst_2 : (tensor<1xf32>, tensor<1xf32>) -> tensor + return %0, %1, %2 : tensor, tensor, tensor } - + // ----- @@ -122,16 +128,21 @@ func.func @mul_fold_cst() -> (tensor, tensor) { // SubtractOp - + // CHECK-LABEL: @subtract_fold_cst -func.func @subtract_fold_cst() -> (tensor, tensor) { +func.func @subtract_fold_cst() -> (tensor, tensor, tensor) { @@ -77,42 +76,42 @@ index 5b21a10d..c90c89c6 100644 + %2 = stablehlo.subtract %cst_4, %cst_5 : (tensor<1xf32>, tensor<1xf32>) -> tensor + return %0, %1, %2 : tensor, tensor, tensor } - + // ----- diff --git a/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp -index 2b5198b4..52a28e97 100644 +index 7cd4724d..cddbffc6 100644 --- a/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp +++ b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp -@@ -257,6 +257,9 @@ struct FoldAddOpPattern final : OpRewritePattern { - +@@ -266,6 +266,9 @@ struct FoldAddOpPattern final : OpRewritePattern { + LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op, PatternRewriter& rewriter) const override { -+ if (failed(validateResultTypeForEval(rewriter, op, op.getType()))) ++ if (failed(validateStaticShapeResult(rewriter, op, op.getType()))) + return failure(); + Value lhs = op.getLhs(); Value rhs = op.getRhs(); - -@@ -549,6 +552,9 @@ struct FoldMulOpPattern final : OpRewritePattern { - + +@@ -569,6 +572,9 @@ struct FoldMulOpPattern final : OpRewritePattern { + LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, PatternRewriter& rewriter) const override { -+ if (failed(validateResultTypeForEval(rewriter, op, op.getType()))) ++ if (failed(validateStaticShapeResult(rewriter, op, op.getType()))) + return failure(); + - auto elemType = op.getType().getElementType(); - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); -@@ -748,6 +754,9 @@ struct FoldSubtractOpPattern final - + TypedAttr lhsAttr; + matchPattern(op.getLhs(), m_Constant(&lhsAttr)); + +@@ -758,6 +764,9 @@ struct FoldSubtractOpPattern final + LogicalResult matchAndRewrite(mlir::stablehlo::SubtractOp op, PatternRewriter& rewriter) const override { -+ if (failed(validateResultTypeForEval(rewriter, op, op.getType()))) ++ if (failed(validateStaticShapeResult(rewriter, op, op.getType()))) + return failure(); + Value lhs = op.getLhs(); Value rhs = op.getRhs(); - --- -2.46.0 + +-- +2.48.1 diff --git a/mlir-tensorrt/build_tools/patches/stablehlo/0002-cmake-Update-usage-of-HandleLLVMOptions-and-LLVM_DEF.patch b/mlir-tensorrt/build_tools/patches/stablehlo/0002-cmake-Update-usage-of-HandleLLVMOptions-and-LLVM_DEF.patch index 3f36cd4e1..d37c72cd1 100644 --- a/mlir-tensorrt/build_tools/patches/stablehlo/0002-cmake-Update-usage-of-HandleLLVMOptions-and-LLVM_DEF.patch +++ b/mlir-tensorrt/build_tools/patches/stablehlo/0002-cmake-Update-usage-of-HandleLLVMOptions-and-LLVM_DEF.patch @@ -1,37 +1,8 @@ -From 4949f0f91bf256e01f13dae1ddcd2139d2c41d85 Mon Sep 17 00:00:00 2001 -From: Christopher Bate -Date: Sat, 15 Feb 2025 22:02:17 +0000 -Subject: [PATCH 2/7] [cmake] Update usage of `HandleLLVMOptions` and - `LLVM_DEFINITIONS` +From 25883f24b5e027eb0a09c041a6ac6192d8817042 Mon Sep 17 00:00:00 2001 +From: Sagar Shelke +Date: Tue, 1 Jul 2025 21:49:41 +0000 +Subject: [PATCH] Apply patch -This change attempts to resolve issues with use of `HandleLLVMOptions` -and `LLVM_DEFINITIONS`, see -https://github.com/llvm/llvm-project/issues/125779. - -Note that this is a breaking change because it could cause build -breakage for downstream users. As noted in the comments added to the -CMakeLists.txt file, there may not be one perfect CMake incantation -for setting Stablehlo's options that works for all users. - -Since it's easier to *add* compiler options at a specific scope than it is -to alter/remove options that Stablehlo itself is setting, this change -is hoisting responsibility to the user for setting any compiler -options previously provided by the `HandleLLVMOptions` call when -building in embedded mode. - -This means that if user was using -`FetchContent|add_subdirectory|CPMAddPackage` to build Stablehlo -in their project, they should invoke - -``` -find_package(LLVM CONFIG REQUIRED) -separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) -add_definitions(${LLVM_DEFINITIONS_LIST}) -include(HandleLLVMOptions) -``` - -in their project at the appropriate scope, or set desired flags in some -other manner. --- CMakeLists.txt | 40 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) @@ -105,5 +76,5 @@ index bf5f0172..2a119e01 100644 #------------------------------------------------------------------------------- # Sanitizer configuration -- -2.46.0 +2.48.1 diff --git a/mlir-tensorrt/build_tools/patches/stablehlo/0003-Don-t-insert-unnecessary-arith.index_cast-ops.patch b/mlir-tensorrt/build_tools/patches/stablehlo/0003-Don-t-insert-unnecessary-arith.index_cast-ops.patch deleted file mode 100644 index 964cb9c8e..000000000 --- a/mlir-tensorrt/build_tools/patches/stablehlo/0003-Don-t-insert-unnecessary-arith.index_cast-ops.patch +++ /dev/null @@ -1,53 +0,0 @@ -From 1e5096183747c9d41eec0d624726d25454f10f9c Mon Sep 17 00:00:00 2001 -From: Christopher Bate -Date: Mon, 10 Mar 2025 22:04:05 +0000 -Subject: [PATCH 3/7] Don't insert unnecessary `arith.index_cast` ops - ---- - .../StablehloAggressiveSimplification.cpp | 20 ++----------------- - 1 file changed, 2 insertions(+), 18 deletions(-) - -diff --git a/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp -index f32f8d66..8028f714 100644 ---- a/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp -+++ b/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp -@@ -394,34 +394,18 @@ struct DynamicIotaOpToBroadcast : public OpRewritePattern { - - auto iotaDimension = static_cast(iota.getIotaDimension()); - -- // Handle case where iota dimension is index, need to convert to/from i64 -- // to interop with slice. These canonicalize away if input is i64. -- auto convertedShape = rewriter.create( -- iota.getLoc(), -- RankedTensorType::get( -- cast(iota.getOutputShape().getType()).getShape(), -- rewriter.getI64Type()), -- iota.getOutputShape()); -- -+ Value convertedShape = iota.getOutputShape(); - auto slicedShape = rewriter.create( - iota.getLoc(), convertedShape, - rewriter.getDenseI64ArrayAttr(iotaDimension), - rewriter.getDenseI64ArrayAttr(iotaDimension + 1), - rewriter.getDenseI64ArrayAttr(1)); - -- auto convertedSlicedShape = rewriter.create( -- iota.getLoc(), -- RankedTensorType::get( -- {1}, -- cast(iota.getOutputShape().getType()).getElementType()), -- slicedShape); -- - auto iotaType = RankedTensorType::get({resultTy.getDimSize(iotaDimension)}, - resultTy.getElementType()); - - auto newIota = rewriter.create( -- iota.getLoc(), iotaType, convertedSlicedShape, -- rewriter.getI64IntegerAttr(0)); -+ iota.getLoc(), iotaType, slicedShape, rewriter.getI64IntegerAttr(0)); - - rewriter.replaceOpWithNewOp( - iota, resultTy, newIota, iota.getOutputShape(), --- -2.46.0 - diff --git a/mlir-tensorrt/build_tools/patches/stablehlo/0004-Fix-ZeroExtent-condition-in-simplification-pattern.patch b/mlir-tensorrt/build_tools/patches/stablehlo/0004-Fix-ZeroExtent-condition-in-simplification-pattern.patch index b1aeb59e2..f28774c3a 100644 --- a/mlir-tensorrt/build_tools/patches/stablehlo/0004-Fix-ZeroExtent-condition-in-simplification-pattern.patch +++ b/mlir-tensorrt/build_tools/patches/stablehlo/0004-Fix-ZeroExtent-condition-in-simplification-pattern.patch @@ -1,18 +1,17 @@ -From 94b386bd28b610a3218508c391acd926412e57f1 Mon Sep 17 00:00:00 2001 -From: Christopher Bate -Date: Mon, 10 Mar 2025 22:51:38 +0000 -Subject: [PATCH 4/7] Fix ZeroExtent condition in simplification pattern +From c404b2e11d49a660c29a7b63b1843baa15be8af6 Mon Sep 17 00:00:00 2001 +From: Sagar Shelke +Date: Tue, 1 Jul 2025 21:53:46 +0000 +Subject: [PATCH] Apply patch -Attribute doesn't have to be a DenseElementsAttr. --- .../optimization/StablehloAggressiveSimplificationPatterns.td | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td -index 9cbcc07c..60396cc4 100644 +index efae71b2..992ac5de 100644 --- a/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td +++ b/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td -@@ -94,7 +94,8 @@ def SortedDims : AttrConstraint< +@@ -98,7 +98,8 @@ def SortedDims : AttrConstraint< "is sorted dimensions">; def ZeroExtent : AttrConstraint< @@ -23,5 +22,5 @@ index 9cbcc07c..60396cc4 100644 /////////// -- -2.46.0 +2.48.1 diff --git a/mlir-tensorrt/build_tools/patches/stablehlo/0005-Fix-crash-on-ComplexType-in-PointwiseToLinalgMapConv.patch b/mlir-tensorrt/build_tools/patches/stablehlo/0005-Fix-crash-on-ComplexType-in-PointwiseToLinalgMapConv.patch deleted file mode 100644 index 95ba8ba3e..000000000 --- a/mlir-tensorrt/build_tools/patches/stablehlo/0005-Fix-crash-on-ComplexType-in-PointwiseToLinalgMapConv.patch +++ /dev/null @@ -1,95 +0,0 @@ -From 139c779d447d6163c51dbe9d8735b2062025f032 Mon Sep 17 00:00:00 2001 -From: Christopher Bate -Date: Fri, 21 Mar 2025 03:28:26 +0000 -Subject: [PATCH 5/7] Fix crash on ComplexType in PointwiseToLinalgMapConverter - ---- - .../conversions/linalg/tests/pointwise.mlir | 23 ++++++++++++++ - .../transforms/StablehloToLinalgPointwise.cpp | 30 +++++++++++++++---- - 2 files changed, 48 insertions(+), 5 deletions(-) - -diff --git a/stablehlo/conversions/linalg/tests/pointwise.mlir b/stablehlo/conversions/linalg/tests/pointwise.mlir -index 6dc76f24..7a9f71aa 100644 ---- a/stablehlo/conversions/linalg/tests/pointwise.mlir -+++ b/stablehlo/conversions/linalg/tests/pointwise.mlir -@@ -23,6 +23,29 @@ func.func @float_add(%lhs: tensor<2x2xf32>, - - // ----- - -+// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> -+// CHECK-LABEL: func @complex_add_const -+// CHECK-PRIMITIVE-LABEL: func @complex_add_const -+func.func @complex_add_const(%lhs: tensor<2x2xcomplex>, -+ %rhs: tensor<2x2xcomplex>) -+ -> tensor<2x2xcomplex> { -+ -+ // CHECK: %[[CST:.+]] = complex.constant [1.000000e-01 : f32, 2.000000e-01 : f32] : complex -+ // CHECK: linalg.generic -+ // CHECK: ^bb0(%[[IN:.+]]: complex, %[[OUT:.+]]: complex) -+ // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = complex.add %[[IN]], %[[CST]] -+ // CHECK: linalg.yield %[[RESULT]] -+ -+ // CHECK-PRIMITIVE: linalg.map -+ // CHECK-PRIMITIVE: complex.add -+ %cst = stablehlo.constant dense<(0.1, 0.2)> : tensor<2x2xcomplex> -+ %0 = "stablehlo.add"(%lhs, %cst) {someattr} -+ : (tensor<2x2xcomplex>, tensor<2x2xcomplex>) -> tensor<2x2xcomplex> -+ func.return %0 : tensor<2x2xcomplex> -+} -+ -+// ----- -+ - // CHECK-LABEL: func @float_add_dynamic_encoding - // CHECK-PRIMITIVE-LABEL: func @float_add_dynamic_encoding - func.func @float_add_dynamic_encoding( -diff --git a/stablehlo/conversions/linalg/transforms/StablehloToLinalgPointwise.cpp b/stablehlo/conversions/linalg/transforms/StablehloToLinalgPointwise.cpp -index 707db6a7..301dfdc2 100644 ---- a/stablehlo/conversions/linalg/transforms/StablehloToLinalgPointwise.cpp -+++ b/stablehlo/conversions/linalg/transforms/StablehloToLinalgPointwise.cpp -@@ -114,6 +114,28 @@ FailureOr checkOperandsAndResults( - return PointwiseConversionInfo{maxRank, resultTy}; - } - -+/// If `input` is a splat constant value, materialize the scalar splat -+/// value. Otherwise, return nullopt. -+std::optional materializeSplatScalarConstant(RewriterBase &rewriter, -+ Location loc, Value input) { -+ SplatElementsAttr attr; -+ Type elementType = mlir::getElementTypeOrSelf(input.getType()); -+ if (!matchPattern(input, m_Constant(&attr))) return {}; -+ if (isa(elementType)) { -+ return rewriter -+ .create(loc, elementType, -+ attr.getSplatValue()) -+ .getResult(); -+ } -+ if (isa(elementType)) { -+ return rewriter -+ .create(loc, elementType, -+ attr.getSplatValue()) -+ .getResult(); -+ } -+ return {}; -+} -+ - /// Converts a HLO operation to a linalg.map op that contains the corresponding - /// scalar operations. - template -@@ -160,11 +182,9 @@ struct PointwiseToLinalgMapConverter : OpConversionPattern { - SmallVector mappedInputs; - SmallVector scalarInputs; - for (Value input : adaptor.getOperands()) { -- DenseElementsAttr attr; -- if (matchPattern(input, m_Constant(&attr)) && attr.isSplat()) { -- scalarInputs.push_back(rewriter.create( -- loc, cast(input.getType()).getElementType(), -- attr.getSplatValue())); -+ if (std::optional splatVal = -+ materializeSplatScalarConstant(rewriter, loc, input)) { -+ scalarInputs.push_back(*splatVal); - } else if (getRank(input) == maxRank) { - mappedInputs.push_back(coerceTensorShape( - rewriter, loc, cast>(input), --- -2.46.0 - diff --git a/mlir-tensorrt/build_tools/patches/stablehlo/0006-Remove-explicit-use-of-LLVMSupport.patch b/mlir-tensorrt/build_tools/patches/stablehlo/0006-Remove-explicit-use-of-LLVMSupport.patch index 697b01338..1cbf27597 100644 --- a/mlir-tensorrt/build_tools/patches/stablehlo/0006-Remove-explicit-use-of-LLVMSupport.patch +++ b/mlir-tensorrt/build_tools/patches/stablehlo/0006-Remove-explicit-use-of-LLVMSupport.patch @@ -1,7 +1,7 @@ -From fafd462cdae170e3b615cb559e907b30840d7cf7 Mon Sep 17 00:00:00 2001 -From: Christopher Bate -Date: Fri, 21 Mar 2025 04:00:05 +0000 -Subject: [PATCH 6/7] Remove explicit use of LLVMSupport +From ac6ac8785d9a6429304d1ccd6f9c894373182dc0 Mon Sep 17 00:00:00 2001 +From: Sagar Shelke +Date: Tue, 1 Jul 2025 21:54:20 +0000 +Subject: [PATCH] Apply patch --- stablehlo/transforms/conversions/CMakeLists.txt | 1 - @@ -20,5 +20,5 @@ index e1da2c8b..7aac80f1 100644 MLIRSupport MLIRTransformUtils -- -2.46.0 +2.48.1 diff --git a/mlir-tensorrt/build_tools/patches/stablehlo/0007-Fix-circular-dependence-between-StablehloPasses-and-.patch b/mlir-tensorrt/build_tools/patches/stablehlo/0007-Fix-circular-dependence-between-StablehloPasses-and-.patch index 4f3978f75..f5667e314 100644 --- a/mlir-tensorrt/build_tools/patches/stablehlo/0007-Fix-circular-dependence-between-StablehloPasses-and-.patch +++ b/mlir-tensorrt/build_tools/patches/stablehlo/0007-Fix-circular-dependence-between-StablehloPasses-and-.patch @@ -1,17 +1,15 @@ -From e0b197588de8367b729b726009be028da3ed74a7 Mon Sep 17 00:00:00 2001 -From: Christopher Bate -Date: Sun, 23 Mar 2025 01:54:33 +0000 -Subject: [PATCH 7/7] Fix circular dependence between StablehloPasses and - StablehloOptimizationPasses +From f20e3614c528d44e0d877cf6cb29ffa90a5ceab8 Mon Sep 17 00:00:00 2001 +From: Sagar Shelke +Date: Tue, 1 Jul 2025 21:54:40 +0000 +Subject: [PATCH] Apply patch -Fixes build when BUILD_SHARED_LIBS=ON. --- stablehlo/transforms/CMakeLists.txt | 12 +++++++++++- stablehlo/transforms/optimization/CMakeLists.txt | 1 + 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/stablehlo/transforms/CMakeLists.txt b/stablehlo/transforms/CMakeLists.txt -index 4787369d..50c87304 100644 +index a7fada9b..3f986735 100644 --- a/stablehlo/transforms/CMakeLists.txt +++ b/stablehlo/transforms/CMakeLists.txt @@ -39,6 +39,16 @@ set(LLVM_TARGET_DEFINITIONS VhloToVersionPatterns.td) @@ -31,15 +29,15 @@ index 4787369d..50c87304 100644 add_mlir_dialect_library(StablehloPasses PARTIAL_SOURCES_INTENDED -@@ -59,7 +69,6 @@ add_mlir_dialect_library(StablehloPasses - StablehloRefineShapes.cpp +@@ -60,7 +70,6 @@ add_mlir_dialect_library(StablehloPasses + StablehloWrapInComposite.cpp VhloLegalizeToStablehlo.cpp VhloToVersion.cpp - PassUtils.cpp DEPENDS ChloDecompositionPatternsIncGen -@@ -90,6 +99,7 @@ add_mlir_dialect_library(StablehloPasses +@@ -91,6 +100,7 @@ add_mlir_dialect_library(StablehloPasses StablehloLinalgTransforms StablehloOps StablehloOptimizationPasses @@ -59,5 +57,5 @@ index d43d77be..d063a49d 100644 StablehloTypeInference ) -- -2.46.0 +2.48.1 diff --git a/mlir-tensorrt/build_tools/patches/torch_mlir/0001-cmake-Allow-finding-Stablehlo-via-find_package.patch b/mlir-tensorrt/build_tools/patches/torch_mlir/0001-cmake-Allow-finding-Stablehlo-via-find_package.patch deleted file mode 100644 index 7d5f6edaf..000000000 --- a/mlir-tensorrt/build_tools/patches/torch_mlir/0001-cmake-Allow-finding-Stablehlo-via-find_package.patch +++ /dev/null @@ -1,60 +0,0 @@ -From 1bbf94b5f9d1aa3e0ef2c85e3b1b93dedd11aca1 Mon Sep 17 00:00:00 2001 -From: Christopher Bate -Date: Fri, 14 Feb 2025 00:39:36 +0000 -Subject: [PATCH 1/3] [cmake] Allow finding Stablehlo via 'find_package' - ---- - CMakeLists.txt | 30 ++++++++++++++++++------------ - 1 file changed, 18 insertions(+), 12 deletions(-) - -diff --git a/CMakeLists.txt b/CMakeLists.txt -index c0f94046..2764cd07 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -49,10 +49,11 @@ if(TORCH_MLIR_ENABLE_STABLEHLO) - endif() - # It is possible that both stablehlo and torch_mlir projects are used in some compiler project. - # In this case, we don't want to use stablehlo that is downloaded by torch_mlir (in external/stablehlo) --# folder but instead want to use stablehlo that is part of top level compiler project. --# With TORCH_MLIR_USE_EXTERNAL_STABLEHLO enables, it is assumed that top level compiler project makes --# stablehlo targets AND includes available (for example with `add_subdirectory` and `include_directories`). --option(TORCH_MLIR_USE_EXTERNAL_STABLEHLO "Use stablehlo from top level project" OFF) -+# folder but instead stablehlo that is part of top level compiler project. -+# TORCH_MLIR_EXTERNAL_STABLEHLO_DIR represents stablehlo directory (/stablehlo) -+# that is included in torch_mlir. It is assumed that top level compiler project makes -+# stablehlo targets available (for example with `add_subdirectory`) and thus they are not added. -+set(TORCH_MLIR_EXTERNAL_STABLEHLO_DIR "" CACHE STRING "Path to stablehlo dir from super project") - - option(TORCH_MLIR_ENABLE_TOSA "Add TOSA support" ON) - if(TORCH_MLIR_ENABLE_TOSA) -@@ -249,14 +250,19 @@ endif() - # Getting this wrong results in building large parts of the stablehlo - # project that we don't actually depend on. Further some of those parts - # do not even compile on all platforms. --# Only configure StableHLO if it isn't provided from a top-level project --if (TORCH_MLIR_ENABLE_STABLEHLO AND NOT TORCH_MLIR_USE_EXTERNAL_STABLEHLO) -- set(STABLEHLO_BUILD_EMBEDDED ON) -- set(STABLEHLO_ENABLE_BINDINGS_PYTHON ON) -- add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo -- ${CMAKE_CURRENT_BINARY_DIR}/stablehlo -- EXCLUDE_FROM_ALL) -- include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo) -+if (TORCH_MLIR_ENABLE_STABLEHLO) -+ if (TORCH_MLIR_EXTERNAL_STABLEHLO_DIR STREQUAL "find_package") -+ find_package(Stablehlo REQUIRED) -+ elseif (TORCH_MLIR_EXTERNAL_STABLEHLO_DIR) -+ include_directories(${TORCH_MLIR_EXTERNAL_STABLEHLO_DIR}) -+ else() -+ set(STABLEHLO_BUILD_EMBEDDED ON) -+ set(STABLEHLO_ENABLE_BINDINGS_PYTHON ON) -+ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo -+ ${CMAKE_CURRENT_BINARY_DIR}/stablehlo -+ EXCLUDE_FROM_ALL) -+ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo) -+ endif() - endif() - - #------------------------------------------------------------------------------- --- -2.46.0 - diff --git a/mlir-tensorrt/build_tools/patches/torch_mlir/0002-Make-compatible-with-more-recent-Stablehlo-version.patch b/mlir-tensorrt/build_tools/patches/torch_mlir/0002-Make-compatible-with-more-recent-Stablehlo-version.patch deleted file mode 100644 index 70b61cab0..000000000 --- a/mlir-tensorrt/build_tools/patches/torch_mlir/0002-Make-compatible-with-more-recent-Stablehlo-version.patch +++ /dev/null @@ -1,41 +0,0 @@ -From bcf9dd3472cb5b45a25843f3956fb92a2b38e9b3 Mon Sep 17 00:00:00 2001 -From: Christopher Bate -Date: Fri, 21 Mar 2025 16:40:57 +0000 -Subject: [PATCH 2/3] Make compatible with more recent Stablehlo version - ---- - lib/InitAll.cpp | 4 +++- - 1 file changed, 3 insertions(+), 1 deletion(-) - -diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp -index d9096929..89773e60 100644 ---- a/lib/InitAll.cpp -+++ b/lib/InitAll.cpp -@@ -20,6 +20,7 @@ - #include "mlir/Dialect/Tensor/IR/Tensor.h" - #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" - #include "mlir/IR/Dialect.h" -+#include "stablehlo/transforms/optimization/Passes.h" - #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" - #include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h" - #include "torch-mlir/Conversion/Passes.h" -@@ -32,6 +33,7 @@ - - #ifdef TORCH_MLIR_ENABLE_STABLEHLO - #include "stablehlo/conversions/linalg/transforms/Passes.h" -+#include "stablehlo/transforms/optimization/Passes.h" - #include "stablehlo/transforms/Passes.h" - #endif - -@@ -72,7 +74,7 @@ void mlir::torch::registerAllPasses() { - - #ifdef TORCH_MLIR_ENABLE_STABLEHLO - mlir::stablehlo::registerStablehloLegalizeToLinalgPass(); -- mlir::stablehlo::registerStablehloAggressiveSimplificationPass(); -+ mlir::stablehlo::registerOptimizationPasses(); - mlir::stablehlo::registerStablehloRefineShapesPass(); - mlir::stablehlo::registerStablehloConvertToSignlessPass(); - mlir::stablehlo::registerShapeLegalizeToStablehloPass(); --- -2.46.0 - diff --git a/mlir-tensorrt/build_tools/patches/torch_mlir/0003-Fix-some-configuration-paths-in-LIT-cfg.patch b/mlir-tensorrt/build_tools/patches/torch_mlir/0003-Fix-some-configuration-paths-in-LIT-cfg.patch deleted file mode 100644 index b69b7567e..000000000 --- a/mlir-tensorrt/build_tools/patches/torch_mlir/0003-Fix-some-configuration-paths-in-LIT-cfg.patch +++ /dev/null @@ -1,58 +0,0 @@ -From 7f7eff24f303429a8258af53e72f08707ce9de55 Mon Sep 17 00:00:00 2001 -From: Christopher Bate -Date: Fri, 21 Mar 2025 16:41:25 +0000 -Subject: [PATCH 3/3] Fix some configuration paths in LIT cfg - ---- - test/CMakeLists.txt | 8 ++++++++ - test/lit.cfg.py | 2 +- - test/lit.site.cfg.py.in | 1 + - 3 files changed, 10 insertions(+), 1 deletion(-) - -diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt -index dbfa86aa..c84fec61 100644 ---- a/test/CMakeLists.txt -+++ b/test/CMakeLists.txt -@@ -4,6 +4,14 @@ llvm_canonicalize_cmake_booleans( - TORCH_MLIR_ENABLE_STABLEHLO - ) - -+# Set the tools directory variable. -+get_target_property(TORCH_MLIR_BIN_DIR torch-mlir-opt RUNTIME_OUTPUT_DIRECTORY) -+ -+# If the property wasn't set, fall back to default or define your own -+if(NOT TORCH_MLIR_BIN_DIR OR TORCH_MLIR_BIN_DIR STREQUAL "TORCH_MLIR_BIN_DIR-NOTFOUND") -+ set(TORCH_MLIR_BIN_DIR "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}") -+endif() -+ - configure_lit_site_cfg( - ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in - ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py -diff --git a/test/lit.cfg.py b/test/lit.cfg.py -index 4cdd029e..660cb730 100644 ---- a/test/lit.cfg.py -+++ b/test/lit.cfg.py -@@ -57,7 +57,7 @@ config.test_source_root = os.path.dirname(__file__) - - # test_exec_root: The root path where tests should be run. - config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test") --config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, "bin") -+config.standalone_tools_dir = config.torch_mlir_bin_dir - - # Tweak the PATH to include the tools dir. - llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True) -diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in -index 7ace00cb..5ceda4fe 100644 ---- a/test/lit.site.cfg.py.in -+++ b/test/lit.site.cfg.py.in -@@ -4,6 +4,7 @@ import sys - - config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@ - config.torch_mlir_obj_root = "@TORCH_MLIR_BINARY_DIR@" -+config.torch_mlir_bin_dir = "@TORCH_MLIR_BIN_DIR@" - config.torch_mlir_python_packages_dir = "@TORCH_MLIR_PYTHON_PACKAGES_DIR@" - config.torch_mlir_enable_refbackend = @TORCH_MLIR_ENABLE_REFBACKEND@ - config.host_os = "@HOST_OS@" --- -2.46.0 - diff --git a/mlir-tensorrt/common/CMakeLists.txt b/mlir-tensorrt/common/CMakeLists.txt new file mode 100644 index 000000000..e717f4c46 --- /dev/null +++ b/mlir-tensorrt/common/CMakeLists.txt @@ -0,0 +1,38 @@ +cmake_minimum_required(VERSION 3.25) +project(mlir-tensorrt-common LANGUAGES CXX) + +# Depdendencies +find_package(LLVM REQUIRED CONFIG) +find_package(MLIR REQUIRED CONFIG) +include(HandleLLVMOptions) +include_directories(${LLVM_INCLUDE_DIRS}) +include_directories(${MLIR_INCLUDE_DIRS}) + +if(MLIR_TRT_TARGET_TENSORRT) + find_package(TensorRT REQUIRED) +endif() + +find_package(CUDAToolkit REQUIRED) + +set(MLIR_TENSORRT_COMMON_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +set(MLIR_TENSORRT_COMMON_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) + +include_directories(include ${CMAKE_CURRENT_BINARY_DIR}/include) + +add_library(MLIRTensorRTCommonIncludes INTERFACE) +target_include_directories(MLIRTensorRTCommonIncludes INTERFACE + "$" + "$" +) + +add_subdirectory(include/mlir-tensorrt-common) +add_subdirectory(lib) + +install(TARGETS MLIRTensorRTCommonIncludes + EXPORT MLIRTensorRTCommonTargets + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} +) + diff --git a/mlir-tensorrt/common/include/mlir-tensorrt-common/CMakeLists.txt b/mlir-tensorrt/common/include/mlir-tensorrt-common/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRegistration.h b/mlir-tensorrt/common/include/mlir-tensorrt-common/Conversion/Passes.h similarity index 51% rename from mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRegistration.h rename to mlir-tensorrt/common/include/mlir-tensorrt-common/Conversion/Passes.h index e0e101e84..0c2a6dcaf 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRegistration.h +++ b/mlir-tensorrt/common/include/mlir-tensorrt-common/Conversion/Passes.h @@ -1,6 +1,6 @@ -//===- LuaRegistration.h ----------------------------------------*- C++ -*-===// +//===- Passes.h -------------------------------------------------*- C++ -*-===// // -// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. // All rights reserved. // SPDX-License-Identifier: Apache-2.0 // @@ -18,22 +18,22 @@ // //===----------------------------------------------------------------------===// /// -/// Registration for the Lua runtime methods. +/// This file contains the declarations for the common conversion passes. /// //===----------------------------------------------------------------------===// +#ifndef MLIR_TENSORRT_COMMON_CONVERSION_PASSES +#define MLIR_TENSORRT_COMMON_CONVERSION_PASSES -#include "mlir-executor/Runtime/API/API.h" +#include "mlir/Pass/Pass.h" +#include -struct lua_State; - -namespace mlirtrt::runtime { -/// Register various external functions with the given Lua state using a -/// directly specified device number, total device count, and a pre-determined -/// NCCL uuid. -void registerLuaRuntimeMethods(lua_State *state, - const RuntimeSessionOptions &options, - PinnedMemoryAllocator *pinnedMemoryAllocator, - AllocTracker *allocTracker, - ResourceTracker *resourceTracker); +//===----------------------------------------------------------------------===// +// Add Tablegen'd pass declarations and registration methods. +//===----------------------------------------------------------------------===// +namespace mlir { +#define GEN_PASS_DECL +#define GEN_PASS_REGISTRATION +#include "mlir-tensorrt-common/Conversion/Passes.h.inc" +} // namespace mlir -} // namespace mlirtrt::runtime +#endif // MLIR_TENSORRT_COMMON_CONVERSION_PASSES diff --git a/mlir-tensorrt/common/include/mlir-tensorrt-common/Conversion/Passes.td b/mlir-tensorrt/common/include/mlir-tensorrt-common/Conversion/Passes.td new file mode 100644 index 000000000..2fcc65a92 --- /dev/null +++ b/mlir-tensorrt/common/include/mlir-tensorrt-common/Conversion/Passes.td @@ -0,0 +1,18 @@ +#ifndef MLIR_TENSORRT_COMMON_CONVERSION_PASSES +#define MLIR_TENSORRT_COMMON_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertToLoops : Pass<"convert-to-loops"> { + let summary = "Convert a LoopLikeOpInterface to loops"; + let description = [{ + This pass converts a LoopLikeOpInterface to loops. + }]; + + let dependentDialects = [ + "::mlir::tensor::TensorDialect", + "::mlir::scf::SCFDialect", + ]; +} + +#endif // MLIR_TENSORRT_COMMON_CONVERSION_PASSES diff --git a/mlir-tensorrt/common/include/mlir-tensorrt-common/Dialect/EmitCExt/IR/DataLayoutImpl.h b/mlir-tensorrt/common/include/mlir-tensorrt-common/Dialect/EmitCExt/IR/DataLayoutImpl.h new file mode 100644 index 000000000..bc1930298 --- /dev/null +++ b/mlir-tensorrt/common/include/mlir-tensorrt-common/Dialect/EmitCExt/IR/DataLayoutImpl.h @@ -0,0 +1,36 @@ +//===- DataLayoutImpl.h -----------------------------------------*- C++ -*-===// +// +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// This file contains the declarations for the DataLayout extensions to +/// the EmitC dialect. +/// TODO: These interfaces should be upstreamed to the EmitC dialect so that +/// external models are not required. +/// +//===----------------------------------------------------------------------===// +#ifndef MLIR_TENSORRT_COMMON_DIALECT_EMITCEXT_IR_DATALAYOUTIMPL_H +#define MLIR_TENSORRT_COMMON_DIALECT_EMITCEXT_IR_DATALAYOUTIMPL_H + +#include "mlir/IR/DialectRegistry.h" + +namespace mlir::emitc_ext { +void registerDataLayoutInterfaceExternalModels(DialectRegistry ®istry); +} + +#endif // MLIR_TENSORRT_COMMON_DIALECT_EMITCEXT_IR_DATALAYOUTIMPL_H \ No newline at end of file diff --git a/mlir-tensorrt/common/include/mlir-tensorrt-common/Dialect/LinalgExt/Transforms/ToLoopsOpInterfaceImpl.h b/mlir-tensorrt/common/include/mlir-tensorrt-common/Dialect/LinalgExt/Transforms/ToLoopsOpInterfaceImpl.h new file mode 100644 index 000000000..1d4864e2a --- /dev/null +++ b/mlir-tensorrt/common/include/mlir-tensorrt-common/Dialect/LinalgExt/Transforms/ToLoopsOpInterfaceImpl.h @@ -0,0 +1,53 @@ +//===- ToLoopsOpInterfaceImpl.h ---------------------------------*- C++ -*-===// +// +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// This file contains the declarations for the ToLoopsOpInterface extensions to +/// the Linalg dialect. +/// +//===----------------------------------------------------------------------===// +#ifndef MLIR_TENSORRT_COMMON_DIALECT_LINALGEXT_TRANSFORMS_TOLOOPSOPINTERFACEIMPL +#define MLIR_TENSORRT_COMMON_DIALECT_LINALGEXT_TRANSFORMS_TOLOOPSOPINTERFACEIMPL + +#include "mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.h" +#include "mlir/IR/DialectRegistry.h" + +namespace mlir::linalg { +class LinalgOp; +} + +namespace mlir::scf { +class ForOp; +} + +namespace mlir::linalg_ext { + +/// Register the ToLoopsOpInterface external models for GenericOp. For other +/// kinds of operations that are LinalgOps, we don't register an external model +/// because there are so many; instead use the below function to perform +/// conversion. +void registerToLoopsOpInterfaceExternalModels(DialectRegistry ®istry); + +/// Convert a LinalgOp (on tensors) to SCF loops. +FailureOr> +convertLinalgOpToLoops(RewriterBase &rewriter, linalg::LinalgOp op); + +} // namespace mlir::linalg_ext + +#endif // MLIR_TENSORRT_COMMON_DIALECT_LINALGEXT_TRANSFORMS_TOLOOPSOPINTERFACEIMPL diff --git a/mlir-tensorrt/common/include/mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.h b/mlir-tensorrt/common/include/mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.h new file mode 100644 index 000000000..741bf2132 --- /dev/null +++ b/mlir-tensorrt/common/include/mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.h @@ -0,0 +1,38 @@ +//===- ToLoopsOpInterface.h -------------------------------*- C++ -*-===// +// +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// This file contains the declarations for the `ToLoopsOpInterface` interface. +/// +//===----------------------------------------------------------------------===// +#ifndef MLIR_TENSORRT_COMMON_INTERFACES_TOLOOPSOPINTERFACE +#define MLIR_TENSORRT_COMMON_INTERFACES_TOLOOPSOPINTERFACE + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +class RewriterBase; +namespace scf { +class ForOp; +} // namespace scf +} // namespace mlir + +#include "mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.h.inc" + +#endif // MLIR_TENSORRT_COMMON_INTERFACES_TOLOOPSOPINTERFACE diff --git a/mlir-tensorrt/common/include/mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.td b/mlir-tensorrt/common/include/mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.td new file mode 100644 index 000000000..3d6f0eb11 --- /dev/null +++ b/mlir-tensorrt/common/include/mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.td @@ -0,0 +1,28 @@ +#ifndef MLIR_TENSORRT_COMMON_INTERFACES_TOLOOPSOPINTERFACE +#define MLIR_TENSORRT_COMMON_INTERFACES_TOLOOPSOPINTERFACE + +include "mlir/IR/OpBase.td" + +def ToLoopsOpInterface : OpInterface<"ToLoopsOpInterface"> { + let description = "Interface for lowering to loops"; + + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Lower the operation to a loop nest. Returns + the outermost loop that should replace the original + op, but does not actually perform the replacement. + }], + "::mlir::FailureOr<::mlir::Operation*>", + "lowerToLoops", + (ins "::mlir::RewriterBase&":$rewriter), + "", + [{ + llvm_unreachable("Not implemented"); + }] + > + ]; +} + +#endif // MLIR_TENSORRT_COMMON_INTERFACES_TOLOOPSOPINTERFACE diff --git a/mlir-tensorrt/common/lib/CMakeLists.txt b/mlir-tensorrt/common/lib/CMakeLists.txt new file mode 100644 index 000000000..a714155fa --- /dev/null +++ b/mlir-tensorrt/common/lib/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) +add_subdirectory(Interfaces) +add_subdirectory(Utils) + diff --git a/mlir-tensorrt/common/lib/Conversion/CMakeLists.txt b/mlir-tensorrt/common/lib/Conversion/CMakeLists.txt new file mode 100644 index 000000000..943496624 --- /dev/null +++ b/mlir-tensorrt/common/lib/Conversion/CMakeLists.txt @@ -0,0 +1,11 @@ +set(LLVM_TARGET_DEFINITIONS + "${MLIR_TENSORRT_COMMON_SOURCE_DIR}/include/mlir-tensorrt-common/Conversion/Passes.td") +set(OUTPUT_DIR + "${MLIR_TENSORRT_COMMON_BINARY_DIR}/include/mlir-tensorrt-common/Conversion") + +cmake_path(SET OUTPUT_DIR NORMALIZE "${OUTPUT_DIR}") +cmake_path(RELATIVE_PATH OUTPUT_DIR BASE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}") +mlir_tablegen("${OUTPUT_DIR}/Passes.h.inc" -gen-pass-decls -name MLIRTensorRTCommonConversion) +add_public_tablegen_target(MLIRTensorRTCommonConversionPassesIncGen) + +add_subdirectory(ToLoops) diff --git a/mlir-tensorrt/common/lib/Conversion/ToLoops/CMakeLists.txt b/mlir-tensorrt/common/lib/Conversion/ToLoops/CMakeLists.txt new file mode 100644 index 000000000..9416ef2d7 --- /dev/null +++ b/mlir-tensorrt/common/lib/Conversion/ToLoops/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_tensorrt_library(MLIRTensorRTCommonConvertToLoops + ConvertToLoops.cpp + + DEPENDS + MLIRTensorRTCommonConversionPassesIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLinalgDialect + MLIRSCFDialect + MLIRTensorDialect + MLIRTensorRTCommonIncludes + MLIRTensorRTCommonLinalgExtTransforms + MLIRTensorRTCommonToLoopsOpInterface + ) \ No newline at end of file diff --git a/mlir-tensorrt/common/lib/Conversion/ToLoops/ConvertToLoops.cpp b/mlir-tensorrt/common/lib/Conversion/ToLoops/ConvertToLoops.cpp new file mode 100644 index 000000000..e8ccadfed --- /dev/null +++ b/mlir-tensorrt/common/lib/Conversion/ToLoops/ConvertToLoops.cpp @@ -0,0 +1,75 @@ +//===- ConvertToLoops.cpp ------------------------------------------------===// +// +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// This file contains the implementation of the ConvertToLoops pass. +/// +//===----------------------------------------------------------------------===// +#include "mlir-tensorrt-common/Conversion/Passes.h" +#include "mlir-tensorrt-common/Dialect/LinalgExt/Transforms/ToLoopsOpInterfaceImpl.h" +#include "mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTTOLOOPS +#include "mlir-tensorrt-common/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct ConvertToLoopsPattern + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + LogicalResult matchAndRewrite(ToLoopsOpInterface op, + PatternRewriter &rewriter) const override { + if (failed(op.lowerToLoops(rewriter))) + return failure(); + return success(); + } +}; + +struct ConvertLinalgOpToLoopsPattern + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + LogicalResult matchAndRewrite(linalg::LinalgOp op, + PatternRewriter &rewriter) const override { + FailureOr> loops = + linalg_ext::convertLinalgOpToLoops(rewriter, op); + if (failed(loops)) + return failure(); + rewriter.replaceOp(op, loops->front()->getResults()); + return success(); + } +}; + +struct ConvertToLoops : public mlir::impl::ConvertToLoopsBase { + void runOnOperation() override { + Operation *op = getOperation(); + RewritePatternSet patterns(op->getContext()); + patterns.add( + op->getContext()); + + walkAndApplyPatterns(op, std::move(patterns)); + } +}; +} // namespace diff --git a/mlir-tensorrt/common/lib/Dialect/CMakeLists.txt b/mlir-tensorrt/common/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..49a7e31fa --- /dev/null +++ b/mlir-tensorrt/common/lib/Dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(EmitCExt) +add_subdirectory(LinalgExt) diff --git a/mlir-tensorrt/common/lib/Dialect/EmitCExt/CMakeLists.txt b/mlir-tensorrt/common/lib/Dialect/EmitCExt/CMakeLists.txt new file mode 100644 index 000000000..5df2b41ca --- /dev/null +++ b/mlir-tensorrt/common/lib/Dialect/EmitCExt/CMakeLists.txt @@ -0,0 +1,9 @@ +add_mlir_tensorrt_library( + MLIREmitCExtDataLayoutImpl + DataLayoutImpl.cpp + + LINK_LIBS PUBLIC + MLIREmitCDialect + MLIRDLTIDialect + MLIRTensorRTCommonIncludes + ) diff --git a/mlir-tensorrt/common/lib/Dialect/EmitCExt/DataLayoutImpl.cpp b/mlir-tensorrt/common/lib/Dialect/EmitCExt/DataLayoutImpl.cpp new file mode 100644 index 000000000..1a833d728 --- /dev/null +++ b/mlir-tensorrt/common/lib/Dialect/EmitCExt/DataLayoutImpl.cpp @@ -0,0 +1,65 @@ +//===- DataLayoutImpl.cpp -------------------------------------------------===// +// +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// This file contains the implementation for the DataLayout extensions to +/// the EmitC dialect. +/// +//===----------------------------------------------------------------------===// +#include "mlir-tensorrt-common/Dialect/EmitCExt/IR/DataLayoutImpl.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" + +using namespace mlir; +using namespace mlir::emitc; + +namespace { +/// Add DataLayoutTypeInterface to the `!emitc.size_t` type. We map all queries +/// to the corresponding property of the built-in `index` type. +struct SizeTDataLayoutInterface + : public DataLayoutTypeInterface::ExternalModel { + llvm::TypeSize getTypeSize(Type type, const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + return dataLayout.getTypeSize(IndexType::get(type.getContext())); + } + llvm::TypeSize getTypeSizeInBits(Type type, const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + return dataLayout.getTypeSizeInBits(IndexType::get(type.getContext())); + } + uint64_t getABIAlignment(Type type, const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + return dataLayout.getTypeABIAlignment(IndexType::get(type.getContext())); + } + uint64_t getPreferredAlignment(Type type, const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + return dataLayout.getTypePreferredAlignment( + IndexType::get(type.getContext())); + } +}; +} // namespace + +namespace mlir::emitc_ext { +void registerDataLayoutInterfaceExternalModels(DialectRegistry ®istry) { + registry.addExtension( + +[](MLIRContext *ctx, emitc::EmitCDialect *emitcdialect) { + emitc::SizeTType::attachInterface(*ctx); + }); +} +} // namespace mlir::emitc_ext \ No newline at end of file diff --git a/mlir-tensorrt/common/lib/Dialect/LinalgExt/CMakeLists.txt b/mlir-tensorrt/common/lib/Dialect/LinalgExt/CMakeLists.txt new file mode 100644 index 000000000..e31af3266 --- /dev/null +++ b/mlir-tensorrt/common/lib/Dialect/LinalgExt/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Transforms) diff --git a/mlir-tensorrt/common/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/mlir-tensorrt/common/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt new file mode 100644 index 000000000..99a0f4745 --- /dev/null +++ b/mlir-tensorrt/common/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_tensorrt_library(MLIRTensorRTCommonLinalgExtTransforms + ToLoopsOpInterfaceImpl.cpp + + LINK_LIBS PUBLIC + MLIRAffineDialect + MLIRLinalgDialect + MLIRLinalgTransforms + MLIRSCFDialect + MLIRTensorDialect + MLIRTensorRTCommonIncludes + MLIRTensorRTCommonToLoopsOpInterface + ) diff --git a/mlir-tensorrt/common/lib/Dialect/LinalgExt/Transforms/ToLoopsOpInterfaceImpl.cpp b/mlir-tensorrt/common/lib/Dialect/LinalgExt/Transforms/ToLoopsOpInterfaceImpl.cpp new file mode 100644 index 000000000..c2da43775 --- /dev/null +++ b/mlir-tensorrt/common/lib/Dialect/LinalgExt/Transforms/ToLoopsOpInterfaceImpl.cpp @@ -0,0 +1,194 @@ +//===- ToLoopsOpInterfaceImpl.cpp ----------------------------------------===// +// +// Modified from original LLVM/MLIR code under Linalg dialect transforms. +// Original license: +// "Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception" +// +// Modifiecations: +// - Changed code so that it expects linalg operations to operate on tensor +// instead of buffer types. +// +// Modifications Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// This file contains the implementation of the ToLoopsOpInterface extensions +/// to the Linalg dialect. +/// +//===----------------------------------------------------------------------===// +#include "mlir-tensorrt-common/Dialect/LinalgExt/Transforms/ToLoopsOpInterfaceImpl.h" +#include "mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/DialectRegistry.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::linalg; +using namespace mlir::linalg_ext; + +/// Make canonical affine applies for the given map and values. +static SmallVector +makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, + ArrayRef vals) { + if (map.isEmpty()) + return {}; + assert(map.getNumInputs() == vals.size()); + SmallVector res; + res.reserve(map.getNumResults()); + unsigned numDoms = map.getNumDims(); + for (AffineExpr e : map.getResults()) { + auto exprMap = AffineMap::get(numDoms, map.getNumSymbols(), e); + res.push_back(affine::makeComposedFoldedAffineApply(b, loc, exprMap, vals)); + } + return getValueOrCreateConstantIndexOp(b, loc, res); +} + +/// Inline the region of the LinalgOp and emit the "store" (tensor.insert) +/// operations. +static SmallVector +inlineRegionAndEmitStore(OpBuilder &b, Location loc, linalg::LinalgOp op, + ValueRange allIvs, ArrayRef indexedValues, + ArrayRef> indexing, + ValueRange outputBuffers) { + Block &block = *op.getBlock(); + IRMapping map; + map.map(block.getArguments(), indexedValues); + for (auto &op : block.without_terminator()) { + if (auto indexOp = dyn_cast(&op)) { + map.map(op.getResult(0), allIvs[indexOp.getDim()]); + continue; + } + auto *newOp = b.clone(op, map); + map.map(op.getResults(), newOp->getResults()); + } + + linalg::YieldOp terminator = cast(block.getTerminator()); + SmallVector storeValues; + for (OpOperand &operand : terminator->getOpOperands()) { + Value toStore = map.lookupOrDefault(operand.get()); + storeValues.push_back(b.create( + loc, toStore, outputBuffers[operand.getOperandNumber()], + indexing[operand.getOperandNumber()])); + } + return storeValues; +} + +/// Emit the scalar implementation for the LinalgOp operation. +static scf::ValueVector +emitScalarImplementation(OpBuilder &b, Location loc, ValueRange allIvs, + linalg::LinalgOp linalgOp, + ValueRange operandValuesToUse) { + assert(linalgOp.hasPureTensorSemantics() && + "expected linalg op with tensor semantics"); + SmallVector indexedValues; + indexedValues.reserve(linalgOp->getNumOperands()); + + SmallVector allIvsPlusDims(allIvs); + + for (auto [inputOperand, operandValue] : + llvm::zip(linalgOp.getDpsInputOperands(), + operandValuesToUse.take_front(linalgOp.getNumDpsInputs()))) { + if (linalgOp.isScalar(inputOperand)) { + indexedValues.push_back(operandValue); + continue; + } + SmallVector indexing = makeCanonicalAffineApplies( + b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims); + indexedValues.push_back( + b.create(loc, operandValue, indexing)); + } + + SmallVector, 2> outputIndexing; + for (auto [outputOperand, outputValue] : + llvm::zip(linalgOp.getDpsInitsMutable(), + operandValuesToUse.take_back(linalgOp.getNumDpsInits()))) { + SmallVector indexing = makeCanonicalAffineApplies( + b, loc, linalgOp.getMatchingIndexingMap(&outputOperand), + allIvsPlusDims); + indexedValues.push_back( + b.create(loc, outputValue, indexing)); + outputIndexing.push_back(indexing); + } + + return inlineRegionAndEmitStore( + b, loc, linalgOp, allIvs, indexedValues, + ArrayRef(outputIndexing).take_back(linalgOp.getNumDpsInits()), + operandValuesToUse.take_back(linalgOp.getNumDpsInits())); +} + +/// Lower the LinalgOp to a 'scf.for' loop nest. +FailureOr> +mlir::linalg_ext::convertLinalgOpToLoops(RewriterBase &rewriter, + linalg::LinalgOp linalgOp) { + // The flattened loopToOperandRangesMaps is expected to be an invertible + // permutation map (which is asserted in the inverse calculation). + if (!linalgOp.hasPureTensorSemantics()) + return emitError(linalgOp.getLoc()) + << "expected linalg op with tensor semantics"; + + SmallVector loopRanges = + linalgOp.createLoopRanges(rewriter, linalgOp.getLoc()); + SmallVector iteratorTypes = + linalgOp.getIteratorTypesArray(); + + // Generate the loop nest using the 'mlir::linalg::GenerateLoopNest' utility. + SmallVector loops; + mlir::linalg::GenerateLoopNest::doit( + rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange ivs, + ValueRange operandValuesToUse) -> scf::ValueVector { + for (Value v : ivs) { + BlockArgument ivVal = cast(v); + loops.push_back(ivVal.getOwner()->getParentOp()); + } + return emitScalarImplementation(b, loc, ivs, linalgOp, + operandValuesToUse); + }); + + return loops; +} + +namespace { +template +struct ToLoopsOpInterfaceImpl + : public ToLoopsOpInterface::ExternalModel, + OpTy> { + FailureOr lowerToLoops(Operation *op, + RewriterBase &rewriter) const { + FailureOr> loops = + convertLinalgOpToLoops(rewriter, cast(op)); + if (failed(loops)) + return failure(); + rewriter.replaceOp(op, loops->front()); + return loops->front(); + } +}; +} // namespace + +void linalg_ext::registerToLoopsOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { + linalg::GenericOp::attachInterface>(*ctx); + // linalg::MapOp::attachInterface>(*ctx); + }); +} diff --git a/mlir-tensorrt/common/lib/Interfaces/CMakeLists.txt b/mlir-tensorrt/common/lib/Interfaces/CMakeLists.txt new file mode 100644 index 000000000..eaf86b6c2 --- /dev/null +++ b/mlir-tensorrt/common/lib/Interfaces/CMakeLists.txt @@ -0,0 +1,35 @@ + +function(mlir_tensorrt_common_op_interface interface) + set(LLVM_TARGET_DEFINITIONS + "${MLIR_TENSORRT_COMMON_SOURCE_DIR}/include/mlir-tensorrt-common/Interfaces/${interface}.td") + set(OUTPUT_DIR + "${MLIR_TENSORRT_COMMON_BINARY_DIR}/include/mlir-tensorrt-common/Interfaces") + + cmake_path(SET OUTPUT_DIR NORMALIZE "${OUTPUT_DIR}") + cmake_path(RELATIVE_PATH OUTPUT_DIR BASE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}") + mlir_tablegen("${OUTPUT_DIR}/${interface}.h.inc" -gen-op-interface-decls) + mlir_tablegen("${OUTPUT_DIR}/${interface}.cpp.inc" -gen-op-interface-defs) + add_public_tablegen_target(MLIRTensorRTCommon${interface}IncGen) +endfunction() + +function(add_mlir_tensorrt_common_interface_library target) + add_mlir_tensorrt_library("${target}" + PARTIAL_SOURCES_INTENDED + ${ARGN} + ) +endfunction() + +mlir_tensorrt_common_op_interface(ToLoopsOpInterface) + +add_mlir_tensorrt_common_interface_library( + MLIRTensorRTCommonToLoopsOpInterface + + ToLoopsOpInterface.cpp + + DEPENDS + MLIRTensorRTCommonToLoopsOpInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRTensorRTCommonIncludes + MLIRIR + ) \ No newline at end of file diff --git a/mlir-tensorrt/common/lib/Interfaces/ToLoopsOpInterface.cpp b/mlir-tensorrt/common/lib/Interfaces/ToLoopsOpInterface.cpp new file mode 100644 index 000000000..cd72884ed --- /dev/null +++ b/mlir-tensorrt/common/lib/Interfaces/ToLoopsOpInterface.cpp @@ -0,0 +1,13 @@ +//===- ToLoopsOpInterface.cpp ---------------------------------------------===// +// +// Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +// +//===----------------------------------------------------------------------===// +/// +/// This file contains the implementation of the `ToLoopsOpInterface` interface. +/// +//===----------------------------------------------------------------------===// +#include "mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.h" + +// Include generated interface definitions. +#include "mlir-tensorrt-common/Interfaces/ToLoopsOpInterface.cpp.inc" diff --git a/mlir-tensorrt/common/lib/Utils/CMakeLists.txt b/mlir-tensorrt/common/lib/Utils/CMakeLists.txt new file mode 100644 index 000000000..1d03f05ec --- /dev/null +++ b/mlir-tensorrt/common/lib/Utils/CMakeLists.txt @@ -0,0 +1,3 @@ +if(MLIR_TRT_TARGET_TENSORRT) + add_subdirectory(TensorRTDynamicLoader) +endif() diff --git a/mlir-tensorrt/executor/lib/Utils/TensorRTDynamicLoader/CMakeLists.txt b/mlir-tensorrt/common/lib/Utils/TensorRTDynamicLoader/CMakeLists.txt similarity index 64% rename from mlir-tensorrt/executor/lib/Utils/TensorRTDynamicLoader/CMakeLists.txt rename to mlir-tensorrt/common/lib/Utils/TensorRTDynamicLoader/CMakeLists.txt index 1020ccfb7..d9da66f08 100644 --- a/mlir-tensorrt/executor/lib/Utils/TensorRTDynamicLoader/CMakeLists.txt +++ b/mlir-tensorrt/common/lib/Utils/TensorRTDynamicLoader/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_executor_library(MLIRTRTTensorRTDynamicLoader +add_mlir_tensorrt_library(MLIRTRTTensorRTDynamicLoader TensorRTDynamicLoader.cpp LINK_LIBS PRIVATE diff --git a/mlir-tensorrt/executor/lib/Utils/TensorRTDynamicLoader/TensorRTDynamicLoader.cpp b/mlir-tensorrt/common/lib/Utils/TensorRTDynamicLoader/TensorRTDynamicLoader.cpp similarity index 100% rename from mlir-tensorrt/executor/lib/Utils/TensorRTDynamicLoader/TensorRTDynamicLoader.cpp rename to mlir-tensorrt/common/lib/Utils/TensorRTDynamicLoader/TensorRTDynamicLoader.cpp diff --git a/mlir-tensorrt/compiler/CMakeLists.txt b/mlir-tensorrt/compiler/CMakeLists.txt index e039c9b8a..d15e36a12 100644 --- a/mlir-tensorrt/compiler/CMakeLists.txt +++ b/mlir-tensorrt/compiler/CMakeLists.txt @@ -31,6 +31,7 @@ add_mlir_tensorrt_compiler_dependency(MLIRNVVMTarget) add_mlir_tensorrt_compiler_dependency(MLIRPtrDialect) add_mlir_tensorrt_compiler_dependency(MLIRTargetLLVM) add_mlir_tensorrt_compiler_dependency(MLIRTensorTransformOps) +add_mlir_tensorrt_compiler_dependency(MLIREmitCExtDataLayoutImpl) add_subdirectory(include) add_subdirectory(lib) diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Backends/Host/HostBackend.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Backends/Host/HostBackend.td index 9b3d1d31c..589fc4179 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Backends/Host/HostBackend.td +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Backends/Host/HostBackend.td @@ -5,7 +5,7 @@ include "mlir-tensorrt/Dialect/Plan/IR/PlanDialect.td" include "mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.td" def Plan_HostClusterKindAttr : Plan_Attr<"HostClusterKind", "host_cluster", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let parameters = (ins "int64_t":$benefit); let assemblyFormat = "`<` struct(params) `>`"; } diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Extension.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Extension.h index 670cd0fba..ec2de8bd6 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Extension.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Extension.h @@ -73,11 +73,6 @@ class TaskExtensionBase { template using ListOption = mlir::detail::PassOptions::ListOption; -protected: - /// Whether this extension is disabled. Should default to false and be - /// associated with a flag `--disable-[name]-extension`. - bool disabled{false}; - private: mlir::TypeID typeID; diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h index 72e828b90..2f6a9a8bb 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h @@ -21,8 +21,8 @@ /// Data structures and functions for manipulating compiler options. /// //===----------------------------------------------------------------------===// -#ifndef MLIR_TENSORRT_COMPILER_OPTIONS -#define MLIR_TENSORRT_COMPILER_OPTIONS +#ifndef MLIR_TENSORRT_COMPILER_OPTIONSPROVIDERS +#define MLIR_TENSORRT_COMPILER_OPTIONSPROVIDERS #include "mlir-executor/Support/DeviceInfo.h" #include "mlir/Pass/PassManager.h" @@ -210,7 +210,7 @@ struct DeviceOptions : public OptionsProvider { private: /// Stores host device info. This is populated by the callback of - /// `shouldInfoFromHost`. If present, then it will also override the other + /// `shouldInferFromHost`. If present, then it will also override the other /// options in their callbacks. std::optional hostDeviceInfo{}; }; @@ -316,7 +316,6 @@ class CompilationTaskOptionsBase llvm::cl::desc("entrypoint function name")}; protected: - std::vector> extensions; std::unique_ptr debugOptions{nullptr}; }; @@ -363,4 +362,4 @@ class CompilationTaskOptions : public CompilationTaskOptionsBase { } // namespace mlirtrt::compiler -#endif // MLIR_TENSORRT_COMPILER_OPTIONS +#endif // MLIR_TENSORRT_COMPILER_OPTIONSPROVIDERS diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/StablehloToExecutable.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/StablehloToExecutable.h index b0c606f4e..60646b20e 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/StablehloToExecutable.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/StablehloToExecutable.h @@ -138,13 +138,6 @@ class StablehloToExecutableTask static void populatePassManager(mlir::OpPassManager &pm, const StablehloToExecutableOptions &options); - - /// Compile a StableHLO module into a MLIR-TensorRT Runtime executable. - /// This is the "functional" entrypoint that will allocate a new PassManager - /// for a single run. - static mlirtrt::StatusOr> - compileStableHLOToExecutable(CompilerClient &client, mlir::ModuleOp module, - const StablehloToExecutableOptions &options); }; /// Register the task/options with the client's registry. diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/TensorRTExtension.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/TensorRTExtension.h index c829f6c88..c2647ecb8 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/TensorRTExtension.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/TensorRTExtension.h @@ -60,8 +60,8 @@ class StablehloToExecutableTensorRTExtension this->workspaceMemoryPoolLimit = options.workspaceMemoryPoolLimit; } - Option disable{this->ctx, "disable-tensorrt-extension", - llvm::cl::init(false)}; + Option disabled{this->ctx, "disable-tensorrt-extension", + llvm::cl::init(false)}; Option format{ this->ctx, "tensorrt-target", diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td index b7d3d014c..8d2963f50 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td @@ -42,7 +42,9 @@ def ConvertStablehloToTensorRTPass : Pass<"convert-stablehlo-to-tensorrt"> { Option<"convertConditionals", "convert-conditionals", "bool", "true", "convert conditionals to TensorRT's conditional layer">, Option<"trtMajorVersion", "trt-major-version", "int64_t", "10", - "target TensorRT version for conversion"> + "target TensorRT version for conversion">, + Option<"preferEinsum", "prefer-einsum", "bool", "false", + "prefer converting to 'tensorrt.einsum' over 'tensorrt.matrix_multiply'"> ]; } #endif // MLIR_TENSORRT_ENABLE_HLO @@ -321,7 +323,10 @@ def ConvertStablehloToScfPass : Pass<"convert-stablehlo-to-scf"> { }]; let dependentDialects = [ "::mlir::tensor::TensorDialect", - "::mlir::scf::SCFDialect" + "::mlir::scf::SCFDialect", + "::mlir::tensor::TensorDialect", + "::mlir::arith::ArithDialect", + "::mlir::math::MathDialect" ]; } diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/StablehloToTensorRT/StablehloToTensorRT.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/StablehloToTensorRT/StablehloToTensorRT.h index a4e74b2a6..62793b6e9 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/StablehloToTensorRT/StablehloToTensorRT.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/StablehloToTensorRT/StablehloToTensorRT.h @@ -31,11 +31,20 @@ namespace mlir { class ConversionTarget; -// Collection of rewrite patterns for lowering of Stable HLO to TensorRT -// dialect. +/// Populate patterns for converting Stablehlo reduction and contraction ops to +/// TensorRT. +void populateStablehloReductionAndContractionToTensorRtConversionPattern( + TensorRTTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit = 1, PatternBenefit dotToEinsumBenefit = 0); + +/// Collection of rewrite patterns for lowering of Stable HLO to TensorRT +/// dialect. +/// The `preferEinsum` parameter controls whether `tensorrt.einsum` is used +/// as the primary method for converting `stablehlo.dot_general` or only for +/// fallback when conversion to `tensorrt.matrix_multiply` is not possible. void populateStablehloToTensorRtConversionPattern( TensorRTTypeConverter &typeConverter, RewritePatternSet &patterns, - ShapeInfoCallbacks shapeInfoCallbacks = {}); + ShapeInfoCallbacks shapeInfoCallbacks = {}, bool preferEinsum = false); /// Populate patterns for convert Chlo ops to TensorRT ops. void populateChloToTensorRtLegalityAndPatterns( diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/TensorRTCommon/ConvertToTensorRTCommon.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/TensorRTCommon/ConvertToTensorRTCommon.h index b39064885..05e344676 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/TensorRTCommon/ConvertToTensorRTCommon.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/TensorRTCommon/ConvertToTensorRTCommon.h @@ -208,26 +208,6 @@ class ConvertOpToTensorRTPattern : public ConvertToTensorRTPattern { : ConvertToTensorRTPattern(typeConverter, SourceOp::getOperationName(), benefit, context) {} - /// Wrappers around the ConversionPattern methods that pass the derived op - /// type. - LogicalResult match(Operation *op) const final { - return match(cast(op)); - } - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - if constexpr (SourceOp::hasProperties()) - return rewrite(cast(op), - OpAdaptor(operands, op->getAttrDictionary(), - cast(op).getProperties()), - rewriter); - rewrite(cast(op), OpAdaptor(operands, op->getAttrDictionary()), - rewriter); - } - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto sourceOp = cast(op); - rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); - } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { @@ -248,32 +228,10 @@ class ConvertOpToTensorRTPattern : public ConvertToTensorRTPattern { rewriter); } - /// Rewrite and Match methods that operate on the SourceOp type. These must be - /// overridden by the derived pattern class. - virtual LogicalResult match(SourceOp op) const { - (void)op; - llvm_unreachable("must override match or matchAndRewrite"); - } - virtual void rewrite(SourceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - (void)op; - (void)adaptor; - (void)rewriter; - llvm_unreachable("must override matchAndRewrite or a rewrite method"); - } - virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - SmallVector oneToOneOperands = - getOneToOneAdaptorOperands(adaptor.getOperands()); - rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); - } virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (failed(match(op))) - return failure(); - rewrite(op, adaptor, rewriter); - return success(); + llvm_unreachable("must override matchAndRewrite"); } virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/Plan.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/Plan.h index ff11571a2..ea212808f 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/Plan.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/Plan.h @@ -26,6 +26,7 @@ #include "mlir-tensorrt-dialect/Interface/TensorKindOpInterface.h" #include "mlir-tensorrt/Compiler/Extension.h" +#include "mlir-tensorrt/Dialect/Plan/IR/PlanEnums.h" #include "mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.h" #include "mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.h" #include "mlir/Bytecode/BytecodeOpInterface.h" @@ -134,11 +135,6 @@ class PlanDialectExtension }; } // namespace mlir::plan -//===----------------------------------------------------------------------===// -// Plan Enums -//===----------------------------------------------------------------------===// -#include "mlir-tensorrt/Dialect/Plan/IR/PlanEnums.h.inc" - //===----------------------------------------------------------------------===// // Plan Attributes //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanDialect.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanDialect.td index 7027bb49c..9902b377b 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanDialect.td +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanDialect.td @@ -91,6 +91,14 @@ def Plan_Dialect : Dialect { return "plan.shape_profile"; } + /// Return the name of the attribute used to encode memory space + /// constraints. It should appear in function attributes or in + /// function arg/result attribute dictionaries. + static StringRef getMemorySpaceConstraintAttrName() { + return "plan.memory_space"; + } + + private: ::llvm::StringMap attrParsingHooks; ::llvm::DenseMap<::mlir::TypeID, AttrPrintingHook> attrPrintingHooks; diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanEnums.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanEnums.h new file mode 100644 index 000000000..4a0519028 --- /dev/null +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanEnums.h @@ -0,0 +1,29 @@ +//===- PlanEnums.h ----------------------------------------------*- C++ -*-===// +// +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// Plan dialect enum declarations. +/// +//===----------------------------------------------------------------------===// +#ifndef MLIR_TENSORRT_DIALECT_PLAN_IR_PLANENUMS +#define MLIR_TENSORRT_DIALECT_PLAN_IR_PLANENUMS + +#include "mlir-tensorrt/Dialect/Plan/IR/PlanEnums.h.inc" + +#endif // MLIR_TENSORRT_DIALECT_PLAN_IR_PLANENUMS diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.h index 5f7afdc4a..ef383ffd2 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.h @@ -25,6 +25,7 @@ #define MLIR_TENSORRT_DIALECT_PLAN_IR_PLANINTERFACES #include "mlir-executor/Transforms/Clustering/Clustering.h" +#include "mlir-tensorrt/Dialect/Plan/IR/PlanEnums.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.td index cbfd98954..9d9487b86 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.td +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.td @@ -67,6 +67,16 @@ def ClusterKindAttrInterface : AttrInterface<"ClusterKindAttrInterface"> { /*body=*/"", /*defaultImplementation=*/"" >, + InterfaceMethod< + /*desc*/[{ + Returns the default memory space used for this cluster kind. + }], + /*retTy*/"::mlir::plan::MemorySpace", + "getDefaultMemorySpace", + (ins), + "", + "return ::mlir::plan::MemorySpace::device;" + >, InterfaceMethod< /*desc=*/[{ Return true if the cluster requires closure prior to diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.td index e7538edea..b97325194 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.td +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.td @@ -449,6 +449,83 @@ def PlanAssignMemorySpacesPass : Pass<"plan-assign-memory-spaces", ]; } +//===----------------------------------------------------------------------===// +// PlanOptimizeMemorySpacesPass +//===----------------------------------------------------------------------===// + +def PlanOptimizeMemorySpacesPass : Pass<"plan-optimize-memory-spaces", + "::mlir::func::FuncOp"> { + let summary = "optimizes memory spaces encodings to tensor types"; + + let description = [{ + This pass applies a set of transformations that attempt to optimize the + memory space encodings of tensor types in terms of host vs. device + placement. This includes changes such as (but not limited to): + + - Removing redundant memory space changes. + - Hoisting memory space changes out of loops. + - Ensuring operations that require certain operands to live in specific + memory spaces (host vs. device) have such constraints met. + + Note that this pass only deals with 'host' and 'device' memory spaces. The + current contract is that use of other specialized memory spaces (e.g. + `host_pinned`) is done via follow-on specialized optimization passes. + }]; + + let dependentDialects = [ + "::mlir::plan::PlanDialect", + "::mlir::bufferization::BufferizationDialect", + "::mlir::tensor::TensorDialect" + ]; +} + +//===----------------------------------------------------------------------===// +// PlanPromoteHostTensorsToHostPinnedPass +//===----------------------------------------------------------------------===// + +def PlanPromoteHostTensorsToHostPinnedPass + : Pass<"plan-promote-host-tensors-to-host-pinned", "::mlir::func::FuncOp"> { + let summary = "promotes host tensors to host pinned tensors"; + + let description = [{ + This pass finds host tensors which are ideal candidates for promotion to the + 'host-pinned' memory space. This pass must be run after the + `plan-optimize-memory-spaces` pass. + }]; + + let dependentDialects = [ + "::mlir::plan::PlanDialect", + "::mlir::bufferization::BufferizationDialect", + "::mlir::tensor::TensorDialect" + ]; +} + +//===----------------------------------------------------------------------===// +// PlanMaterializeExplicitTransfersPass +//===----------------------------------------------------------------------===// + +def PlanMaterializeExplicitTransfersPass + : Pass<"plan-materialize-explicit-transfers"> { + let summary = "Turn `tensor.cast` that cast between memory spaces into " + "explicit transfers using bufferization ops."; + + let description = [{ + This pass materializes explicit transfers between memory spaces by + lowering `tensor.cast` operations that change the memory space specified + by the tensor encoding attributes of the operand/result types. + + The transfers are materialized as explicit `bufferization.alloc_tensor` + and `bufferization.materialize_in_destination` operations to perform the + copy (the more concise `bufferization.alloc_tensor` with `copy` operand + currently cannot change between memory spaces). + }]; + + let dependentDialects = [ + "::mlir::bufferization::BufferizationDialect", + "::mlir::tensor::TensorDialect", + ]; +} + //===----------------------------------------------------------------------===// // PlanAllocTensorsPass //===----------------------------------------------------------------------===// @@ -554,6 +631,30 @@ def PlanRemoveEquivalentBufferResultsPass : Pass<"plan-remove-equivalent-buffer- }]; } +//===----------------------------------------------------------------------===// +// PlanBufferResultsToOutParamsPass +//===----------------------------------------------------------------------===// + +def PlanBufferResultsToOutParamsPass : Pass<"plan-buffer-results-to-out-params", + "::mlir::ModuleOp"> { + let summary = "Convert buffer results to out params"; + + let description = [{ + This pass converts function memref results to out params. There is a similar + upstream pass, but our version is more advanced and can handle promoting + a set of memref results. + }]; + + let dependentDialects = [ + "::mlir::memref::MemRefDialect" + ]; + + let options = [ + Option<"ignorePublicFunctions", "ignore-public-functions", "bool", + "false", "do not apply the transformation on public functions"> + ]; +} + //===----------------------------------------------------------------------===// // PlanOwnershipBasedBufferDeallocationPass //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/InitAllDialects.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/InitAllDialects.h index 17cc3125a..6621a452c 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/InitAllDialects.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/InitAllDialects.h @@ -25,6 +25,8 @@ #define MLIR_TENSORRT_INIT_ALL_DIALECTS #include "mlir-executor/Executor/IR/Executor.h" +#include "mlir-tensorrt-common/Dialect/EmitCExt/IR/DataLayoutImpl.h" +#include "mlir-tensorrt-common/Dialect/LinalgExt/Transforms/ToLoopsOpInterfaceImpl.h" #include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h" #include "mlir-tensorrt-dialect/TensorRT/Target/TensorRTEncodingImpl.h" #include "mlir-tensorrt/Backends/Host/HostBackend.h" @@ -158,11 +160,13 @@ inline void registerAllDialects(mlir::DialectRegistry ®istry) { mlir::cf::registerBufferDeallocationOpInterfaceExternalModels(registry); mlir::cf::registerBufferizableOpInterfaceExternalModels(registry); mlir::cuda::registerBufferizableOpInterfaceExternalModels(registry); + mlir::emitc_ext::registerDataLayoutInterfaceExternalModels(registry); mlir::linalg::registerBufferizableOpInterfaceExternalModels(registry); mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry); mlir::linalg::registerSubsetOpInterfaceExternalModels(registry); mlir::linalg::registerTilingInterfaceExternalModels(registry); mlir::linalg::registerValueBoundsOpInterfaceExternalModels(registry); + mlir::linalg_ext::registerToLoopsOpInterfaceExternalModels(registry); mlir::LLVM::registerInlinerInterface(registry); mlir::memref::registerAllocationOpInterfaceExternalModels(registry); mlir::memref::registerBufferViewFlowOpInterfaceExternalModels(registry); diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/InitAllPasses.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/InitAllPasses.h index 8f2d088d4..36ef996a3 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/InitAllPasses.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/InitAllPasses.h @@ -23,6 +23,7 @@ #define REGISTRATION_REGISTERMLIRTENSORRTPASSES_H #include "mlir-executor/InitAllPasses.h" +#include "mlir-tensorrt-common/Conversion/Passes.h" #include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h" #include "mlir-tensorrt/Conversion/Passes.h" #include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h" @@ -52,8 +53,8 @@ inline void registerAllPasses() { mlir::emitc::registerEmitCPasses(); mlir::plan::registerPlanDialectPipelines(); mlir::plan::registerPlanPasses(); - mlir::registerConvertAffineToStandard(); - mlir::registerConvertPDLToPDLInterp(); + mlir::registerLowerAffinePass(); + mlir::registerConvertPDLToPDLInterpPass(); mlir::registerMLIRTensorRTConversionPasses(); mlir::registerMLIRTensorRTGenericTransformsPasses(); mlir::registerTransformsPasses(); @@ -61,6 +62,7 @@ inline void registerAllPasses() { mlir::registerConvertCUDAToExecutorPass(); mlir::bufferization::registerBufferizationPasses(); mlir::executor::registerAllPasses(); + mlir::registerMLIRTensorRTCommonConversionPasses(); IF_MLIR_TRT_ENABLE_HLO({ mlirtrt::compiler::registerStablehloToExecutablePasses(); diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Transforms/Transforms.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Transforms/Transforms.h index 72535eb1f..67b421d30 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Transforms/Transforms.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Transforms/Transforms.h @@ -25,14 +25,33 @@ #ifndef MLIR_TENSORRT_TRANSFORMS_TRANSFORMS_H #define MLIR_TENSORRT_TRANSFORMS_TRANSFORMS_H +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include namespace mlir { class ModuleOp; class RewriterBase; +class RewritePatternSet; /// Remove any operations nested below `op` that have the "IsolatedFromAbove" /// and "SymbolTable" attribute. void dropNestedModules(RewriterBase &rewriter, ModuleOp op); +using ShouldScalarizeWhileBeforeArgFunc = + std::function; +using ShouldScalarizeWhileAfterArgFunc = + std::function; + +/// Populates the patterns to detensorize scf.while ops. The provided functions +/// are used to control whether the arguments in each region are a candidate for +/// scalarization. They will currently only receive arguments that are tensor +/// types with a single element. +void populateSCFDetensorizeWhilePatterns( + RewritePatternSet &patterns, + ShouldScalarizeWhileBeforeArgFunc shouldScalarizeBeforeArg, + ShouldScalarizeWhileAfterArgFunc shouldScalarizeAfterArg, + PatternBenefit benefit = 1); + } // namespace mlir #endif // MLIR_TENSORRT_TRANSFORMS_TRANSFORMS_H diff --git a/mlir-tensorrt/compiler/lib/Backends/Host/HostBackend.cpp b/mlir-tensorrt/compiler/lib/Backends/Host/HostBackend.cpp index f804ebca4..c0263d669 100644 --- a/mlir-tensorrt/compiler/lib/Backends/Host/HostBackend.cpp +++ b/mlir-tensorrt/compiler/lib/Backends/Host/HostBackend.cpp @@ -204,6 +204,10 @@ bool HostClusterKindAttr::supportsInputKind(InputKind inputKind) const { return inputKind == InputKind::Stablehlo; } +MemorySpace HostClusterKindAttr::getDefaultMemorySpace() const { + return MemorySpace::host; +} + //===----------------------------------------------------------------------===// // Extension Registration //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Registration/RegisterAllDialects.cpp b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Registration/RegisterAllDialects.cpp index 9608ca4e0..f6f263dff 100644 --- a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Registration/RegisterAllDialects.cpp +++ b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Registration/RegisterAllDialects.cpp @@ -25,9 +25,23 @@ #include "mlir-tensorrt-c/Compiler/Registration/RegisterAllDialects.h" #include "mlir-tensorrt/Compiler/StablehloToExecutable/StablehloToExecutable.h" #include "mlir-tensorrt/Compiler/TensorRTToExecutable/TensorRTToExecutable.h" +#include "mlir-tensorrt/Features.h" #include "mlir-tensorrt/InitAllDialects.h" #include "mlir-tensorrt/InitAllExtensions.h" #include "mlir-tensorrt/InitAllPasses.h" + +#ifdef MLIR_TRT_ENABLE_TORCH +#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" +#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h" +#include "torch-mlir/Conversion/Passes.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" +#include "torch-mlir/RefBackend/Passes.h" +#endif + #include "mlir/CAPI/IR.h" void mtrtCompilerRegisterDialects(MlirDialectRegistry registry) { @@ -35,7 +49,15 @@ void mtrtCompilerRegisterDialects(MlirDialectRegistry registry) { mlirtrt::compiler::registerAllExtensions(*unwrap(registry)); } -void mtrtCompilerRegisterPasses() { mlirtrt::compiler::registerAllPasses(); } +void mtrtCompilerRegisterPasses() { + mlirtrt::compiler::registerAllPasses(); + IF_MLIR_TRT_ENABLE_TORCH({ + mlir::torch::registerTorchPasses(); + mlir::torch::registerTorchConversionPasses(); + mlir::torch::registerConversionPasses(); + mlir::torch::TMTensor::registerPasses(); + }); +} void mtrtCompilerRegisterTasks() { mlirtrt::compiler::registerStableHloToExecutableTask(); diff --git a/mlir-tensorrt/compiler/lib/Compiler/OptionsProviders.cpp b/mlir-tensorrt/compiler/lib/Compiler/OptionsProviders.cpp index 497970259..8941eeb9e 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/OptionsProviders.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/OptionsProviders.cpp @@ -154,10 +154,7 @@ std::optional CompilationTaskOptionsBase::getHash() const { llvm::raw_svector_ostream os(str); this->print(os); } - auto val = llvm::hash_value(str); - for (const auto &ext : extensions) - val = llvm::hash_combine(val, *ext->getHash()); - return val; + return llvm::hash_value(str); } mlir::LogicalResult diff --git a/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/Passes.cpp b/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/Passes.cpp index 5be59a7e7..02d774de4 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/Passes.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/Passes.cpp @@ -28,8 +28,10 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/DialectResourceBlobManager.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/PassOptions.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "stablehlo/dialect/StablehloOps.h" #ifdef MLIR_TRT_ENABLE_HLO @@ -109,6 +111,27 @@ class ProcessHostClustersPass // ConvertStablehloConstantToArithPass //===----------------------------------------------------------------------===// +static FailureOr +handleStablehloConstantAttr(Location loc, ElementsAttr elAttr) { + Type elementType = elAttr.getElementType(); + if (auto integerType = dyn_cast(elementType)) { + if (integerType.isSignless()) + return elAttr; + Type signlessType = + IntegerType::get(elAttr.getContext(), integerType.getWidth()); + if (auto denseElementsAttr = dyn_cast(elAttr)) + return ElementsAttr(denseElementsAttr.bitcast(signlessType)); + if (auto denseResourceElementsAttr = + dyn_cast(elAttr)) { + auto handle = denseResourceElementsAttr.getRawHandle(); + return ElementsAttr(DenseResourceElementsAttr::get( + elAttr.getShapedType().clone(signlessType), handle)); + } + return emitError(loc, "unsupported constant attribute kind"); + } + return elAttr; +} + class ConvertStablehloConstantToArithPass : public compiler::impl::ConvertStablehloConstantsToArithPassBase< ConvertStablehloConstantToArithPass> { @@ -117,23 +140,27 @@ class ConvertStablehloConstantToArithPass void runOnOperation() override { func::FuncOp func = getOperation(); - - // Apply other preparation and simplification patterns. - RewritePatternSet patterns(func->getContext()); - // Convert `stablehlo.constant` to `arith.constant`. - patterns.add( - +[](stablehlo::ConstantOp constOp, - PatternRewriter &rewriter) -> LogicalResult { - rewriter.replaceOpWithNewOp(constOp, - constOp.getValue()); - return success(); + IRRewriter rewriter(func->getContext()); + auto walkResult = + func.walk([&](stablehlo::ConstantOp constOp) { + FailureOr elAttr = + handleStablehloConstantAttr(constOp.getLoc(), constOp.getValue()); + if (failed(elAttr)) + return WalkResult::interrupt(); + Type newType = elAttr->getType(); + rewriter.setInsertionPoint(constOp); + auto newConstOp = rewriter.create( + constOp.getLoc(), newType, *elAttr); + if (newType == constOp.getType()) { + rewriter.replaceOp(constOp, newConstOp); + } else { + rewriter.replaceOpWithNewOp( + constOp, constOp.getType(), newConstOp.getResult()); + } + return WalkResult::advance(); }); - - if (failed(applyPatternsGreedily(func, std::move(patterns)))) { - emitError(func.getLoc()) - << "failed to apply patterns in " << getArgument(); + if (walkResult.wasInterrupted()) return signalPassFailure(); - } } void getDependentDialects(DialectRegistry ®istry) const override { diff --git a/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/StableHloInputPipelines.cpp b/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/StableHloInputPipelines.cpp index d12308f10..e67b5c368 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/StableHloInputPipelines.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/StableHloInputPipelines.cpp @@ -34,6 +34,9 @@ using namespace mlirtrt::compiler; static void buildStableHloSimplificationPipeline( OpPassManager &pm, const mlir::ConvertChloToStableHloExtPassOptions &chloToStablehloOptions) { + pm.addNestedPass( + stablehlo::createStablehloLegalizeCompositeToCallPass()); + pm.addPass(createInlinerPass()); // Some match-and-raise patterns should be performed before canonicalization, // since the pattern is based on specific frontend patterns (e.g. JAX). pm.addPass(stablehlo_ext::createExpandTuplesPass()); diff --git a/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/StablehloToExecutable.cpp b/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/StablehloToExecutable.cpp index aea81e136..cbf5c9abf 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/StablehloToExecutable.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/StablehloToExecutable.cpp @@ -204,7 +204,7 @@ void StablehloToExecutableTask::populatePassManager( // For EmitC lowering, we rely on preserving control flow. Otherwise the C // code could be very unreadable. if (hostTarget != HostTarget::EmitC) - pm.addPass(createConvertSCFToCFPass()); + pm.addPass(mlir::createSCFToControlFlowPass()); pm.addPass(memref::createFoldMemRefAliasOpsPass()); pm.addPass(memref::createExpandOpsPass()); @@ -240,65 +240,6 @@ void StablehloToExecutableTask::populatePassManager( } } -mlirtrt::StatusOr> -StablehloToExecutableTask::compileStableHLOToExecutable( - CompilerClient &client, mlir::ModuleOp module, - const StablehloToExecutableOptions &options) { - if (client.getContext() != module->getContext()) - return getInternalErrorStatus("CompilerClient has a MLIRContext that is " - "different from the ModuleOp's MLIRContext"); - - LLVM_DEBUG({ - DBGS() << "compiling with options:\n"; - options.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - }); - -#ifndef NDEBUG - if (options.get().enableLLVMDebugFlag) { - SmallVector debugTypeLiterals = - llvm::map_to_vector(options.get().llvmDebugTypes, - [](const std::string &x) { return x.c_str(); }); - llvm::setCurrentDebugTypes(debugTypeLiterals.data(), - debugTypeLiterals.size()); - llvm::DebugFlag = true; - } -#endif - - std::string result; - llvm::raw_string_ostream ss(result); - options.print(ss); - ss.flush(); - StatusOr runner = - client.getCompilationTask( - llvm::StringRef(result).drop_front(1).drop_back(1), - /*enableDebugOptions=*/false); - if (!runner.isOk()) - return runner.getStatus(); - - // Setup pass manager - if (failed((*runner)->run(module))) - return getInternalErrorStatus( - "failed to run compilation on module with symbol name: {0}", - module.getName() ? *module.getName() : "no-symbol-name"); - - // Translate to Runtime Executable - FailureOr> exeStorage = - mlir::translateToRuntimeExecutable(module); - if (failed(exeStorage)) - return getStatusWithMsg(StatusCode::InternalError, - "failed to translate compiled MLIR module to a " - "MLIR-TensorRT runtime Executable"); - -#ifndef NDEBUG - // Turn debugging back off if we turned it on. - if (options.get().enableLLVMDebugFlag) - llvm::DebugFlag = false; -#endif - - return std::make_unique(std::move(*exeStorage)); -} - void mlirtrt::compiler::registerStableHloToExecutableTask() { registerCompilationTask( "stablehlo-to-executable", diff --git a/mlir-tensorrt/compiler/lib/Conversion/HostToEmitC/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Conversion/HostToEmitC/CMakeLists.txt index 0f88e4210..b264c5a94 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/HostToEmitC/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Conversion/HostToEmitC/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_tensorrt_library(MLIRTensorRTHostToEmitC MLIRTensorRTTensorRTToEmitC MLIRTensorRTExecutorDialect MLIRFuncToEmitC + MLIRFuncTransforms MLIRTensorRTCUDADialect MLIRTensorRTTensorRTRuntimeDialect MLIRTensorRTLLVMConversionUtils diff --git a/mlir-tensorrt/compiler/lib/Conversion/HostToEmitC/HostToEmitC.cpp b/mlir-tensorrt/compiler/lib/Conversion/HostToEmitC/HostToEmitC.cpp index 2681522a5..2aa1fd55a 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/HostToEmitC/HostToEmitC.cpp +++ b/mlir-tensorrt/compiler/lib/Conversion/HostToEmitC/HostToEmitC.cpp @@ -27,11 +27,13 @@ #include "mlir-tensorrt/Dialect/CUDA/IR/CUDADialect.h" #include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" #include "mlir-tensorrt/Dialect/TensorRTRuntime/IR/TensorRTRuntime.h" +#include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" #include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" @@ -935,15 +937,20 @@ namespace { /// organized by dialect below. template struct EmitCConversionPattern : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + + EmitCConversionPattern(const TypeConverter &typeConverter, + const DataLayout &dataLayout, MLIRContext *ctx, + PatternBenefit benefit = PatternBenefit(10)) + : OpConversionPattern(typeConverter, ctx, benefit), + dataLayout(dataLayout) {} MLIRContext *ctx{this->getContext()}; EmitCCallBuilders builders{ctx}; Type voidPtrType{builders.voidPtrType}; - - Type i8Type{IntegerType::get(ctx, 8)}; - Type i32Type{IntegerType::get(ctx, 32)}; - Type i64Type{IntegerType::get(ctx, 64)}; + IntegerType i8Type{IntegerType::get(ctx, 8)}; + IntegerType i32Type{IntegerType::get(ctx, 32)}; + IntegerType i64Type{IntegerType::get(ctx, 64)}; + const mlir::DataLayout &dataLayout; emitc::PointerType getPointerType(Type elementType) const { return emitc::PointerType::get(elementType); @@ -999,8 +1006,6 @@ struct EmitCConversionPattern : OpConversionPattern { "mtrt::make_unranked_descriptor", {rankVal, rankedDesc}) .getResult(0); } - - mlir::DataLayout dataLayout; }; //===----------------------------------------------------------------------===// @@ -1070,8 +1075,6 @@ struct TRTEnqueueConverter : EmitCConversionPattern { rewriter.eraseOp(op); return success(); } - - DataLayout dataLayout; }; //===----------------------------------------------------------------------===// @@ -1282,8 +1285,6 @@ struct CUDAAllocConverter : public EmitCConversionPattern { return success(); } - - mlir::DataLayout dataLayout; }; struct CudaDeallocConverter : public EmitCConversionPattern { @@ -1334,13 +1335,14 @@ struct CudaCopyConverter : public EmitCConversionPattern { EmitCMemRefDescriptor src(adaptor.getSource()); EmitCMemRefDescriptor dest(adaptor.getTarget()); Location loc = op.getLoc(); - Value srcStart = src.getMemRefBufferStart(rewriter, loc, dataLayout, + Value srcStart = src.getMemRefBufferStart(rewriter, loc, this->dataLayout, srcType.getElementType()); - Value destStart = dest.getMemRefBufferStart(rewriter, loc, dataLayout, + Value destStart = dest.getMemRefBufferStart(rewriter, loc, this->dataLayout, dstType.getElementType()); if (!isCopyStrided(srcType, dstType)) { - Value totalSize = src.getSizeInBytes(rewriter, loc, dataLayout, srcType); + Value totalSize = + src.getSizeInBytes(rewriter, loc, this->dataLayout, srcType); this->callOpaque(rewriter, loc, Type{}, "mtrt::cuda_copy", {adaptor.getStream(), srcStart, destStart, totalSize}); rewriter.eraseOp(op); @@ -1367,8 +1369,6 @@ struct CudaCopyConverter : public EmitCConversionPattern { rewriter.eraseOp(op); return success(); } - - DataLayout dataLayout; }; //===----------------------------------------------------------------------===// @@ -1557,6 +1557,22 @@ struct MemRefAllocOpLowering : public EmitCConversionPattern { } }; +struct MemrefCastOpLowering : public EmitCConversionPattern { + using EmitCConversionPattern::EmitCConversionPattern; + LogicalResult + matchAndRewrite(memref::CastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemRefType sourceType = dyn_cast(op.getSource().getType()); + MemRefType targetType = dyn_cast(op.getResult().getType()); + if (!sourceType || !targetType) + return failure(); + if (sourceType.getRank() != targetType.getRank()) + return failure(); + rewriter.replaceOp(op, adaptor.getSource()); + return success(); + } +}; + /// Convert `memref.dealloc` to MTRT C++ API calls. struct MemRefDeallocLowering : public EmitCConversionPattern { @@ -1620,7 +1636,6 @@ struct MemRefLoadOpLowering : public EmitCConversionPattern { rewriter.replaceOpWithNewOp(op, elementType, lval); return success(); } - DataLayout dataLayout; }; // Convert memref.load to C++. @@ -1648,7 +1663,6 @@ struct MemRefStoreOpLowering : public EmitCConversionPattern { rewriter.eraseOp(op); return success(); } - DataLayout dataLayout; }; /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index. @@ -1683,8 +1697,9 @@ class MemRefExtractAlignedPointerAsIndexConverter //===----------------------------------------------------------------------===// /// Populate EmitC type conversions and op conversion patterns. -static void populateEmitCConversionPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns) { +static void populateEmitCConversionPatternsAndLegality( + const DataLayout &dataLayout, TypeConverter &typeConverter, + ConversionTarget &target, RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); Type cuEngineType = emitc::OpaqueType::get(ctx, "nvinfer1::ICudaEngine"); Type cuEnginePtrType = emitc::PointerType::get(cuEngineType); @@ -1725,6 +1740,19 @@ static void populateEmitCConversionPatterns(TypeConverter &typeConverter, return emitc::OpaqueType::get(t.getContext(), name); }); + // Setup legality constraints. + target.addLegalOp(); + target.addLegalDialect(); + target.addIllegalDialect(); + target.addDynamicallyLegalOp([&typeConverter](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()); + }); + target.addDynamicallyLegalOp([&typeConverter](func::CallOp op) { + return typeConverter.isLegal(op->getResultTypes()) && + typeConverter.isLegal(op->getOperandTypes()); + }); + // clang-format off patterns.add< CUDAAllocConverter, @@ -1738,6 +1766,7 @@ static void populateEmitCConversionPatterns(TypeConverter &typeConverter, ExecutorPrintConverter, ExtractStridedMetadataOpLowering, MemRefAllocOpLowering, + MemrefCastOpLowering, MemRefDimOpLowering, MemRefDeallocLowering, MemRefExtractAlignedPointerAsIndexConverter, @@ -1745,8 +1774,13 @@ static void populateEmitCConversionPatterns(TypeConverter &typeConverter, MemRefReinterpretCastOpLowering, MemRefStoreOpLowering, TRTEnqueueConverter - >(typeConverter, patterns.getContext()); + >(typeConverter, dataLayout, patterns.getContext()); // clang-format on + mlir::populateSCFToEmitCConversionPatterns(patterns, typeConverter); + mlir::populateArithToEmitCPatterns(typeConverter, patterns); + mlir::populateFuncTypeConversionPatterns(typeConverter, patterns); + mlir::populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); } namespace { @@ -1755,29 +1789,6 @@ class HostToEmitCPass public: using Base::Base; - // Create the rewrite pattern set using all loaded dialects. - LogicalResult initialize(MLIRContext *context) final { - auto target = std::make_shared(*context); - auto typeConverter = std::make_shared(); - target->addLegalOp(); - target->addLegalDialect(); - - target->addIllegalDialect(); - - RewritePatternSet patterns_(context); - populateEmitCConversionPatterns(*typeConverter, patterns_); - mlir::populateSCFToEmitCConversionPatterns(patterns_, *typeConverter); - mlir::populateFuncToEmitCPatterns(*typeConverter, patterns_); - mlir::populateArithToEmitCPatterns(*typeConverter, patterns_); - - this->patterns = - std::make_shared(std::move(patterns_)); - this->target = target; - this->typeConverter = typeConverter; - return success(); - } - void runOnOperation() override { ModuleOp moduleOp = getOperation(); @@ -1802,7 +1813,17 @@ class HostToEmitCPass if (failed(converter.convert())) return signalPassFailure(); - if (failed(applyPartialConversion(moduleOp, *target, *patterns))) { + const DataLayoutAnalysis &dataLayoutAnalysis = + getAnalysis(); + const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(moduleOp); + + TypeConverter typeConverter; + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + populateEmitCConversionPatternsAndLegality(dataLayout, typeConverter, + target, patterns); + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { emitError(getOperation()->getLoc()) << "failed to apply conversion in " << getArgument(); return signalPassFailure(); @@ -1846,8 +1867,8 @@ class HostToEmitCPass //===----------------------------------------------------------------------===// // cleanup //===----------------------------------------------------------------------===// - RewritePatternSet patterns(moduleOp->getContext()); - patterns.add( + RewritePatternSet cleanupPatterns(moduleOp->getContext()); + cleanupPatterns.add( +[](emitc::CastOp op, PatternRewriter &rewriter) -> LogicalResult { // Eliminate useless casts. if (op.getType() == op.getOperand().getType()) { @@ -1884,11 +1905,7 @@ class HostToEmitCPass return failure(); }); - (void)mlir::applyPatternsGreedily(moduleOp, std::move(patterns)); + (void)mlir::applyPatternsGreedily(moduleOp, std::move(cleanupPatterns)); } - - std::shared_ptr patterns; - std::shared_ptr target; - std::shared_ptr typeConverter; }; } // namespace diff --git a/mlir-tensorrt/compiler/lib/Conversion/StablehloToScf/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Conversion/StablehloToScf/CMakeLists.txt index b77e9d3e7..836de4556 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/StablehloToScf/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Conversion/StablehloToScf/CMakeLists.txt @@ -5,12 +5,18 @@ add_mlir_tensorrt_library(MLIRTensorRTStablehloToSCF MLIRTensorRTConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithDialect MLIRDialectUtils MLIRIR + MLIRMathDialect MLIRPass MLIRRewrite MLIRSCFDialect + MLIRSCFTransforms MLIRTensorDialect + MLIRTensorDialect + MLIRTensorRTSCFDetensorizeLoops MLIRTransformUtils + StablehloLinalgTransforms StablehloOps ) \ No newline at end of file diff --git a/mlir-tensorrt/compiler/lib/Conversion/StablehloToScf/StablehloToScf.cpp b/mlir-tensorrt/compiler/lib/Conversion/StablehloToScf/StablehloToScf.cpp index bc4edf08d..c24c7ea69 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/StablehloToScf/StablehloToScf.cpp +++ b/mlir-tensorrt/compiler/lib/Conversion/StablehloToScf/StablehloToScf.cpp @@ -15,13 +15,22 @@ /// //===----------------------------------------------------------------------===// +#include "mlir-tensorrt/Transforms/Transforms.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/OneToNTypeConversion.h" +#include "stablehlo/conversions/linalg/transforms/MapStablehloToScalarOp.h" #include "stablehlo/dialect/StablehloOps.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" namespace mlir { #define GEN_PASS_DEF_CONVERTSTABLEHLOTOSCFPASS @@ -50,12 +59,10 @@ static void inlineStablehloRegionIntoSCFRegion(PatternRewriter &rewriter, /// Extracts a scalar from tensor with a single element. static Value extractScalarFromTensorValue(OpBuilder &b, Value tensor) { Location loc = tensor.getLoc(); - // If ranked tensor, first collapse shape. - if (cast(tensor.getType()).getRank() != 0) - tensor = b.create( - loc, tensor, SmallVector()); - - return b.create(loc, tensor, ValueRange()); + RankedTensorType rtt = cast(tensor.getType()); + SmallVector zeros(rtt.getRank(), + b.create(loc, 0)); + return b.create(loc, tensor, zeros); } namespace { @@ -185,7 +192,324 @@ struct ConvertCaseOp : public OpConversionPattern { return success(); } }; +} // namespace + +//===----------------------------------------------------------------------===// +// Code after this point is not part of the original MHLO pass. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// These patterns are meant to perform canonicalization and uplift of +// scf.while to scf.for after the conversion from stablehlo to scf. +//===----------------------------------------------------------------------===// + +/// Scalarize a `stablehlo.compare` op. +static Value scalarizeStablehloCompareOp(stablehlo::CompareOp op, + PatternRewriter &rewriter) { + auto scalarOperands = llvm::map_to_vector(op.getOperands(), [&](Value v) { + return extractScalarFromTensorValue(rewriter, v); + }); + return stablehlo::StablehloOpToStdScalarOp::mapOp( + op, op.getType().getElementType(), scalarOperands, &rewriter); +} + +namespace { + +/// Scalarize a `stablehlo.compare` op used by a `tensor.extract` op. +struct ScalarizeStablehloCompareUsedByExtractPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ExtractOp op, + PatternRewriter &rewriter) const override { + auto compareOp = op.getTensor().getDefiningOp(); + if (!compareOp || !compareOp.getType().hasStaticShape() || + compareOp.getType().getNumElements() != 1 || + !compareOp.getType().getElementType().isSignlessIntOrIndex()) + return failure(); + rewriter.setInsertionPoint(compareOp); + Value scalarCompare = scalarizeStablehloCompareOp(compareOp, rewriter); + rewriter.replaceOp(op, scalarCompare); + return success(); + } +}; +} // namespace + +static bool isScalarizable(Type type) { + if (auto rtt = dyn_cast(type)) + return rtt.hasStaticShape() && rtt.getNumElements() == 1; + return false; +} + +static FailureOr convertToScalar(Operation *op, + PatternRewriter &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + RankedTensorType rtt = cast(op->getResult(0).getType()); + SmallVector scalarOperands; + for (Value operand : op->getOperands()) + scalarOperands.push_back(extractScalarFromTensorValue(rewriter, operand)); + return llvm::TypeSwitch>(op) + .Case< + // clang-format off + stablehlo::AbsOp, + stablehlo::AddOp, + stablehlo::AndOp, + stablehlo::Atan2Op, + stablehlo::BitcastConvertOp, + stablehlo::CbrtOp, + stablehlo::CeilOp, + stablehlo::ClampOp, + stablehlo::ClzOp, + stablehlo::CompareOp, + stablehlo::ComplexOp, + stablehlo::ConvertOp, + stablehlo::CosineOp, + stablehlo::DivOp, + stablehlo::ExpOp, + stablehlo::Expm1Op, + stablehlo::FloorOp, + stablehlo::ImagOp, + stablehlo::IsFiniteOp, + stablehlo::Log1pOp, + stablehlo::LogOp, + stablehlo::LogisticOp, + stablehlo::MaxOp, + stablehlo::MinOp, + stablehlo::MulOp, + stablehlo::NegOp, + stablehlo::NotOp, + stablehlo::OrOp, + stablehlo::PopulationCountOp, + stablehlo::PowOp, + stablehlo::RealOp, + stablehlo::ReducePrecisionOp, + stablehlo::RemOp, + stablehlo::RoundNearestEvenOp, + stablehlo::RoundOp, + stablehlo::RsqrtOp, + stablehlo::SelectOp, + stablehlo::ShiftLeftOp, + stablehlo::ShiftRightArithmeticOp, + stablehlo::ShiftRightLogicalOp, + stablehlo::SignOp, + stablehlo::SineOp, + stablehlo::SqrtOp, + stablehlo::SubtractOp, + stablehlo::TanhOp, + stablehlo::XorOp + // clang-format on + >([&](auto op) -> FailureOr { + return stablehlo::StablehloOpToStdScalarOp::mapOp( + op, rtt.getElementType(), scalarOperands, &rewriter); + }) + .Default([](auto op) -> FailureOr { return failure(); }); +} + +namespace { +/// Scalarize operations which feed into the condition argument of +/// `scf.condition`. +struct ScalarizeWhileConditionProducers + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(scf::ConditionOp op, + PatternRewriter &rewriter) const override { + auto scfWhile = op->getParentOfType(); + if (!scfWhile || scfWhile.getBefore() != op->getParentRegion()) + return rewriter.notifyMatchFailure( + op, "op is not in the before region of a scf.while op"); + + Region &beforeRegion = scfWhile.getBefore(); + BackwardSliceOptions options{}; + options.inclusive = false; + options.omitUsesFromAbove = true; + options.omitBlockArguments = true; + options.filter = [&](Operation *op) { + return beforeRegion.isAncestor(op->getParentRegion()) && + (llvm::isa_and_present( + op->getDialect()) || + llvm::isa(op)); + }; + + SetVector producers; + getBackwardSlice(op.getCondition(), &producers, options); + + bool changed = false; + for (Operation *producer : producers) { + if (!isa_and_present( + producer->getDialect()) || + !producer->hasTrait() || + producer->getNumResults() != 1) + continue; + if (!isScalarizable(producer->getResult(0).getType())) + continue; + if (!llvm::all_of(producer->getOperandTypes(), isScalarizable)) + continue; + FailureOr scalarized = convertToScalar(producer, rewriter); + if (failed(scalarized)) + continue; + rewriter.setInsertionPointAfterValue(*scalarized); + rewriter.replaceOpWithNewOp( + producer, producer->getResult(0).getType(), scalarized.value()); + changed = true; + } + return success(changed); + } +}; +} // namespace + +/// Check if the add op is a valid induction variable increment. +static bool matchInductionVariableIncrement(stablehlo::AddOp op, + scf::WhileOp parentWhile) { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + if (matchPattern(lhs, m_Constant()) || matchPattern(rhs, m_Constant())) + return true; + Region *whileRegion = parentWhile->getParentRegion(); + return lhs.getParentRegion()->isAncestor(whileRegion) || + rhs.getParentRegion()->isAncestor(whileRegion); +} + +namespace { +/// Scalarize any `stablehlo.add` operations in the 'after' region of +/// a scf.while op. +struct ScalarizeStablehloAddOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(stablehlo::AddOp op, + PatternRewriter &rewriter) const override { + if (!op->hasOneUse()) + return rewriter.notifyMatchFailure( + op, "op has more than one use, cannot scalarize"); + auto extractUser = dyn_cast(*op->user_begin()); + if (!extractUser || !extractUser->hasOneUse() || + !isa(*extractUser->user_begin())) + return rewriter.notifyMatchFailure( + op, "op result is not extracted and yielded from region"); + + auto scfWhile = extractUser->getParentOfType(); + if (!scfWhile || scfWhile.getAfter() != op->getParentRegion()) + return rewriter.notifyMatchFailure( + op, "op is not in the after region of a scf.while op"); + + // One operand must be a constant or defined above in order to be + // considered as the loop step. + if (!matchInductionVariableIncrement(op, scfWhile)) + return rewriter.notifyMatchFailure( + op, "op is not a valid induction variable increment"); + + // Find a block argument that has been scalarized. + auto findBlockArgument = [](Value v) -> BlockArgument { + Value source{}; + if (matchPattern(v, + m_Op(matchers::m_Any(&source)))) + return dyn_cast(source); + return {}; + }; + BlockArgument arg = findBlockArgument(op.getLhs()); + if (!arg) + arg = findBlockArgument(op.getRhs()); + if (!arg || arg.getParentRegion() != scfWhile.getAfter()) + return rewriter.notifyMatchFailure( + op, "could not find block argument in after region"); + + // Check that the corresponding block argument in the `before` region feeds + // into a comparison. + Region &before = scfWhile.getBefore(); + if (arg.getArgNumber() >= before.getNumArguments() || + before.getArgument(arg.getArgNumber()).getType() != arg.getType()) + return rewriter.notifyMatchFailure( + op, "could not find block argument in before region"); + auto beforeArg = before.getArgument(arg.getArgNumber()); + if (!llvm::all_of(beforeArg.getUsers(), + llvm::IsaPred)) + return rewriter.notifyMatchFailure( + op, "block argument is not consumed by a comparison op"); + + // Check that the before region has a block argument in the same position + // and is consumed by a comparison op. + RankedTensorType rtt = op.getType(); + Type elementType = rtt.getElementType(); + if (!rtt.hasStaticShape() || rtt.getNumElements() != 1 || + !elementType.isSignlessIntOrIndex()) + return rewriter.notifyMatchFailure(op, "op is not a scalar add op"); + + auto scalarOperands = llvm::map_to_vector(op.getOperands(), [&](Value v) { + return extractScalarFromTensorValue(rewriter, v); + }); + + auto scalarAdd = + stablehlo::StablehloOpToStdScalarOp::mapOp( + op, elementType, scalarOperands, &rewriter); + auto fromElements = + rewriter.create(op.getLoc(), rtt, scalarAdd); + rewriter.replaceOp(op, fromElements); + return success(); + } +}; +} // namespace + +/// This is used by the SCF while detensorization patterns to determine whether +/// a block argument of the 'before' region should be scalarized. We want to +/// scalarize the block argument corresponding to the induction variable of the +/// for loop. It will have a user like `stablehlo.compare` or `tensor.extract`. +static bool shouldScalarizeWhileBeforeArg(BlockArgument arg, Value initOperand, + Value yieldOperand) { + return cast(arg.getType()) + .getElementType() + .isSignlessIntOrIndex() && + llvm::count_if(arg.getUsers(), + llvm::IsaPred) >= 1; +} + +/// This is used by the SCF while detensorization patterns to determine whether +/// a block argument of the 'after' region should be scalarized. We want to +/// scalarize the block argument corresponding to the induction variable of the +/// for loop. It will have a user like `stablehlo.add` or `tensor.extract`. +static bool shouldScalarizeWhileAfterArg(BlockArgument arg, Value condOperand, + Value result) { + RankedTensorType rtt = cast(arg.getType()); + auto whileOp = arg.getParentRegion()->getParentOfType(); + Region &before = whileOp.getBefore(); + if (before.getNumArguments() <= arg.getArgNumber() || + before.getArgument(arg.getArgNumber()).getType() != + rtt.getElementType() || + !llvm::all_of(before.getArgument(arg.getArgNumber()).getUsers(), + llvm::IsaPred)) + return false; + + auto condProducer = condOperand.getDefiningOp(); + if (!condProducer || condProducer.getElements().size() != 1 || + !isa(condProducer.getElements().front())) + return false; + return rtt.getElementType().isSignlessIntOrIndex() && + llvm::count_if(arg.getUsers(), + llvm::IsaPred) >= 1; +} + +/// Populates the patterns to uplift scf.while to scf.for. This requires +/// detensorization as well as the upstream uplift patterns. +static LogicalResult applyWhileToForUpliftPatterns(Operation *op) { + RewritePatternSet patterns(op->getContext()); + scf::populateUpliftWhileToForPatterns(patterns); + scf::WhileOp::getCanonicalizationPatterns(patterns, op->getContext()); + scf::IfOp::getCanonicalizationPatterns(patterns, op->getContext()); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); + tensor::FromElementsOp::getCanonicalizationPatterns(patterns, + patterns.getContext()); + tensor::ExtractOp::getCanonicalizationPatterns(patterns, + patterns.getContext()); + populateSCFDetensorizeWhilePatterns(patterns, shouldScalarizeWhileBeforeArg, + shouldScalarizeWhileAfterArg, + /*benefit=*/10); + patterns.add(op->getContext()); + return applyPatternsGreedily(op, std::move(patterns)); +} + +namespace { struct StablehloToScfPass : public impl::ConvertStablehloToScfPassBase { public: @@ -204,7 +528,13 @@ struct StablehloToScfPass std::move(patterns)))) { emitError(getOperation()->getLoc()) << "failed to apply patterns in " << getArgument(); - signalPassFailure(); + return signalPassFailure(); + } + if (failed(applyWhileToForUpliftPatterns(getOperation()))) { + emitError(getOperation()->getLoc()) + << "failed to apply while-to-for uplift patterns in " + << getArgument(); + return signalPassFailure(); } } }; diff --git a/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/CMakeLists.txt index c41f91c0e..a8468b020 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/CMakeLists.txt @@ -1,7 +1,8 @@ add_mlir_tensorrt_library(MLIRTensorRTStablehloToTensorRT - StablehloToTensorRT.cpp ControlFlowOps.cpp ChloToTensorRT.cpp + ReductionConversions.cpp + StablehloToTensorRT.cpp DEPENDS MLIRTensorRTConversionPassIncGen diff --git a/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/Matchers.h b/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/Matchers.h index 9d23d82fd..5ee4979e4 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/Matchers.h +++ b/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/Matchers.h @@ -257,26 +257,6 @@ class ConvertHloOpToTensorRTPattern : public ConvertToTensorRTPattern { : ConvertToTensorRTPattern(typeConverter, SourceOp::getOperationName(), benefit, context) {} - /// Wrappers around the ConversionPattern methods that pass the derived op - /// type. - LogicalResult match(Operation *op) const final { - return match(cast(op)); - } - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - if constexpr (SourceOp::hasProperties()) - return rewrite(cast(op), - OpAdaptor(operands, op->getAttrDictionary(), - cast(op).getProperties()), - rewriter); - rewrite(cast(op), OpAdaptor(operands, op->getAttrDictionary()), - rewriter); - } - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto sourceOp = cast(op); - rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); - } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { @@ -313,32 +293,10 @@ class ConvertHloOpToTensorRTPattern : public ConvertToTensorRTPattern { rewriter); } - /// Rewrite and Match methods that operate on the SourceOp type. These must be - /// overridden by the derived pattern class. - virtual LogicalResult match(SourceOp op) const { - (void)op; - llvm_unreachable("must override match or matchAndRewrite"); - } - virtual void rewrite(SourceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - (void)op; - (void)adaptor; - (void)rewriter; - llvm_unreachable("must override matchAndRewrite or a rewrite method"); - } - virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - SmallVector oneToOneOperands = - getOneToOneAdaptorOperands(adaptor.getOperands()); - rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); - } virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (failed(match(op))) - return failure(); - rewrite(op, adaptor, rewriter); - return success(); + llvm_unreachable("must override matchAndRewrite"); } virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, diff --git a/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/ReductionConversions.cpp b/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/ReductionConversions.cpp new file mode 100644 index 000000000..c96446910 --- /dev/null +++ b/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/ReductionConversions.cpp @@ -0,0 +1,529 @@ +//===- ReductionConversions.cpp -------------------------------------------===// +// +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// Implementation of pass to convert StableHLO reduction and contraction ops to +/// TensorRT dialect ops. +/// +//===----------------------------------------------------------------------===// +#include "Matchers.h" +#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h" +#include "mlir-tensorrt-dialect/Utils/ShapeUtils.h" +#include "mlir-tensorrt/Conversion/Patterns.h" +#include "mlir-tensorrt/Conversion/StablehloToTensorRT/StablehloToTensorRT.h" +#include "mlir-tensorrt/Conversion/TensorRTCommon/ConvertToTensorRTCommon.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/StablehloOps.h" + +#define DEBUG_TYPE "stablehlo-to-tensorrt" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +using namespace mlir; +using mlir::tensorrt::TensorValue; + +/// Drop the unit dimension at `dimToDrop` from each of `values`. +static SmallVector +createRankReducedResults(TensorRTConversionPatternRewriter &rewriter, + Location loc, ResultRange values, int64_t dimToDrop, + int64_t trtMajorVersion) { + assert(!values.empty()); + SmallVector result; + result.reserve(values.size()); + for (Value v : values) { + auto inputType = dyn_cast(v.getType()); + assert((!inputType || inputType.getDimSize(dimToDrop) == 1) && + "expected value to have unit dim to drop"); + auto rtt = RankedTensorType::Builder(inputType); + rtt.dropDim(dimToDrop); + Value collapsed = + rewriter.checkAndCreate(loc, Type(rtt), v); + result.push_back(collapsed); + } + return result; +} + +template +static LogicalResult matchAndReplaceStablehloArgMinMax( + stablehlo::ReduceOp op, TensorRTConversionPatternRewriter &rewriter, + Value operand, ArrayRef reductionDims, int64_t trtMajorVersion) { + if (!matchPattern(op, + matchers::detail::StablehloArgMinMaxReduceMatcher())) + return failure(); + auto argMinOrMaxOp = rewriter.checkAndCreate( + op.getLoc(), + /*input=*/operand, /*axis=*/reductionDims.front()); + // Rank reduce the results. + if (!argMinOrMaxOp) + return failure(); + SmallVector replacements = createRankReducedResults( + rewriter, op.getLoc(), argMinOrMaxOp.getResults(), reductionDims.front(), + trtMajorVersion); + rewriter.replaceOp(op, replacements); + return success(); +} + +/// Given a stablehlo reduction operation, convert to a `tensorrt.reduce` +/// operation if it is a simple reduction (e.g. sum, mul, max/min) that be +/// converted 1-1. Caller must do the replacement, this just creates the new +/// operation and returns the new value. +static FailureOr +convertSimpleReductions(TensorRTConversionPatternRewriter &rewriter, + stablehlo::ReduceOp op, ArrayRef reductionDim, + Value input, Value init, int64_t trtMajorVersion) { + // TODO: verify the init is the neutral value based on the op below. + if (!matchPattern(init, m_Constant())) + return failure(); + + Block *reduceBody = &op.getBody().front(); + auto termOp = cast(reduceBody->getTerminator()); + if (termOp->getNumOperands() != 1 || reduceBody->getNumArguments() != 2) + return failure(); + + Location loc = op.getLoc(); + Value retValue = termOp.getOperands()[0]; + auto bbLhs = matchers::m_Val(reduceBody->getArgument(0)); + auto bbRhs = matchers::m_Val(reduceBody->getArgument(1)); + + tensorrt::ReduceOperation reductionOp; + if (matchPattern(retValue, m_Op(bbLhs, bbRhs))) + reductionOp = tensorrt::ReduceOperation::kSUM; + else if (matchPattern(retValue, m_Op(bbLhs, bbRhs))) + reductionOp = tensorrt::ReduceOperation::kPROD; + else if (matchPattern(retValue, m_Op(bbLhs, bbRhs))) + reductionOp = tensorrt::ReduceOperation::kMIN; + else if (matchPattern(retValue, m_Op(bbLhs, bbRhs))) + reductionOp = tensorrt::ReduceOperation::kMAX; + else + return failure(); + + auto reduceOp = rewriter.checkAndCreate( + loc, op.getType(0), input, + /*reduceDims=*/ + reductionDim, + /*keepdims=*/false, reductionOp); + if (!reduceOp) + return failure(); + return reduceOp.getResult(); +} + +static FailureOr convertBooleanReductions(RewriterBase &rewriter, + stablehlo::ReduceOp op, + ArrayRef reductionDim, + Value input, Value init) { + Location loc = op.getLoc(); + // Create an int32 tensor types equivalent to the boolean tensor types. + auto originalInputType = cast(input.getType()); + auto originalResultType = cast(op->getResultTypes()[0]); + if (!originalResultType.getElementType().isInteger(1) || + !originalInputType.getElementType().isInteger(1)) + return failure(); + + RankedTensorType integerInputType = + RankedTensorType::Builder(originalInputType) + .setElementType(rewriter.getI32Type()); + RankedTensorType integerResultType = + RankedTensorType::Builder(originalResultType) + .setElementType(rewriter.getI32Type()); + + // Create the new reduction type. + Block *reduceBody = &op.getBody().front(); + auto termOp = cast(reduceBody->getTerminator()); + if (termOp->getNumOperands() != 1 || reduceBody->getNumArguments() != 2) + return failure(); + Value retValue = termOp.getOperands()[0]; + auto bbLhs = matchers::m_Val(reduceBody->getArgument(0)); + auto bbRhs = matchers::m_Val(reduceBody->getArgument(1)); + tensorrt::ReduceOperation reductionOpType; + if (matchPattern(retValue, m_Op(bbLhs, bbRhs))) + reductionOpType = tensorrt::ReduceOperation::kSUM; + else if (matchPattern(retValue, m_Op(bbLhs, bbRhs))) + reductionOpType = tensorrt::ReduceOperation::kPROD; + else + return failure(); + + // Cast i1 to i32. + Value i32Input = + rewriter.create(loc, integerInputType, input); + + auto reduceOp = rewriter.create( + loc, integerResultType, i32Input, + /*reduceDims=*/SmallVector{reductionDim}, + /*keepdims=*/false, reductionOpType); + // Cast i32 to i1. + return rewriter + .create(loc, originalResultType, + reduceOp.getResult()) + .getResult(); +} + +namespace { + +// Converts a `stablehlo.reduce` operation to a `tensorrt.reduce` operation. +struct ConvertReduceOp + : public ConvertHloOpToTensorRTPattern { + using ConvertHloOpToTensorRTPattern::ConvertHloOpToTensorRTPattern; + + LogicalResult + matchAndRewrite(stablehlo::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + TensorRTConversionPatternRewriter trtRewriter(rewriter, + targetTrtMajorVersion); + + Value operand = adaptor.getInputs().front(); + SmallVector reductionDims = llvm::to_vector(op.getDimensions()); + // Try to match and handle the ArgMin/ArgMax cases. + if (succeeded(matchAndReplaceStablehloArgMinMax< + tensorrt::ArgMaxOp, stablehlo::ComparisonDirection::GE>( + op, trtRewriter, operand, reductionDims, targetTrtMajorVersion))) + return success(); + if (succeeded(matchAndReplaceStablehloArgMinMax< + tensorrt::ArgMinOp, stablehlo::ComparisonDirection::LE>( + op, trtRewriter, operand, reductionDims, targetTrtMajorVersion))) + return success(); + + // Try to match the simpler reductions across a single input. + if (op.getInputs().size() != 1) + return rewriter.notifyMatchFailure(op, + "number of reduction inputs not 1"); + Value init = adaptor.getInitValues().front(); + + FailureOr replacement = + convertBooleanReductions(rewriter, op, reductionDims, operand, init); + if (succeeded(replacement)) { + trtRewriter.replaceOp(op, *replacement); + return success(); + } + + replacement = convertSimpleReductions(trtRewriter, op, reductionDims, + operand, init, targetTrtMajorVersion); + if (failed(replacement)) + return rewriter.notifyMatchFailure( + op, "could not do simple reduction transform"); + trtRewriter.replaceOp(op, *replacement); + return success(); + } +}; + +/// Convert `stablehlo.dot` to `tensorrt.matrix_multiply`. +/// TODO: clean since `dot` op is removed from stable hlo in the favor of +/// `dot_general`. +struct ConvertDot : public ConvertHloOpToTensorRTPattern { + using ConvertHloOpToTensorRTPattern< + stablehlo::DotOp>::ConvertHloOpToTensorRTPattern; + LogicalResult + matchAndRewrite(stablehlo::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + TensorRTConversionPatternRewriter trtRewriter(rewriter, + targetTrtMajorVersion); + + TensorType resultType = op.getType(); + tensorrt::MatrixOperation qualifierLhs = tensorrt::MatrixOperation::kNONE; + tensorrt::MatrixOperation qualifierRhs = tensorrt::MatrixOperation::kNONE; + auto lhsType = cast(adaptor.getLhs().getType()); + auto rhsType = cast(adaptor.getRhs().getType()); + if (lhsType.getRank() == 1) + qualifierLhs = tensorrt::MatrixOperation::kVECTOR; + if (rhsType.getRank() == 1) + qualifierRhs = tensorrt::MatrixOperation::kVECTOR; + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + auto replacement = trtRewriter.checkAndCreate( + op->getLoc(), resultType, lhs, rhs, qualifierLhs, qualifierRhs); + if (!replacement) + return failure(); + + return replaceWithCast(trtRewriter, op, replacement.getResult()); + } +}; + +struct EinsumHelper { + + EinsumHelper(stablehlo::DotGeneralOp op) + : dimNums(op.getDotDimensionNumbers()), op(op) {} + + FailureOr getEquation() { + const int64_t lhsRank = op.getLhs().getType().getRank(); + const int64_t rhsRank = op.getRhs().getType().getRank(); + FailureOr batchLetters = getBatchDimLetters(); + FailureOr contractionLetters = getContractionDimLetters(); + FailureOr lhsResultDimLetters = + getResultDimLetters(dimNums.getLhsBatchingDimensions(), + dimNums.getLhsContractingDimensions(), lhsRank); + FailureOr rhsResultDimLetters = + getResultDimLetters(dimNums.getRhsBatchingDimensions(), + dimNums.getRhsContractingDimensions(), rhsRank); + if (failed(batchLetters) || failed(contractionLetters) || + failed(lhsResultDimLetters) || failed(rhsResultDimLetters)) + return failure(); + + std::string equation; + emitOperandTerms(lhsRank, *batchLetters, *contractionLetters, + *lhsResultDimLetters, dimNums.getLhsBatchingDimensions(), + dimNums.getLhsContractingDimensions(), equation); + equation += ","; + emitOperandTerms(rhsRank, *batchLetters, *contractionLetters, + *rhsResultDimLetters, dimNums.getRhsBatchingDimensions(), + dimNums.getRhsContractingDimensions(), equation); + equation += "->"; + equation += *batchLetters + *lhsResultDimLetters + *rhsResultDimLetters; + return equation; + } + +private: + stablehlo::DotDimensionNumbersAttr dimNums; + stablehlo::DotGeneralOp op; + + static constexpr StringRef kTermPool = "abcdefghijklmnopqrstuvwxyz"; + + LogicalResult appendTerm(std::string &result) { + if (termPos >= kTermPool.size()) + return failure(); + result += kTermPool[termPos++]; + return success(); + } + + FailureOr getBatchDimLetters() { + std::string result = ""; + ArrayRef lhsBatchDims = dimNums.getLhsBatchingDimensions(); + for (int64_t _ : lhsBatchDims) { + if (failed(appendTerm(result))) + return failure(); + } + return result; + } + + FailureOr getContractionDimLetters() { + std::string result = ""; + ArrayRef rhsContractingDims = + dimNums.getRhsContractingDimensions(); + for (int64_t _ : rhsContractingDims) { + if (failed(appendTerm(result))) + return failure(); + } + return result; + } + + FailureOr getResultDimLetters(ArrayRef batchDims, + ArrayRef contractionDims, + int64_t rank) { + std::string result; + for (int64_t dim : llvm::seq(0, rank)) { + if (llvm::is_contained(batchDims, dim) || + llvm::is_contained(contractionDims, dim)) + continue; + if (failed(appendTerm(result))) + return failure(); + } + return result; + } + + void emitOperandTerms(int64_t rank, StringRef batchDimTerms, + StringRef contractionDimTerms, StringRef resultDimTerms, + ArrayRef batchDims, + ArrayRef contractionDims, + std::string &result) { + for (int64_t idx : llvm::seq(0, rank)) { + if (llvm::is_contained(batchDims, idx)) { + result += batchDimTerms.front(); + batchDimTerms = batchDimTerms.drop_front(); + continue; + } + if (llvm::is_contained(contractionDims, idx)) { + result += contractionDimTerms.front(); + contractionDimTerms = contractionDimTerms.drop_front(); + continue; + } + result += resultDimTerms.front(); + resultDimTerms = resultDimTerms.drop_front(); + } + assert(batchDimTerms.empty() && "expected all batch dim terms to be used"); + assert(contractionDimTerms.empty() && + "expected all contraction dim terms to be used"); + assert(resultDimTerms.empty() && + "expected all result dim terms to be used"); + } + + std::string batchDimLetters; + unsigned termPos = 0; +}; + +/// Convert `stablehlo.dot_general` to `tensorrt.einsum`. +struct ConvertDotGeneralToEinsum + : public ConvertHloOpToTensorRTPattern { + using ConvertHloOpToTensorRTPattern< + stablehlo::DotGeneralOp>::ConvertHloOpToTensorRTPattern; + LogicalResult + matchAndRewrite(stablehlo::DotGeneralOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + TensorRTConversionPatternRewriter trtRewriter(rewriter, + targetTrtMajorVersion); + + TensorType resultType = op.getType(); + // Determine the TRT equivalent qualifier. + auto lhs = cast(adaptor.getLhs()); + auto rhs = cast(adaptor.getRhs()); + TensorType lhsType = lhs.getType(); + + if (lhsType.getElementType().isInteger(32)) + return failure(); + + // 'stablehlo.dot_general' allows for promotion of the result element + // type. We treat this as equivalent to compute/accumulator element type + // being equal to the result type. In TensorRT, we have limited control + // over the accumulator element type, but you're supposed to be able to + // specify it using cast operaitons on the operands. + Type computeElementType = resultType.getElementType(); + if (computeElementType != lhsType.getElementType()) { + FailureOr castedLhs = + this->castTensor(trtRewriter, computeElementType, lhs); + FailureOr castedRhs = + this->castTensor(trtRewriter, computeElementType, rhs); + if (failed(castedLhs) || failed(castedRhs)) + return failure(); + lhs = std::move(*castedLhs); + rhs = std::move(*castedRhs); + } + + EinsumHelper helper(op); + FailureOr equation = helper.getEquation(); + if (failed(equation)) + return failure(); + + tensorrt::EinsumOp replacement = + trtRewriter.checkAndCreate( + op.getLoc(), resultType, ValueRange{lhs, rhs}, + trtRewriter.getStringAttr(*equation)); + if (!replacement) + return failure(); + return replaceWithCast(trtRewriter, op, replacement.getResult()); + } +}; + +/// Convert `stablehlo.dot_general` to `tensorrt.matrix_multiply`. +struct ConvertDotGeneralToMatrixMultiply + : public ConvertHloOpToTensorRTPattern { + using ConvertHloOpToTensorRTPattern< + stablehlo::DotGeneralOp>::ConvertHloOpToTensorRTPattern; + LogicalResult + matchAndRewrite(stablehlo::DotGeneralOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + TensorRTConversionPatternRewriter trtRewriter(rewriter, + targetTrtMajorVersion); + + stablehlo::DotDimensionNumbersAttr dimNums = op.getDotDimensionNumbers(); + ArrayRef lhsBatchDims = dimNums.getLhsBatchingDimensions(); + const int64_t numBatchDims = lhsBatchDims.size(); + const int64_t numContractionDims = + dimNums.getRhsContractingDimensions().size(); + + TensorType resultType = op.getType(); + // Determine the TRT equivalent qualifier. + tensorrt::MatrixOperation qualifierLhs = tensorrt::MatrixOperation::kNONE; + tensorrt::MatrixOperation qualifierRhs = tensorrt::MatrixOperation::kNONE; + auto lhs = cast(adaptor.getLhs()); + auto rhs = cast(adaptor.getRhs()); + TensorType lhsType = lhs.getType(); + TensorType rhsType = rhs.getType(); + + if (lhsType.getElementType().isInteger(32)) + return failure(); + + // 'stablehlo.dot_general' allows for promotion of the result element type. + // We treat this as equivalent to compute/accumulator element type being + // equal to the result type. In TensorRT, we have limited control over the + // accumulator element type, but you're supposed to be able to specify it + // using cast operaitons on the operands. + Type computeElementType = resultType.getElementType(); + if (computeElementType != lhsType.getElementType()) { + FailureOr castedLhs = + this->castTensor(trtRewriter, computeElementType, lhs); + FailureOr castedRhs = + this->castTensor(trtRewriter, computeElementType, rhs); + if (failed(castedLhs) || failed(castedRhs)) + return failure(); + lhs = std::move(*castedLhs); + rhs = std::move(*castedRhs); + } + + // We don't handle multiple contraction dims. + if (numContractionDims != 1) + return failure(); + + // We don't handle multiple outer product dimensions. + if (rhsType.getRank() > numBatchDims + numContractionDims + 1 || + lhsType.getRank() > numBatchDims + numContractionDims + 1) + return failure(); + + if (lhsType.getRank() == numBatchDims + numContractionDims + 1) { + if (dimNums.getLhsContractingDimensions().front() == + lhsType.getRank() - 1) + qualifierLhs = tensorrt::MatrixOperation::kNONE; + else if (dimNums.getLhsContractingDimensions().front() == + lhsType.getRank() - 2) + qualifierLhs = tensorrt::MatrixOperation::kTRANSPOSE; + else + return failure(); + // No explicit outer product dimension + } else if (lhsType.getRank() == numBatchDims + numContractionDims) { + qualifierLhs = tensorrt::MatrixOperation::kVECTOR; + } else { + return failure(); + } + + if (rhsType.getRank() == numBatchDims + numContractionDims + 1) { + if (dimNums.getRhsContractingDimensions().front() == + rhsType.getRank() - 1) + qualifierRhs = tensorrt::MatrixOperation::kTRANSPOSE; + else if (dimNums.getRhsContractingDimensions().front() == + rhsType.getRank() - 2) + qualifierRhs = tensorrt::MatrixOperation::kNONE; + else + return failure(); + } else if (rhsType.getRank() == numBatchDims + numContractionDims) { + qualifierRhs = tensorrt::MatrixOperation::kVECTOR; + } else { + return failure(); + } + auto replacement = trtRewriter.checkAndCreate( + op->getLoc(), resultType, lhs, rhs, qualifierLhs, qualifierRhs); + if (!replacement) + return failure(); + return replaceWithCast(trtRewriter, op, replacement.getResult()); + } +}; + +} // namespace + +void mlir::populateStablehloReductionAndContractionToTensorRtConversionPattern( + TensorRTTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit, PatternBenefit dotToEinsumBenefit) { + // clang-format off + patterns.add< + ConvertDot, + ConvertDotGeneralToMatrixMultiply, + ConvertReduceOp + >(typeConverter, patterns.getContext(), benefit); + patterns.add< + ConvertDotGeneralToEinsum + >(typeConverter, patterns.getContext(), dotToEinsumBenefit); + // clang-format on +} diff --git a/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp b/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp index 5ece68149..24ac1d2bd 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp +++ b/mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp @@ -352,315 +352,6 @@ struct SortToTopK : public ConvertHloOpToTensorRTPattern { }; } // namespace -/// Given a stablehlo reduction operation, convert to a `tensorrt.reduce` -/// operation if it is a simple reduction (e.g. sum, mul, max/min) that be -/// converted 1-1. Caller must do the replacement, this just creates the new -/// operation and returns the new value. -static FailureOr -convertSimpleReductions(TensorRTConversionPatternRewriter &rewriter, - stablehlo::ReduceOp op, ArrayRef reductionDim, - Value input, Value init, int64_t trtMajorVersion) { - // TODO: verify the init is the neutral value based on the op below. - if (!matchPattern(init, m_Constant())) - return failure(); - - Block *reduceBody = &op.getBody().front(); - auto termOp = cast(reduceBody->getTerminator()); - if (termOp->getNumOperands() != 1 || reduceBody->getNumArguments() != 2) - return failure(); - - Location loc = op.getLoc(); - Value retValue = termOp.getOperands()[0]; - auto bbLhs = matchers::m_Val(reduceBody->getArgument(0)); - auto bbRhs = matchers::m_Val(reduceBody->getArgument(1)); - - tensorrt::ReduceOperation reductionOp; - if (matchPattern(retValue, m_Op(bbLhs, bbRhs))) - reductionOp = tensorrt::ReduceOperation::kSUM; - else if (matchPattern(retValue, m_Op(bbLhs, bbRhs))) - reductionOp = tensorrt::ReduceOperation::kPROD; - else if (matchPattern(retValue, m_Op(bbLhs, bbRhs))) - reductionOp = tensorrt::ReduceOperation::kMIN; - else if (matchPattern(retValue, m_Op(bbLhs, bbRhs))) - reductionOp = tensorrt::ReduceOperation::kMAX; - else - return failure(); - - auto reduceOp = rewriter.checkAndCreate( - loc, op.getType(0), input, - /*reduceDims=*/ - reductionDim, - /*keepdims=*/false, reductionOp); - if (!reduceOp) - return failure(); - return reduceOp.getResult(); -} - -static FailureOr convertBooleanReductions(RewriterBase &rewriter, - stablehlo::ReduceOp op, - ArrayRef reductionDim, - Value input, Value init) { - Location loc = op.getLoc(); - // Create an int32 tensor types equivalent to the boolean tensor types. - auto originalInputType = cast(input.getType()); - auto originalResultType = cast(op->getResultTypes()[0]); - if (!originalResultType.getElementType().isInteger(1) || - !originalInputType.getElementType().isInteger(1)) - return failure(); - - RankedTensorType integerInputType = - RankedTensorType::Builder(originalInputType) - .setElementType(rewriter.getI32Type()); - RankedTensorType integerResultType = - RankedTensorType::Builder(originalResultType) - .setElementType(rewriter.getI32Type()); - - // Create the new reduction type. - Block *reduceBody = &op.getBody().front(); - auto termOp = cast(reduceBody->getTerminator()); - if (termOp->getNumOperands() != 1 || reduceBody->getNumArguments() != 2) - return failure(); - Value retValue = termOp.getOperands()[0]; - auto bbLhs = matchers::m_Val(reduceBody->getArgument(0)); - auto bbRhs = matchers::m_Val(reduceBody->getArgument(1)); - tensorrt::ReduceOperation reductionOpType; - if (matchPattern(retValue, m_Op(bbLhs, bbRhs))) - reductionOpType = tensorrt::ReduceOperation::kSUM; - else if (matchPattern(retValue, m_Op(bbLhs, bbRhs))) - reductionOpType = tensorrt::ReduceOperation::kPROD; - else - return failure(); - - // Cast i1 to i32. - Value i32Input = - rewriter.create(loc, integerInputType, input); - - auto reduceOp = rewriter.create( - loc, integerResultType, i32Input, - /*reduceDims=*/SmallVector{reductionDim}, - /*keepdims=*/false, reductionOpType); - // Cast i32 to i1. - return rewriter - .create(loc, originalResultType, - reduceOp.getResult()) - .getResult(); -} - -/// Drop the unit dimension at `dimToDrop` from each of `values`. -static SmallVector -createRankReducedResults(TensorRTConversionPatternRewriter &rewriter, - Location loc, ResultRange values, int64_t dimToDrop, - int64_t trtMajorVersion) { - assert(!values.empty()); - SmallVector result; - result.reserve(values.size()); - for (Value v : values) { - auto inputType = dyn_cast(v.getType()); - assert((!inputType || inputType.getDimSize(dimToDrop) == 1) && - "expected value to have unit dim to drop"); - auto rtt = RankedTensorType::Builder(inputType); - rtt.dropDim(dimToDrop); - Value collapsed = - rewriter.checkAndCreate(loc, Type(rtt), v); - result.push_back(collapsed); - } - return result; -} - -template -static LogicalResult matchAndReplaceStablehloArgMinMax( - stablehlo::ReduceOp op, TensorRTConversionPatternRewriter &rewriter, - Value operand, ArrayRef reductionDims, int64_t trtMajorVersion) { - if (!matchPattern(op, - matchers::detail::StablehloArgMinMaxReduceMatcher())) - return failure(); - auto argMinOrMaxOp = rewriter.checkAndCreate( - op.getLoc(), - /*input=*/operand, /*axis=*/reductionDims.front()); - // Rank reduce the results. - if (!argMinOrMaxOp) - return failure(); - SmallVector replacements = createRankReducedResults( - rewriter, op.getLoc(), argMinOrMaxOp.getResults(), reductionDims.front(), - trtMajorVersion); - rewriter.replaceOp(op, replacements); - return success(); -} - -namespace { -// Converts a `stablehlo.reduce` operation to a `tensorrt.reduce` operation. -struct ConvertReduceOp - : public ConvertHloOpToTensorRTPattern { - using ConvertHloOpToTensorRTPattern::ConvertHloOpToTensorRTPattern; - - LogicalResult - matchAndRewrite(stablehlo::ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - TensorRTConversionPatternRewriter trtRewriter(rewriter, - targetTrtMajorVersion); - - Value operand = adaptor.getInputs().front(); - SmallVector reductionDims = llvm::to_vector(op.getDimensions()); - // Try to match and handle the ArgMin/ArgMax cases. - if (succeeded(matchAndReplaceStablehloArgMinMax< - tensorrt::ArgMaxOp, stablehlo::ComparisonDirection::GE>( - op, trtRewriter, operand, reductionDims, targetTrtMajorVersion))) - return success(); - if (succeeded(matchAndReplaceStablehloArgMinMax< - tensorrt::ArgMinOp, stablehlo::ComparisonDirection::LE>( - op, trtRewriter, operand, reductionDims, targetTrtMajorVersion))) - return success(); - - // Try to match the simpler reductions across a single input. - if (op.getInputs().size() != 1) - return rewriter.notifyMatchFailure(op, - "number of reduction inputs not 1"); - Value init = adaptor.getInitValues().front(); - - FailureOr replacement = - convertBooleanReductions(rewriter, op, reductionDims, operand, init); - if (succeeded(replacement)) { - trtRewriter.replaceOp(op, *replacement); - return success(); - } - - replacement = convertSimpleReductions(trtRewriter, op, reductionDims, - operand, init, targetTrtMajorVersion); - if (failed(replacement)) - return rewriter.notifyMatchFailure( - op, "could not do simple reduction transform"); - trtRewriter.replaceOp(op, *replacement); - return success(); - } -}; - -/// Convert `stablehlo.dot` to `tensorrt.matrix_multiply`. -/// TODO: clean since `dot` op is removed from stable hlo in the favor of -/// `dot_general`. -struct ConvertDot : public ConvertHloOpToTensorRTPattern { - using ConvertHloOpToTensorRTPattern< - stablehlo::DotOp>::ConvertHloOpToTensorRTPattern; - LogicalResult - matchAndRewrite(stablehlo::DotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - TensorRTConversionPatternRewriter trtRewriter(rewriter, - targetTrtMajorVersion); - - TensorType resultType = op.getType(); - tensorrt::MatrixOperation qualifierLhs = tensorrt::MatrixOperation::kNONE; - tensorrt::MatrixOperation qualifierRhs = tensorrt::MatrixOperation::kNONE; - auto lhsType = cast(adaptor.getLhs().getType()); - auto rhsType = cast(adaptor.getRhs().getType()); - if (lhsType.getRank() == 1) - qualifierLhs = tensorrt::MatrixOperation::kVECTOR; - if (rhsType.getRank() == 1) - qualifierRhs = tensorrt::MatrixOperation::kVECTOR; - - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - auto replacement = trtRewriter.checkAndCreate( - op->getLoc(), resultType, lhs, rhs, qualifierLhs, qualifierRhs); - if (!replacement) - return failure(); - - return replaceWithCast(trtRewriter, op, replacement.getResult()); - } -}; - -/// Convert `stablehlo.dot_general` to `tensorrt.matrix_multiply`. -struct ConvertDotGeneral - : public ConvertHloOpToTensorRTPattern { - using ConvertHloOpToTensorRTPattern< - stablehlo::DotGeneralOp>::ConvertHloOpToTensorRTPattern; - LogicalResult - matchAndRewrite(stablehlo::DotGeneralOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - TensorRTConversionPatternRewriter trtRewriter(rewriter, - targetTrtMajorVersion); - - stablehlo::DotDimensionNumbersAttr dimNums = op.getDotDimensionNumbers(); - ArrayRef lhsBatchDims = dimNums.getLhsBatchingDimensions(); - const int64_t numBatchDims = lhsBatchDims.size(); - const int64_t numContractionDims = - dimNums.getRhsContractingDimensions().size(); - - TensorType resultType = op.getType(); - // Determine the TRT equivalent qualifier. - tensorrt::MatrixOperation qualifierLhs = tensorrt::MatrixOperation::kNONE; - tensorrt::MatrixOperation qualifierRhs = tensorrt::MatrixOperation::kNONE; - auto lhs = cast(adaptor.getLhs()); - auto rhs = cast(adaptor.getRhs()); - TensorType lhsType = lhs.getType(); - TensorType rhsType = rhs.getType(); - - if (lhsType.getElementType().isInteger(32)) - return failure(); - - // 'stablehlo.dot_general' allows for promotion of the result element type. - // We treat this as equivalent to compute/accumulator element type being - // equal to the result type. In TensorRT, we have limited control over the - // accumulator element type, but you're supposed to be able to specify it - // using cast operaitons on the operands. - Type computeElementType = resultType.getElementType(); - if (computeElementType != lhsType.getElementType()) { - FailureOr castedLhs = - this->castTensor(trtRewriter, computeElementType, lhs); - FailureOr castedRhs = - this->castTensor(trtRewriter, computeElementType, rhs); - if (failed(castedLhs) || failed(castedRhs)) - return failure(); - lhs = std::move(*castedLhs); - rhs = std::move(*castedRhs); - } - - // We don't handle multiple contraction dims. - if (numContractionDims != 1) - return failure(); - - // We don't handle multiple outer product dimensions. - if (rhsType.getRank() > numBatchDims + numContractionDims + 1 || - lhsType.getRank() > numBatchDims + numContractionDims + 1) - return failure(); - - if (lhsType.getRank() == numBatchDims + numContractionDims + 1) { - if (dimNums.getLhsContractingDimensions().front() == - lhsType.getRank() - 1) - qualifierLhs = tensorrt::MatrixOperation::kNONE; - else if (dimNums.getLhsContractingDimensions().front() == - lhsType.getRank() - 2) - qualifierLhs = tensorrt::MatrixOperation::kTRANSPOSE; - else - return failure(); - // No explicit outer product dimension - } else if (lhsType.getRank() == numBatchDims + numContractionDims) { - qualifierLhs = tensorrt::MatrixOperation::kVECTOR; - } else { - return failure(); - } - - if (rhsType.getRank() == numBatchDims + numContractionDims + 1) { - if (dimNums.getRhsContractingDimensions().front() == - rhsType.getRank() - 1) - qualifierRhs = tensorrt::MatrixOperation::kTRANSPOSE; - else if (dimNums.getRhsContractingDimensions().front() == - rhsType.getRank() - 2) - qualifierRhs = tensorrt::MatrixOperation::kNONE; - else - return failure(); - } else if (rhsType.getRank() == numBatchDims + numContractionDims) { - qualifierRhs = tensorrt::MatrixOperation::kVECTOR; - } else { - return failure(); - } - auto replacement = trtRewriter.checkAndCreate( - op->getLoc(), resultType, lhs, rhs, qualifierLhs, qualifierRhs); - if (!replacement) - return failure(); - return replaceWithCast(trtRewriter, op, replacement.getResult()); - } -}; -} // namespace - /// Given an expression try to find a single-character string from `termPool` /// that is not used in `expression`. Returns the index of the unused character. static FailureOr getUnusedTerm(StringRef expression, @@ -4372,7 +4063,9 @@ class ConvertStablehloToTensorRtPass return op.isPrivate() && op->hasAttr("plan.decomposition"); }); - populateStablehloToTensorRtConversionPattern(typeConverter, patterns); + populateStablehloToTensorRtConversionPattern( + typeConverter, patterns, {}, /*preferEinsum=*/preferEinsum); + populateStablehloControlFlowToTensorRtPatterns( typeConverter, patterns, convertLoops, convertConditionals); populateChloToTensorRtLegalityAndPatterns(typeConverter, target, @@ -4391,14 +4084,14 @@ class ConvertStablehloToTensorRtPass void mlir::populateStablehloToTensorRtConversionPattern( TensorRTTypeConverter &typeConverter, RewritePatternSet &patterns, - ShapeInfoCallbacks shapeInfoCallbacks) { + ShapeInfoCallbacks shapeInfoCallbacks, bool preferEinsum) { // Add larger patterns with a higher // benefit so that they run first. patterns.add( typeConverter, patterns.getContext(), PatternBenefit(100)); patterns.add< // Contraction Operations - ConvertDot, ConvertDotGeneral, ConvertEinsum, + ConvertEinsum, // Shape related operations ReshapeConverter, ConvertBroadcastInDim, ConvertDynamicBroadcastInDim, ConvertBroadcast, ConvertIota, ConvertDynamicIota, @@ -4445,9 +4138,9 @@ void mlir::populateStablehloToTensorRtConversionPattern( MAKE_UNARY_OP_CONVERTER(RoundNearestEvenOp, kROUND), #undef MAKE_UNARY_OP_CONVERTER - ConvertRemainder, ConvertReduceOp, TorchIndexSelectConverter, - ReverseConverter, PadConverter, DynamicPadConverter, CompareConverter, - ClampConverter, GetDimensionSizeConverter, UniformQuantizeConverter, + ConvertRemainder, TorchIndexSelectConverter, ReverseConverter, + PadConverter, DynamicPadConverter, CompareConverter, ClampConverter, + GetDimensionSizeConverter, UniformQuantizeConverter, UniformDequantizeConverter, HloUnaryOpToActivationConverter, @@ -4466,6 +4159,10 @@ void mlir::populateStablehloToTensorRtConversionPattern( >(typeConverter, patterns.getContext(), PatternBenefit(1)); // clang-format on + populateStablehloReductionAndContractionToTensorRtConversionPattern( + typeConverter, patterns, PatternBenefit(1), + /*dotToEinsumBenefit=*/PatternBenefit(preferEinsum ? 2 : 0)); + if (!shapeInfoCallbacks.isElementValueEqualToConstant) shapeInfoCallbacks.isElementValueEqualToConstant = [](TensorElementValue elementValue, diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/AllocTensors.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/AllocTensors.cpp index 96098513c..ec1062e2a 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/AllocTensors.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/AllocTensors.cpp @@ -208,8 +208,11 @@ struct RewriteFromElements : public OpRewritePattern { std::optional originalMemorySpace{}; if (auto constraint = - dyn_cast_or_null(op.getType().getEncoding())) + dyn_cast_or_null(op.getType().getEncoding())) { + if (constraint.isHostVisible()) + return failure(); originalMemorySpace = constraint.getValue(); + } // Create a host allocation and insert the elements. MemorySpace memorySpace = MemorySpace::host_pinned; @@ -318,9 +321,14 @@ struct TensorDeviceExtractRewriter if (lattice->getValue().isHostOnly()) return rewriter.notifyMatchFailure(op, "lattice value is host-only"); - Value source = op.getTensor(); + if (auto constraint = dyn_cast_if_present( + op.getTensor().getType().getEncoding())) { + if (constraint.isHostVisible()) + return rewriter.notifyMatchFailure( + op, "source tensor already is in a host-visible space"); + } - if (failed(replaceHostUsesWithHostAlloc(rewriter, source))) + if (failed(replaceHostUsesWithHostAlloc(rewriter, op.getTensor()))) return failure(); return success(); @@ -846,12 +854,24 @@ static LogicalResult rewriteFuncToDestinationPassingStyle( // Our action now depends on what kind of equivalent value we found. Operation *equivalentOp = equivalentValues.front().getDefiningOp(); + // Check if the user of `toReplace` is a `tensor.reshape` operation. + // Since `toReplace` is bufferizes to the equivalent of `tensor.reshape`, + // we can just try to replace the reshape instead. + if (equivalentOp->hasOneUse()) { + if (auto reshapeOp = + dyn_cast(*equivalentOp->user_begin())) { + if (reshapeOp.getSource() == equivalentOp->getResult(0)) { + equivalentOp = reshapeOp; + } + } + } + // A reshape or cast may be required if the equivalent value has a different // type than the new function argument. rewriter.setInsertionPointAfter(equivalentOp); FailureOr> reshaped = maybeReshapeOrCast( - rewriter, equivalentValues.front().getLoc(), replacement, - cast>(equivalentValues.front())); + rewriter, equivalentOp->getLoc(), replacement, + cast>(equivalentOp->getResult(0))); if (failed(reshaped)) return failure(); replacement = *reshaped; @@ -867,7 +887,6 @@ static LogicalResult rewriteFuncToDestinationPassingStyle( rewriter.replaceAllOpUsesWith(equivalentOp, replacement); continue; } - rewriter.setInsertionPoint(allocOp); rewriter.replaceOpWithNewOp( allocOp, allocOp.getCopy(), replacement); continue; @@ -878,6 +897,10 @@ static LogicalResult rewriteFuncToDestinationPassingStyle( rewriter.replaceAllUsesExcept(constOp, matOp.getResult(), matOp); continue; } + if (auto reshapeOp = dyn_cast(equivalentOp)) { + rewriter.replaceAllOpUsesWith(equivalentOp, replacement); + continue; + } llvm_unreachable("unexpected leaf operation kind"); } return success(); diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/AssignMemorySpaces.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/AssignMemorySpaces.cpp index b20e39c77..a5c72a39e 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/AssignMemorySpaces.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/AssignMemorySpaces.cpp @@ -23,6 +23,7 @@ //===----------------------------------------------------------------------===// #include "mlir-tensorrt-dialect/Analysis/TensorKindAnalysis.h" #include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" +#include "mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.h" #include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h" #include "mlir-tensorrt/Utils/ModuleUtils.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" @@ -36,6 +37,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/DialectResourceBlobManager.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -61,6 +63,9 @@ class GenericConvertSpace : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { + if (isa(op)) + return failure(); + SmallVector resultTypes; if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes))) return failure(); @@ -81,6 +86,67 @@ class GenericConvertSpace : public ConversionPattern { } }; +/// Apply special conversion logic for `bufferization.alloc_tensor` operations. +/// It has a `memory_space` attribute that acts as a constraint. +/// memory space of the allocated tensor. +class ConvertAllocTensorPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto originalType = dyn_cast(op.getType()); + if (!originalType) + return failure(); + + auto resultConstraint = + dyn_cast_or_null(originalType.getEncoding()); + auto opMemorySpaceConstraint = + dyn_cast_if_present(op.getMemorySpaceAttr()); + + auto expectedResultType = dyn_cast_if_present( + getTypeConverter()->convertType(op.getType())); + if (!expectedResultType) + return failure(); + + MemorySpaceAttr constraint = + opMemorySpaceConstraint ? opMemorySpaceConstraint : resultConstraint; + + RankedTensorType constraintedType = + constraint ? originalType.cloneWithEncoding(constraint) + : expectedResultType; + + if (adaptor.getCopy()) { + auto castToConstraint = rewriter.create( + op.getLoc(), constraintedType, adaptor.getCopy()); + auto castOp = rewriter.create( + op.getLoc(), expectedResultType, castToConstraint); + rewriter.replaceOp(op, castOp); + return success(); + } + + rewriter.modifyOpInPlace(op, [&]() { + op.getResult().setType(constraintedType); + op.setMemorySpaceAttr(constraintedType.getEncoding()); + op.getCopyMutable().clear(); + }); + rewriter.setInsertionPointAfter(op); + + auto newAllocOp = rewriter.create( + op.getLoc(), constraintedType, + /*dynamic_dimensions=*/adaptor.getDynamicSizes(), + /*copy=*/Value{}, + /*size_hint=*/Value{}, + /*memory_space=*/constraintedType.getEncoding()); + auto castOp = rewriter.create( + op.getLoc(), expectedResultType, newAllocOp.getResult()); + rewriter.replaceOp(op, castOp); + return success(); + } +}; + // A pattern that converts the type of the attribute used as an operand for // arith.constant class ConvertConstantPattern : public OpConversionPattern { @@ -112,189 +178,285 @@ class ConvertConstantPattern : public OpConversionPattern { }; } // namespace -/// Return true if the op is likely in a compute region, like the region of -/// `stablehlo.reduce` or `linalg.generic`. -static bool inComputeRegion(Operation *op) { - Operation *parent = op->getParentOp(); - while (parent) { - if (isa(parent)) - return false; - if (!isa(parent)) - return true; - parent = parent->getParentOp(); - } - return false; -} - namespace { -/// Use an explicit 'host_pinned' staging tensor to materialie the -/// 'from_elements' before creating explicitly moving it to the 'device' space. -/// Other optimization patterns below help avoid the host-device transfer when -/// possible. -struct FixUpFromElements : public OpRewritePattern { - FixUpFromElements(MLIRContext *ctx, const DataFlowSolver &solver, - PatternBenefit benefit = 1) - : OpRewritePattern(ctx, benefit), solver(solver) {} - - LogicalResult matchAndRewrite(tensor::FromElementsOp op, - PatternRewriter &rewriter) const override { - auto space = dyn_cast_or_null(op.getType().getEncoding()); - if (!space) - return failure(); - if (space.getValue() != plan::MemorySpace::device) - return failure(); - const TensorKindLattice *lattice = - solver.lookupState(op.getResult()); - if (!lattice || lattice->getValue().isUninitialized() || - !lattice->getValue().isHostVisible()) - return failure(); +/// A type converter that adds a MemorySpaceAttr the the encoding of tensor +/// types. TensorTypes are legal only if they have the required encoding. +class TensorEncodingConverter : public TypeConverter { +public: + TensorEncodingConverter(MLIRContext &context, plan::MemorySpace encoding) + : requiredMemorySpace{plan::MemorySpaceAttr::get(&context, encoding)} { + addConversion([&](Type type) -> std::optional { return type; }); + addConversion([&](RankedTensorType type) -> std::optional { + return type.cloneWithEncoding(requiredMemorySpace); + }); + addSourceMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + return builder.create(loc, resultType, inputs.front()); + }); + addTargetMaterialization([&](OpBuilder &builder, TypeRange resultTypes, + ValueRange inputs, + Location loc) -> SmallVector { + return { + builder + .create(loc, resultTypes.front(), inputs.front()) + .getResult()}; + }); + } - RankedTensorType originalType = op.getType(); - RankedTensorType newType = RankedTensorType::get( - originalType.getShape(), originalType.getElementType(), - MemorySpaceAttr::get(originalType.getContext(), - plan::MemorySpace::host_pinned)); - auto newOp = rewriter.create(op.getLoc(), newType, - op.getElements()); - Value deviceTensor = rewriter.create( - op.getLoc(), originalType.getShape(), originalType.getElementType(), - originalType.getEncoding()); - Value rematDevReplacement = - rewriter - .create( - op.getLoc(), originalType, newOp.getResult(), deviceTensor) - .getResult(); - rewriter.replaceOp(op, rematDevReplacement); - return success(); + /// Convert a function signature, accounting for constraints specified in the + /// arg/result attributes. + FunctionType convertFuncSignature(func::FuncOp func) const { + FunctionType funcType = func.getFunctionType(); + SmallVector newInputs, newResults; + for (unsigned i = 0, e = funcType.getNumInputs(); i != e; ++i) + newInputs.push_back(convertFuncSignatureElement(funcType.getInput(i), + func.getArgAttrDict(i))); + for (unsigned i = 0, e = funcType.getNumResults(); i != e; ++i) + newResults.push_back(convertFuncSignatureElement( + funcType.getResult(i), func.getResultAttrDict(i))); + return FunctionType::get(func.getContext(), newInputs, newResults); + } + +private: + /// Convert a single element (arg or result type) of a function signature. The + /// `dict` should contain the arg/result attributes or nullptr if not present. + Type convertFuncSignatureElement(Type type, DictionaryAttr dict) const { + if (auto rtt = dyn_cast(type)) { + if (auto constraint = + dict ? dyn_cast_if_present(dict.get( + PlanDialect::getMemorySpaceConstraintAttrName())) + : nullptr) + return rtt.cloneWithEncoding(constraint); + if (auto existing = + dyn_cast_if_present(rtt.getEncoding())) + return rtt; + } + return convertType(type); } - const DataFlowSolver &solver; + plan::MemorySpaceAttr requiredMemorySpace; }; +} // namespace -static bool isHostVisible(TypedValue v) { - auto space = dyn_cast_or_null(v.getType().getEncoding()); - if (!space) - return false; - switch (space.getValue()) { - case plan::MemorySpace::host: - case plan::MemorySpace::host_pinned: - case plan::MemorySpace::unified: - return true; - default: - return false; +/// Convert the block arguments for a single block where the RankedTensorTypes +/// may have received an updated encoding. +static void applySignatureConversion(RewriterBase &rewriter, Block *block, + const TensorEncodingConverter &converter, + TypeRange convertedTypes) { + OpBuilder::InsertionGuard g(rewriter); + assert(convertedTypes.size() == block->getNumArguments() && + "convertedTypes size mismatch"); + for (BlockArgument arg : block->getArguments()) { + Type origType = arg.getType(); + if (origType == convertedTypes[arg.getArgNumber()]) + continue; + auto castOp = rewriter.create( + arg.getLoc(), convertedTypes[arg.getArgNumber()], arg); + rewriter.replaceAllUsesExcept(arg, castOp, castOp); + arg.setType(convertedTypes[arg.getArgNumber()]); } } -/// For any 'shape' parameter of a 'tensor.reshape', get the shape by skipping -/// past any unnecessary explicit host-device transfers. -struct ReshapeAbsorbDeviceCast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(tensor::ReshapeOp op, - PatternRewriter &rewriter) const override { - if (isHostVisible(op.getShape())) - return failure(); - auto matOp = - op.getShape() - .getDefiningOp(); - if (!matOp) - return failure(); - auto source = dyn_cast>(matOp.getSource()); - if (!source || !isHostVisible(source)) - return failure(); - rewriter.modifyOpInPlace(op, - [&]() { op.getShapeMutable().assign(source); }); +/// Convert the block arguments for all Blocks in a function body where the +/// RankedTensorTypes may have received an updated encoding. +static LogicalResult +convertFuncRegionTypes(RewriterBase &rewriter, func::FuncOp funcOp, + const TensorEncodingConverter &converter, + FunctionType newType) { + if (funcOp.isDeclaration()) return success(); - } -}; -/// Rewrite `memref.load` that acts on device memory to first copy the buffer to -/// the host and load from the host buffer. -struct TensorDeviceExtractRewriter - : public OpRewritePattern { + Region *region = &funcOp.getBody(); - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractOp op, - PatternRewriter &rewriter) const override { - auto source = op.getTensor(); - if (isHostVisible(source)) + // Convert the arguments of each non-entry block within the region. + for (Block &block : + llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) { + rewriter.setInsertionPointToStart(&block); + // Compute the signature for the block with the provided converter. + std::optional conversion = + converter.convertBlockSignature(&block); + if (!conversion) return failure(); + // Convert the block with the computed signature. + applySignatureConversion(rewriter, &block, converter, + conversion->getConvertedTypes()); + } - if (inComputeRegion(op)) - return failure(); + rewriter.setInsertionPointToStart(&funcOp.getBody().front()); + applySignatureConversion(rewriter, &funcOp.getBody().front(), converter, + newType.getInputs()); - rewriter.setInsertionPointAfterValue(source); - Value hostTensor = rewriter.create( - op.getLoc(), - RankedTensorType::get( - source.getType().getShape(), source.getType().getElementType(), - plan::MemorySpaceAttr::get(op->getContext(), - plan::MemorySpace::host_pinned)), - source); - - rewriter.replaceUsesWithIf(op.getTensor(), hostTensor, [&](OpOperand &use) { - return isa(use.getOwner()); + return success(); +} + +/// Convert the operands and results of a function's callers after the `func` +/// has been updated to a new function signature type. The only types that can +/// change are RankedTensorTypes where the encoding has been updated. Therefore, +/// we only insert `tensor.cast` operations to cast the values back to their +/// original types. +struct LogicalResult convertFuncUsers(RewriterBase &rewriter, func::FuncOp func, + const SymbolUserMap &userMap) { + OpBuilder::InsertionGuard g(rewriter); + FunctionType funcType = func.getFunctionType(); + auto handleValue = [&](Value value, Type desiredType) -> Value { + if (value.getType() == desiredType) + return value; + return rewriter.create(value.getLoc(), desiredType, value); + }; + for (Operation *user : userMap.getUsers(func)) { + auto call = dyn_cast(user); + if (!call) + continue; + rewriter.setInsertionPoint(call); + SmallVector newOperands; + for (auto [newType, arg] : + llvm::zip_equal(funcType.getInputs(), call.getOperands())) + newOperands.push_back(handleValue(arg, newType)); + + rewriter.setInsertionPointAfter(call); + SmallVector replacements; + for (auto [newType, result] : + llvm::zip_equal(funcType.getResults(), call.getResults())) + replacements.push_back(handleValue(result, newType)); + + rewriter.modifyOpInPlace(call, [&]() { + call.getOperandsMutable().assign(newOperands); + for (auto [oldResult, replacement, newType] : llvm::zip_equal( + call.getResults(), replacements, funcType.getResults())) { + if (oldResult.getType() != newType) { + oldResult.setType(newType); + rewriter.replaceAllUsesExcept(oldResult, replacement, + replacement.getDefiningOp()); + } + } }); + } + return success(); +} +/// Conver the signature, block arguments, terminator operands, and caller +/// operands/results of a particular function by updating the types in place to +/// include the required memory space encodings. `tensor.cast` operations are +/// inserted to cast values back to their original types. +static LogicalResult +convertFuncOpTypes(func::FuncOp funcOp, + const TensorEncodingConverter &typeConverter, + RewriterBase &rewriter, const SymbolUserMap &userMap) { + FunctionType type = funcOp.getFunctionType(); + FunctionType newType = typeConverter.convertFuncSignature(funcOp); + if (type == newType) return success(); + if (failed(convertFuncRegionTypes(rewriter, funcOp, typeConverter, newType))) + return failure(); + rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); }); + + if (!funcOp.isDeclaration()) { + funcOp.walk([&](func::ReturnOp op) { + rewriter.setInsertionPoint(op); + SmallVector newTermOperands; + bool changed = false; + for (auto [newType, arg] : + llvm::zip_equal(newType.getResults(), op.getOperands())) { + if (arg.getType() == newType) { + newTermOperands.push_back(arg); + continue; + } + changed = true; + auto cast = rewriter.create(arg.getLoc(), newType, arg); + newTermOperands.push_back(cast); + } + if (!changed) + return; + rewriter.modifyOpInPlace( + op, [&]() { op.getOperandsMutable().assign(newTermOperands); }); + }); } -}; -/// Remap relevant analysis state of type T from `original` to `replacement`. -template -static void remapLatticeState(DataFlowSolver &solver, Value original, - Value replacement) { - if constexpr (!std::is_same_v) { - if (const T *lattice = solver.lookupState(original)) { - T *latticeReplacement = solver.getOrCreateState(replacement); - latticeReplacement->getValue() = lattice->getValue(); - } - } else { - // do nothing for liveness analysis for the moment except create the state - if (const auto *oldState = - solver.lookupState(original)) { - dataflow::Executable *newState = solver.getOrCreateState(replacement); - // Set to live if old state is live. We ignore change status. - if (oldState->isLive()) - (void)newState->setToLive(); - } - } + return convertFuncUsers(rewriter, funcOp, userMap); } -/// A rewrite listener that transfers replacements to updates to the solver -/// state. -class SolverStateListener : public RewriterBase::Listener { -public: - SolverStateListener(DataFlowSolver &solver) - : RewriterBase::Listener(), solver(solver) {} - -private: - void notifyOperationReplaced(Operation *op, - ValueRange replacements) override { - for (auto [original, replacement] : - llvm::zip_equal(op->getResults(), replacements)) { - remapLatticeState(solver, original, replacement); - remapLatticeState>( - solver, original, replacement); - remapLatticeState(solver, original, replacement); - } - solver.eraseState(solver.getProgramPointAfter(op)); - } - void notifyOperationReplaced(Operation *op, Operation *replacement) override { - notifyOperationReplaced(op, replacement->getResults()); - } +/// Get the default memory space for a particular function. +static plan::MemorySpace getFuncitonDefaultEncoding(func::FuncOp func) { + // The `plan.memory_space` attribute takes precedence over the cluster kind + // default memory space. + if (auto constraintOverride = func->getAttrOfType( + plan::PlanDialect::getMemorySpaceConstraintAttrName())) + return constraintOverride.getValue(); + if (auto clusterKindAttr = func->getAttrOfType( + plan::PlanDialect::kFuncTargetKind)) + return clusterKindAttr.getDefaultMemorySpace(); + return plan::MemorySpace::device; +} - void notifyOperationErased(Operation *op) override { - solver.eraseState(solver.getProgramPointAfter(op)); - for (Value res : op->getResults()) - solver.eraseState(res); +/// Convert the signatures of functions and their callers by adding the +/// appropriate memory space attribute to all tensor types. +static LogicalResult +assignMemorySpacesToFunctionBoundaries(IRRewriter &rewriter, ModuleOp module) { + SymbolTableCollection symbolTables; + SymbolUserMap symbolUserMap(symbolTables, module); + for (auto func : module.getOps()) { + plan::MemorySpace defaultEncoding = getFuncitonDefaultEncoding(func); + TensorEncodingConverter converter(*func.getContext(), defaultEncoding); + if (failed(convertFuncOpTypes(func, converter, rewriter, symbolUserMap))) + return failure(); } + return success(); +} - DataFlowSolver &solver; -}; +static bool hasMemorySpaceEncoding(Type type) { + auto tensorType = dyn_cast(type); + if (!tensorType) + return false; + return dyn_cast_if_present(tensorType.getEncoding()) != + nullptr; +} -} // namespace +static LogicalResult applyConversionToFunction(func::FuncOp func) { + MLIRContext *context = func.getContext(); + auto defaultEncoding = [&]() { + if (auto constraintOverride = func->getAttrOfType( + plan::PlanDialect::getMemorySpaceConstraintAttrName())) + return constraintOverride.getValue(); + if (auto clusterKindAttr = func->getAttrOfType( + plan::PlanDialect::kFuncTargetKind)) + return clusterKindAttr.getDefaultMemorySpace(); + return plan::MemorySpace::device; + }(); + TensorEncodingConverter converter(*context, defaultEncoding); + + // Ops are legal if they are in a nested module or if their operand and + // result types are legal. + ConversionTarget target(*context); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return converter.isLegal(op->getOperandTypes()) && + converter.isLegal(op->getResultTypes()); + }); + target.addDynamicallyLegalOp([&](arith::ConstantOp op) { + return converter.isLegal(op.getType()) && + converter.isLegal(op.getValue().getType()); + }); + target.addDynamicallyLegalOp([&](tensor::CastOp op) { + return hasMemorySpaceEncoding(op.getType()) && + hasMemorySpaceEncoding(op.getOperand().getType()); + }); + target.addDynamicallyLegalOp( + [&](bufferization::AllocTensorOp op) { + if (op.getCopy()) + return false; + return hasMemorySpaceEncoding(op.getType()); + }); + target.addLegalDialect(); + + RewritePatternSet patterns(context); + patterns.add(converter, context); + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, + target); + if (failed(applyFullConversion(func, target, std::move(patterns)))) + return emitError(func.getLoc(), "failed to assign memory spaces"); + return success(); +} namespace { struct AssignMemorySpacesPass @@ -303,96 +465,19 @@ struct AssignMemorySpacesPass void runOnOperation() override { MLIRContext *context = &getContext(); - ConversionTarget target(*context); - - TypeConverter converter; - converter.addConversion( - [&](Type type) -> std::optional { return type; }); - - // The default tensor type converter just adds the 'device' memory type - // info. - auto deviceEncoding = - plan::MemorySpaceAttr::get(context, plan::MemorySpace::device); - converter.addConversion([&](RankedTensorType type) -> std::optional { - if (type.getEncoding()) - return type; - return RankedTensorType::get(type.getShape(), type.getElementType(), - deviceEncoding); - }); - // Ops are legal if they are in a nested module or if their operand and - // result types are legal. - target.markUnknownOpDynamicallyLegal([&](Operation *op) { - if (op->getParentWithTrait() != getOperation()) - return true; - return converter.isLegal(op->getOperandTypes()) && - converter.isLegal(op->getResultTypes()); - }); - target.addDynamicallyLegalOp([&](func::FuncOp op) { - if (op->getParentWithTrait() != getOperation()) - return true; - return converter.isSignatureLegal(op.getFunctionType()); - }); - target.markOpRecursivelyLegal( - [&](func::FuncOp op) -> std::optional { - if (op->getParentWithTrait() != getOperation()) - return true; - return false; - }); - target.addDynamicallyLegalOp([&](arith::ConstantOp op) { - if (op->getParentWithTrait() != getOperation()) - return true; - return converter.isLegal(op.getType()) && - converter.isLegal(op.getValue().getType()); - }); - - RewritePatternSet patterns(&getContext()); - patterns.add(converter, - context); - - // FuncOp is special as it has type encoding via attributes. - populateFunctionOpInterfaceTypeConversionPattern(patterns, - converter); - scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, - target); + IRRewriter rewriter(context); - auto module = getOperation(); - if (failed(applyFullConversion(module, target, std::move(patterns)))) { - emitError(module.getLoc(), "failed to assign memory spaces"); + /// Update all function signatures and their callers to include the required + /// memory space encodings. + if (failed( + assignMemorySpacesToFunctionBoundaries(rewriter, getOperation()))) return signalPassFailure(); - } - // Perform some minor optimizations involving tensor.from_elements. - { - SymbolTableCollection symbolTables; - DataFlowSolver solver(DataFlowConfig().setInterprocedural(false)); - solver.load(); - solver.load(); - solver.load(symbolTables); - - if (failed(solver.initializeAndRun(getOperation()))) { - emitError(getOperation().getLoc()) - << "failed to run TensorKindAnalysis"; + for (auto func : + llvm::make_early_inc_range(getOperation().getOps())) { + if (failed(applyConversionToFunction(func))) return signalPassFailure(); - } - - SolverStateListener solverAwareListener(solver); - GreedyRewriteConfig config; - config.listener = &solverAwareListener; - FrozenRewritePatternSet patterns = [&]() { - RewritePatternSet patterns_(&getContext()); - patterns_.insert(&getContext(), solver); - patterns_.insert(&getContext()); - patterns_.insert(&getContext()); - return patterns_; - }(); - for (FunctionOpInterface func : - getOperation().getOps()) { - if (failed(applyPatternsGreedily(func, patterns))) { - emitError(func.getLoc()) << "failed to run " << getArgument(); - return signalPassFailure(); - } - } } } }; diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt index b9705d975..87f664670 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt @@ -6,17 +6,21 @@ add_mlir_tensorrt_library(MLIRTensorRTPlanTransforms CreateClosedRegions.cpp CreateShapeFuncs.cpp EliminateShapeOps.cpp + MaterializeExplicitTransfers.cpp MaterializeShapeCalculations.cpp MaterializeShapeCalculationsStablehlo.cpp + ModuleBufferization/BufferResultsToOutParams.cpp ModuleBufferization/ModuleBufferization.cpp ModuleBufferization/ModuleBufferizationAnalysis.cpp ModuleBufferization/ModuleBufferizationUtils.cpp ModuleBufferization/RemoveEquivalentBufferResults.cpp + OptimizeMemorySpaces.cpp OutlineClusters.cpp OutlineConstantFoldableSubgraphs.cpp Passes.cpp PopulateFunctionBoundsAttributes.cpp PostClusteringValidation.cpp + PromoteHostTensorsToHostPinned.cpp RefineTypes.cpp DEPENDS diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CreateShapeFuncs.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CreateShapeFuncs.cpp index 3c22041d4..4a6f0d958 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CreateShapeFuncs.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CreateShapeFuncs.cpp @@ -524,17 +524,13 @@ static FailureOr createAggregateShapeFunc( fromElementsOperands)); } - // Make sure to mark that the shape function arg and results as host tensors. - // The TensorKindAnalysis currently doesn't inspect encoding attributes. - for (unsigned i = 0; i < aggregateShapeFunc.getNumResults(); i++) - aggregateShapeFunc.setResultAttr(i, getHostTensorArgAttrName(), - rewriter.getUnitAttr()); - for (unsigned i = 0; i < aggregateShapeFunc.getNumArguments(); i++) - aggregateShapeFunc.setArgAttr(i, getHostTensorArgAttrName(), - rewriter.getUnitAttr()); - rewriter.create(func.getLoc(), shapeFuncReturns); + // Mark the function as having default memory space of 'host' + aggregateShapeFunc->setAttr( + PlanDialect::getMemorySpaceConstraintAttrName(), + plan::MemorySpaceAttr::get(rewriter.getContext(), MemorySpace::host)); + return aggregateShapeFunc; } diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeExplicitTransfers.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeExplicitTransfers.cpp new file mode 100644 index 000000000..9f737a631 --- /dev/null +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeExplicitTransfers.cpp @@ -0,0 +1,181 @@ +//===- MaterializeExplicitTransfers.cpp -----------------------------------===// +// +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// Implementation of the `plan-materialize-explicit-transfers` pass. +/// +//===----------------------------------------------------------------------===// +#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::plan { +#define GEN_PASS_DEF_PLANMATERIALIZEEXPLICITTRANSFERSPASS +#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h.inc" +} // namespace mlir::plan + +using namespace mlir; +using namespace mlir::plan; + +/// Get the dynamic dimensions for a shaped value. +static FailureOr> getDynamicDims(OpBuilder &b, Location loc, + Value tensor) { + if (!llvm::isa(tensor.getType())) + return failure(); + + RankedTensorType tensorType = llvm::cast(tensor.getType()); + SmallVector dynamicSizes; + // Compute the dynamic part of the shape. + // First try to query the shape via ReifyRankedShapedTypeOpInterface. + if (llvm::isa(tensor)) { + ReifiedRankedShapedTypeDims resultDims; + if (succeeded(reifyResultShapes(b, tensor.getDefiningOp(), resultDims))) { + const SmallVector &shape = + resultDims[llvm::cast(tensor).getResultNumber()]; + for (const auto &dim : enumerate(tensorType.getShape())) + if (ShapedType::isDynamic(dim.value())) + dynamicSizes.push_back(cast(shape[dim.index()])); + return dynamicSizes; + } + } + + // If the shape could not be reified, create DimOps. + bufferization::populateDynamicDimSizes(b, loc, tensor, dynamicSizes); + return dynamicSizes; +} + +namespace { + +struct RewriteAllocTensorWithCopyToCastPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(bufferization::AllocTensorOp op, + PatternRewriter &rewriter) const override { + auto copySource = op.getCopy(); + if (!copySource) + return failure(); + auto copySourceType = dyn_cast(copySource.getType()); + if (!copySourceType) + return failure(); + + auto memorySpace = + llvm::dyn_cast_or_null(op.getMemorySpaceAttr()); + if (!memorySpace) + return failure(); + + auto newType = op.getType().cloneWithEncoding(memorySpace); + + auto castOp = + rewriter.create(op.getLoc(), newType, copySource); + rewriter.replaceOpWithNewOp(op, newType, castOp); + return success(); + } +}; + +struct TensorCastToAllocAndCopyPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::CastOp op, + PatternRewriter &rewriter) const override { + + auto sourceType = cast(op.getOperand().getType()); + auto targetType = cast(op.getType()); + // Source/target could be unranked for `tensor.cast`. + if (!sourceType || !targetType) + return failure(); + + auto sourceSpace = + dyn_cast_or_null(sourceType.getEncoding()); + auto targetSpace = + dyn_cast_or_null(targetType.getEncoding()); + + if (!sourceSpace || !targetSpace || sourceSpace == targetSpace) + return rewriter.notifyMatchFailure( + op, "skipping no space encoding or same space"); + + FailureOr> dynamicDims = + getDynamicDims(rewriter, op.getLoc(), op.getOperand()); + if (failed(dynamicDims)) + return failure(); + + auto allocOp = rewriter.create( + op.getLoc(), targetType, *dynamicDims, /*copy=*/Value{}, + /*size_hint=*/Value{}, + /*memory_space=*/targetSpace); + + rewriter.replaceOpWithNewOp( + op, targetType, op.getOperand(), allocOp.getResult()); + + return success(); + } +}; + +/// Remove redundant explicit `bufferization.materialize_in_dest` ops. +struct RemoveRedundantMaterializeInDestPattern + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(bufferization::MaterializeInDestinationOp op, + PatternRewriter &rewriter) const override { + auto producer = + op.getDest().getDefiningOp(); + if (!producer) + return failure(); + + if (producer.getDest().getType() != op.getDest().getType()) + return failure(); + + rewriter.modifyOpInPlace( + op, [&]() { op.getDestMutable().assign(producer.getDest()); }); + return success(); + } +}; + +class MaterializeExplicitTransfersPass + : public plan::impl::PlanMaterializeExplicitTransfersPassBase< + MaterializeExplicitTransfersPass> { + void runOnOperation() override { + Operation *op = getOperation(); + + // Eliminate `bufferization.alloc_tensor` ops with `copy` argument and + // simplify `tensor.cast` ops. + { + RewritePatternSet patterns(op->getContext()); + patterns.add(op->getContext()); + tensor::CastOp::getCanonicalizationPatterns(patterns, op->getContext()); + if (failed(applyPatternsGreedily(op, std::move(patterns)))) { + emitError(op->getLoc()) + << "failed to apply patterns in " << getArgument(); + return; + } + } + + RewritePatternSet patterns(op->getContext()); + patterns.add(op->getContext()); + + if (failed(applyPatternsGreedily(op, std::move(patterns)))) { + emitError(op->getLoc()) + << "failed to apply patterns in " << getArgument(); + return; + } + } +}; +} // namespace diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/ModuleBufferization/BufferResultsToOutParams.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/ModuleBufferization/BufferResultsToOutParams.cpp new file mode 100644 index 000000000..9a418f25e --- /dev/null +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/ModuleBufferization/BufferResultsToOutParams.cpp @@ -0,0 +1,517 @@ +//===- BufferResultsToOutParams.cpp ---------------------------------------===// +// +// Modified from upstream 'BufferResultsToOutParams.cpp', part of the LLVM +// Project, under the Apache License v2.0 with LLVM Exceptions. See +// https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Changes: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// SPDX-FileCopyrightText: All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +/// +/// Implementation of the Plan buffer results to out params pass. +/// +//===----------------------------------------------------------------------===// +#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" +#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" + +namespace mlir::plan { +#define GEN_PASS_DEF_PLANBUFFERRESULTSTOOUTPARAMSPASS +#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h.inc" +} // namespace mlir::plan + +using namespace mlir; +using namespace mlir::bufferization; + +/// Visit all the return ops in the function. If the visitTerminator returns +/// failure, then the traversal is interrupted. +static LogicalResult +visitReturnOps(func::FuncOp func, + function_ref visitTerminator) { + for (Block &block : func.getBody()) { + if (auto terminator = dyn_cast(block.getTerminator())) { + if (failed(visitTerminator(terminator))) + return failure(); + } + } + return success(); +} + +/// Return a vector which maps results to block argument indices if, for each +/// returned value index, all `return` ops in the func uniformly return the same +/// block argument. +static SmallVector> +getReturnedBlockArgs(func::FuncOp func) { + SmallVector> returnsBlockArgs(func.getNumResults(), + std::nullopt); + for (unsigned i = 0, e = func.getNumResults(); i < e; ++i) { + if (failed(visitReturnOps(func, [&](func::ReturnOp op) -> LogicalResult { + auto blockArg = dyn_cast(op.getOperand(i)); + if (!blockArg || blockArg.getOwner()->getParentOp() != func) + return success(); + std::optional &blockArgIdx = returnsBlockArgs[i]; + if (!blockArgIdx) { + blockArgIdx = blockArg.getArgNumber(); + return success(); + } + if (*blockArgIdx != blockArg.getArgNumber()) + return failure(); + return success(); + }))) { + returnsBlockArgs[i] = std::nullopt; + continue; + } + } + return returnsBlockArgs; +} + +/// Given a memref value, return the "base" value by skipping over all +/// ViewLikeOpInterface ops (if any) in the reverse use-def chain. +static Value getViewBase(Value value) { + while (auto viewLikeOp = value.getDefiningOp()) + value = viewLikeOp.getViewSource(); + return value; +} + +/// Return "true" if the given values are guaranteed to be different (and +/// non-aliasing) allocations based on the fact that one value is the result +/// of an allocation and the other value is a block argument of a parent block. +/// Note: This is a best-effort analysis that will eventually be replaced by a +/// proper "is same allocation" analysis. This function may return "false" even +/// though the two values are distinct allocations. +static bool distinctAllocAndBlockArgument(Value v1, Value v2) { + Value v1Base = getViewBase(v1); + Value v2Base = getViewBase(v2); + auto areDistinct = [](Value v1, Value v2) { + if (Operation *op = v1.getDefiningOp()) + if (hasEffect(op, v1)) + if (auto bbArg = dyn_cast(v2)) + if (bbArg.getOwner()->findAncestorOpInBlock(*op)) + return true; + return false; + }; + return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base); +} + +/// Checks if `memref` may potentially alias a MemRef in `otherList`. It is +/// often a requirement of optimization patterns that there cannot be any +/// aliasing memref in order to perform the desired simplification. +static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis, + ValueRange otherList, Value memref) { + for (auto other : otherList) { + if (!isa(other.getType())) + continue; + if (distinctAllocAndBlockArgument(other, memref)) + continue; + std::optional analysisResult = + analysis.isSameAllocation(other, memref); + if (!analysisResult.has_value() || analysisResult == true) + return true; + } + return false; +} + +/// Check whether a MemRefValue is produced by a set of "hoistable" operations. +/// The operations are hoistable if they are composed of a closed (less function +/// arguments) set of operations involving an allocation and zero or more pure +/// operations. +static FailureOr> +getHoistableOperations(Value value, func::FuncOp func) { + llvm::SetVector slice; + BackwardSliceOptions sliceOptions{}; + sliceOptions.omitBlockArguments = false; + sliceOptions.omitUsesFromAbove = false; + sliceOptions.inclusive = true; + sliceOptions.filter = [](Operation *op) { + return isa(op) || mlir::isPure(op); + }; + mlir::getBackwardSlice(value, &slice, sliceOptions); + + // Slice must be closed. + bool hasAllocation = false; + for (Operation *op : slice) { + if (isa(op)) + hasAllocation = true; + for (Value operand : op->getOperands()) { + if (auto blockArg = dyn_cast(operand)) { + if (blockArg.getOwner()->getParentOp() == func) + continue; + } + if (!slice.contains(operand.getDefiningOp())) + return failure(); + } + } + + if (slice.empty() || !hasAllocation) + return failure(); + + return slice; +} + +/// Return true if the result at `resultIdx` can be promoted to an argument. A +/// result can be promoted to an argument if: +/// - It is a MemRefType +/// - For all return ops in the function, it does not potentially alias any +/// other returned operand. +/// +/// The specific method in which the result is promoted to an argument (e.g. +/// via hoisting allocation or via copy insertion) is not determined here. +static bool checkAliasingPrecondition(func::FuncOp func, unsigned resultIdx, + BufferOriginAnalysis &analysis) { + MemRefType resultType = + dyn_cast(func.getResultTypes()[resultIdx]); + if (!resultType) + return false; + + return succeeded(visitReturnOps(func, [&](func::ReturnOp term) { + SmallVector otherOperands(term->getOperands()); + otherOperands.erase(otherOperands.begin() + resultIdx); + return success(!potentiallyAliasesMemref(analysis, otherOperands, + term.getOperand(resultIdx))); + })); +} + +namespace { +struct ResultPromotionPlan { + + /// Specifies which return values can be dropped and replaced with a block + /// argument without inserting any copies by inserting a single + /// statically-sized allocation at each call site. + BitVector simplyHoistableAllocations; + /// Specifies which return values can be dropped and replaced with a block + /// argument by cloning a tree of hoistable operations at each call site. + /// This requires that each returned operand at the corresponding position in + /// all return ops is the same value. + BitVector hoistableAllocations; + + /// Specifies which return values can be dropped by inserting a new block + /// argument + a copy at each function return. + BitVector promotableToCopyOut; + + /// Specifies which return values are already block arguments. + SmallVector> returnsExistingBlockArg; + + /// The union of `simplyHoistableAllocations`, `hoistableAllocations`, and + /// `returnsExistingBlockArg`. + BitVector resultsToDrop; + + /// Specifies the tree of operations to hoist for each return value that is + /// set true in `hoistableAllocations`. + SmallVector> operationsToHoist; +}; +} // namespace + +/// Returns true if a returned value is "simply hoistable", meaning it is +/// directly produced by an allocation op and has static shape and identity +/// layout. +static bool isSimplyHoistableAllocation(Value value) { + return isa_and_nonnull( + value.getDefiningOp()) && + cast(value.getType()).hasStaticShape() && + cast(value.getType()).getLayout().isIdentity(); +} + +/// Constructs a "ResultPromotionPlan" by identifying all results which can be +/// dropped and passed as MemRef arguments instead. It updates the func op type +/// and entry block arguments. +static FailureOr +updateFuncOp(RewriterBase &rewriter, func::FuncOp func, + SmallVectorImpl &appendedEntryArgs, + BufferOriginAnalysis &analysis) { + auto functionType = func.getFunctionType(); + + ResultPromotionPlan plan{/*simplyHoistableAllocations=*/ + BitVector(functionType.getNumResults(), false), + /*hoistableAllocations=*/ + BitVector(functionType.getNumResults(), false), + /*promotableToCopyOut=*/ + BitVector(functionType.getNumResults(), false), + /*returnsExistingBlockArg=*/ + getReturnedBlockArgs(func), + /*resultsToPromote=*/ + BitVector(functionType.getNumResults(), false), + SmallVector>()}; + + // Collect information about the results will become appended arguments. + SmallVector newBlockArgTypes; + for (auto [idx, resultType] : llvm::enumerate(functionType.getResults())) { + + auto memrefType = dyn_cast(resultType); + if (!memrefType) + continue; + + if (plan.returnsExistingBlockArg[idx]) { + plan.resultsToDrop.set(idx); + continue; + } + + if (!checkAliasingPrecondition(func, idx, analysis)) { + continue; + } + + std::optional> hoistableOperations{}; + bool isSimplyHoistable = true; + if (failed(visitReturnOps( + func, + [&, idx = idx](func::ReturnOp term) { + isSimplyHoistable &= + isSimplyHoistableAllocation(term.getOperand(idx)); + if (isSimplyHoistable) + return success(); + FailureOr> tmp = + getHoistableOperations(term.getOperand(idx), func); + if (failed(tmp)) + return failure(); + if (!hoistableOperations) { + hoistableOperations = std::move(*tmp); + return success(); + } + return success(*hoistableOperations == *tmp); + })) || + (!isSimplyHoistable && !hoistableOperations)) { + if (memrefType.hasStaticShape() && memrefType.getLayout().isIdentity()) { + plan.promotableToCopyOut.set(idx); + plan.resultsToDrop.set(idx); + newBlockArgTypes.push_back(resultType); + } + continue; + } + + if (isSimplyHoistable) { + plan.simplyHoistableAllocations.set(idx); + plan.resultsToDrop.set(idx); + newBlockArgTypes.push_back(resultType); + continue; + } + + plan.operationsToHoist.emplace_back(std::move(*hoistableOperations)); + plan.hoistableAllocations.set(idx); + plan.resultsToDrop.set(idx); + newBlockArgTypes.push_back(resultType); + } + + // Add the new arguments to the function type. + auto newArgTypes = llvm::to_vector( + llvm::concat(functionType.getInputs(), newBlockArgTypes)); + auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes, + functionType.getResults()); + + rewriter.modifyOpInPlace(func, [&]() { func.setType(newFunctionType); }); + + // Transfer the result attributes to arg attributes. + unsigned newArgIdx = functionType.getNumInputs(); + for (auto [idx, erasedResultIdx] : + llvm::enumerate(plan.resultsToDrop.set_bits())) { + if (plan.returnsExistingBlockArg[erasedResultIdx]) + continue; + func.setArgAttrs(newArgIdx, func.getResultAttrs(erasedResultIdx)); + // Set the marker to indicate we promoted this argument from a result. + func.setArgAttr(newArgIdx, + StringAttr::get(func.getContext(), + plan::PlanDialect::kResultArgAttrName), + UnitAttr::get(func.getContext())); + newArgIdx++; + } + + // Erase the results. This takes care of updating the result attributes array. + rewriter.modifyOpInPlace(func, + [&]() { func.eraseResults(plan.resultsToDrop); }); + + // Add the new arguments to the entry block if the function is not external. + if (func.isExternal()) + return plan; + + Location loc = func.getLoc(); + rewriter.modifyOpInPlace(func, [&]() { + for (Type type : newBlockArgTypes) + appendedEntryArgs.push_back(func.front().addArgument(type, loc)); + }); + + return plan; +} + +/// Updates all ReturnOps in the scope of the given func::FuncOp by either +/// keeping them as return values or dropping the return value, replacing uses +/// and inserting copies as required. +static LogicalResult +updateReturnOps(RewriterBase &rewriter, func::FuncOp func, + ArrayRef appendedEntryArgs, + ResultPromotionPlan &plan, + bufferization::BufferResultsToOutParamsOpts &options) { + OpBuilder::InsertionGuard g(rewriter); + + return visitReturnOps(func, [&](func::ReturnOp term) -> LogicalResult { + rewriter.setInsertionPoint(term); + SmallVector keepAsReturnOperands; + llvm::SmallDenseMap copyIntoOutParams; + llvm::SmallDenseMap valuesToHoist; + unsigned appendedEntryArgIdx = 0; + + for (auto [idx, operand] : llvm::enumerate(term.getOperands())) { + if (plan.resultsToDrop.test(idx)) { + if (plan.hoistableAllocations.test(idx) || + plan.simplyHoistableAllocations.test(idx)) { + valuesToHoist[operand] = appendedEntryArgs[appendedEntryArgIdx++]; + } else if (plan.promotableToCopyOut.test(idx)) { + copyIntoOutParams[operand] = appendedEntryArgs[appendedEntryArgIdx++]; + } + continue; + } + keepAsReturnOperands.push_back(operand); + } + + for (auto [hoistable, appendedEntryArg] : valuesToHoist) + rewriter.replaceAllUsesWith(hoistable, appendedEntryArg); + + for (auto [orig, arg] : copyIntoOutParams) { + if (failed(options.memCpyFn(rewriter, term.getLoc(), orig, arg))) + return failure(); + } + + rewriter.modifyOpInPlace(term, [&]() { + term.getOperandsMutable().assign(keepAsReturnOperands); + }); + + return success(); + }); +} + +/// Updates all CallOps in the scope of the given ModuleOp by allocating +/// temporary buffers for newly introduced out params or cloning the required +/// operations to produce the new output buffer. +static LogicalResult +updateCalls(RewriterBase &rewriter, func::FuncOp func, + const ResultPromotionPlan &plan, const SymbolUserMap &symbolUserMap, + const bufferization::BufferResultsToOutParamsOpts &options) { + OpBuilder::InsertionGuard g(rewriter); + for (auto symbolUser : symbolUserMap.getUsers(func)) { + auto call = dyn_cast(symbolUser); + if (!call) + continue; + SmallVector newResultTypes; + SmallVector newOperands(call.getOperands()); + SmallVector replaceWithNewCallResults; + rewriter.setInsertionPoint(call); + auto hoistableOpsIt = plan.operationsToHoist.begin(); + for (auto [idx, result] : llvm::enumerate(call.getResults())) { + if (plan.resultsToDrop.test(idx)) { + auto memrefType = cast(result.getType()); + if (plan.promotableToCopyOut.test(idx) || + plan.simplyHoistableAllocations.test(idx)) { + FailureOr maybeOutParam = + options.allocationFn(rewriter, call.getLoc(), memrefType); + if (failed(maybeOutParam)) + return call.emitError() + << "failed to create allocation when promoting " + "a buffer result to an output parameter"; + rewriter.replaceAllUsesWith(result, *maybeOutParam); + newOperands.push_back(*maybeOutParam); + continue; + } + if (plan.hoistableAllocations.test(idx)) { + const SetVector &hoistableOps = *hoistableOpsIt++; + assert(!hoistableOps.empty() && "hoistableOps is empty"); + IRMapping mapping; + mapping.map(func.getArguments(), call.getOperands()); + for (Operation *op : hoistableOps) + rewriter.clone(*op, mapping); + Value operandReplacement = + mapping.lookup(hoistableOps.back()->getResult(0)); + rewriter.replaceAllUsesWith(result, operandReplacement); + newOperands.push_back(operandReplacement); + continue; + } + if (std::optional existingBlockArg = + plan.returnsExistingBlockArg[idx]) { + rewriter.replaceAllUsesWith(result, + call.getOperand(*existingBlockArg)); + continue; + } + llvm_unreachable("unhandled case"); + } + newResultTypes.push_back(result.getType()); + replaceWithNewCallResults.push_back(result); + } + + auto newCall = rewriter.create( + call.getLoc(), call.getCalleeAttr(), newResultTypes, newOperands); + for (auto [valueToReplace, replacement] : + llvm::zip_equal(replaceWithNewCallResults, newCall.getResults())) + rewriter.replaceAllUsesWith(valueToReplace, replacement); + rewriter.eraseOp(call); + } + + return success(); +} + +namespace { +struct PlanBufferResultsToOutParamsPass + : public plan::impl::PlanBufferResultsToOutParamsPassBase< + PlanBufferResultsToOutParamsPass> { + using Base::Base; + + LogicalResult initialize(MLIRContext *context) override { + options.allocationFn = [](OpBuilder &builder, Location loc, + MemRefType type) -> FailureOr { + return builder.create(loc, type).getResult(); + }; + + options.memCpyFn = [](OpBuilder &builder, Location loc, Value src, + Value dst) -> LogicalResult { + builder.create(loc, src, dst); + return success(); + }; + options.filterFn = [&](func::FuncOp *op) { + if (ignorePublicFunctions && op->isPublic()) + return false; + return !op->isDeclaration(); + }; + + return success(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + SymbolTableCollection symbolTables; + SymbolUserMap symbolUserMap(symbolTables, module); + for (auto func : module.getOps()) { + if (!options.filterFn(&func)) + continue; + + BufferOriginAnalysis analysis(func); + + IRRewriter rewriter(func); + SmallVector appendedEntryArgs; + + FailureOr updatePlan = + updateFuncOp(rewriter, func, appendedEntryArgs, analysis); + if (failed(updatePlan)) + return signalPassFailure(); + + if (func.isExternal()) + continue; + + if (failed(updateReturnOps(rewriter, func, appendedEntryArgs, *updatePlan, + options))) + return signalPassFailure(); + + if (failed( + updateCalls(rewriter, func, *updatePlan, symbolUserMap, options))) + return signalPassFailure(); + } + } + +private: + bufferization::BufferResultsToOutParamsOpts options; +}; +} // namespace diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/ModuleBufferization/ModuleBufferizationAnalysis.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/ModuleBufferization/ModuleBufferizationAnalysis.cpp index a0d4219e9..0e2f28320 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/ModuleBufferization/ModuleBufferizationAnalysis.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/ModuleBufferization/ModuleBufferizationAnalysis.cpp @@ -205,36 +205,67 @@ aliasingFuncOpBBArgsAnalysis(func::FuncOp funcOp, OneShotAnalysisState &state, static LogicalResult funcOpBbArgReadWriteAnalysis(func::FuncOp funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState) { + + auto recordArgAccessKind = [&](int64_t idx, bool isRead, bool isWritten) { + if (state.getOptions().testAnalysisOnly) + annotateFuncArgAccess(funcOp, idx, isRead, isWritten); + if (isRead) + funcState.readBbArgs[funcOp].insert(idx); + if (isWritten) + funcState.writtenBbArgs[funcOp].insert(idx); + }; + for (int64_t idx = 0, e = funcOp.getNumArguments(); idx < e; ++idx) { - // Skip non-tensor arguments. - if (!isa(funcOp.getFunctionType().getInput(idx))) + // Skip arguments that are not tensors or memrefs. + if (!isa(funcOp.getArgumentTypes()[idx])) continue; + + if (funcOp.isDeclaration()) { + // If the function has no body, conservatively assume that all args are + // read + written. + recordArgAccessKind(idx, true, true); + continue; + } + + Value bbArg = funcOp.getArgument(idx); + + // You can't call `state.isValueRead` or `state.isValueWritten` on memref + // values. So search for a `bufferization.to_tensor` op that has a + // `restrict` attribute. + if (auto memrefType = dyn_cast(bbArg.getType())) { + bufferization::ToTensorOp toTensorOp; + for (Operation *user : bbArg.getUsers()) { + if (auto toTensorUser = dyn_cast(user)) { + toTensorOp = toTensorUser; + break; + } + } + // If we fail to find a `bufferization.to_tensor` op that has a + // `restrict` attribute, conservatively assume that the memref bbArg is + // read + written. + if (!toTensorOp || !toTensorOp.getRestrict()) { + recordArgAccessKind(idx, true, true); + continue; + } + bbArg = toTensorOp.getResult(); + } + bool isRead; bool isWritten; if (auto accessAttr = funcOp.getArgAttrOfType( idx, BufferizationDialect::kBufferAccessAttrName)) { - // Buffer access behavior is specified on the function. Skip the analysis. + // Buffer access behavior is specified on the function. Skip the + // analysis. StringRef str = accessAttr.getValue(); isRead = str == "read" || str == "read-write"; isWritten = str == "write" || str == "read-write"; - } else if (funcOp.isDeclaration()) { - // If the function has no body, conservatively assume that all args are - // read + written. - isRead = true; - isWritten = true; } else { // Analyze the body of the function. - BlockArgument bbArg = funcOp.getArgument(idx); isRead = state.isValueRead(bbArg); isWritten = state.isValueWritten(bbArg); } - if (state.getOptions().testAnalysisOnly) - annotateFuncArgAccess(funcOp, idx, isRead, isWritten); - if (isRead) - funcState.readBbArgs[funcOp].insert(idx); - if (isWritten) - funcState.writtenBbArgs[funcOp].insert(idx); + recordArgAccessKind(idx, isRead, isWritten); } return success(); diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OptimizeMemorySpaces.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OptimizeMemorySpaces.cpp new file mode 100644 index 000000000..e85ee9511 --- /dev/null +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OptimizeMemorySpaces.cpp @@ -0,0 +1,552 @@ +//===- OptimizeMemorySpaces.cpp -------------------------------------------===// +// +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// Implementation of the `plan-optimize-memory-spaces` pass. +/// +//===----------------------------------------------------------------------===// +#include "mlir-tensorrt-dialect/Analysis/TensorKindAnalysis.h" +#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h" +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::plan { +#define GEN_PASS_DEF_PLANOPTIMIZEMEMORYSPACESPASS +#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h.inc" +} // namespace mlir::plan + +using namespace mlir; +using namespace mlir::plan; + +/// Returns true if the tensor type has a host-visible memory space encoding. +static bool isHostVisible(Value v) { + auto rtt = dyn_cast(v.getType()); + if (!rtt) + return false; + if (auto space = dyn_cast_or_null(rtt.getEncoding())) + return space.isHostVisible(); + return false; +} + +/// Return true if the op is likely in a "compute" region, like the region of +/// `stablehlo.reduce` or `linalg.generic`. For the purposes of this pass, we're +/// defining "compute" region as a region where the normal flow-of-control does +/// not enter from outside. It's only used to define the semantics of the parent +/// operation.a +static bool inComputeRegion(Operation *op) { + Operation *parent = op->getParentOp(); + while (parent) { + // If the parent is a function, then we're in a normal region. + if (isa(parent)) + return false; + // We are in a region which is not a control flow region and not a function, + // so it's probably a "compute" region. + if (!isa(parent)) + return true; + // If we're in a control flow region, we may still be nested in a "compute" + // region. E.g. `scf.if` is allowed in `linalg.generic` region. Keep going + // up to find the parent function. + parent = parent->getParentOp(); + } + return false; +} + +/// Remap relevant analysis state of type T from `original` to `replacement`. +template +static void remapLatticeState(DataFlowSolver &solver, Value original, + Value replacement) { + if constexpr (!std::is_same_v) { + if (const T *lattice = solver.lookupState(original)) { + T *latticeReplacement = solver.getOrCreateState(replacement); + latticeReplacement->getValue() = lattice->getValue(); + } + } else { + // do nothing for liveness analysis for the moment except create the state + if (const auto *oldState = + solver.lookupState(original)) { + dataflow::Executable *newState = solver.getOrCreateState(replacement); + // Set to live if old state is live. We ignore change status. + if (oldState->isLive()) + (void)newState->setToLive(); + } + } +} + +namespace { + +/// Use an explicit 'host_pinned' staging tensor to materialie the +/// 'from_elements' before creating explicitly moving it to the 'device' space. +/// Other optimization patterns below help avoid the host-device transfer when +/// possible. +struct FixUpFromElements : public OpRewritePattern { + FixUpFromElements(MLIRContext *ctx, const DataFlowSolver &solver, + PatternBenefit benefit = 1) + : OpRewritePattern(ctx, benefit), solver(solver) {} + + LogicalResult matchAndRewrite(tensor::FromElementsOp op, + PatternRewriter &rewriter) const override { + auto space = dyn_cast_or_null(op.getType().getEncoding()); + if (!space || space.isHostVisible()) + return rewriter.notifyMatchFailure( + op, "skipping no encoding or already host-visible"); + + const TensorKindLattice *lattice = + solver.lookupState(op.getResult()); + if (!lattice || lattice->getValue().isUninitialized() || + !lattice->getValue().isHostVisible()) { + return rewriter.notifyMatchFailure( + op, "skipping uninitialized TensorKindLattice value"); + } + + RankedTensorType originalType = op.getType(); + RankedTensorType newType = RankedTensorType::get( + originalType.getShape(), originalType.getElementType(), + MemorySpaceAttr::get(originalType.getContext(), + plan::MemorySpace::host)); + auto newOp = rewriter.create(op.getLoc(), newType, + op.getElements()); + rewriter.replaceOpWithNewOp(op, originalType, newOp); + return success(); + } + + const DataFlowSolver &solver; +}; + +/// Absorb cast operations into the while loop 'before' region and init types. +struct SCFWhileAbsorbCastBeforePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::WhileOp op, + PatternRewriter &rewriter) const override { + SmallVector iterArgsToUpdate; + Region &after = op.getAfter(); + Region &before = op.getBefore(); + auto originalYield = cast(after.front().getTerminator()); + + SmallVector newOperands(op.getOperands()); + SmallVector newYieldOperands(originalYield.getOperands()); + SmallVector> blockTypeUpdates; + bool hasUpdate = false; + for (BlockArgument arg : before.getArguments()) { + auto tensorType = dyn_cast(arg.getType()); + if (!tensorType) + continue; + Value aboveOperand = op.getOperand(arg.getArgNumber()); + Value yieldOperand = originalYield.getOperands()[arg.getArgNumber()]; + auto aboveCast = aboveOperand.getDefiningOp(); + auto yieldCast = yieldOperand.getDefiningOp(); + if (!aboveCast || !yieldCast || + aboveCast.getOperand().getType() != yieldCast.getOperand().getType()) + continue; + newOperands[arg.getArgNumber()] = aboveCast.getOperand(); + newYieldOperands[arg.getArgNumber()] = yieldCast.getOperand(); + blockTypeUpdates.emplace_back(arg.getArgNumber(), + aboveCast.getOperand().getType()); + hasUpdate = true; + } + if (!hasUpdate) + return failure(); + + rewriter.modifyOpInPlace(op, [&]() { + op.getInitsMutable().assign(newOperands); + for (auto [argNumber, type] : blockTypeUpdates) + before.getArgument(argNumber).setType(type); + }); + rewriter.modifyOpInPlace(originalYield, [&]() { + originalYield.getResultsMutable().assign(newYieldOperands); + }); + return success(); + } +}; + +/// Absorb cast operations into the while loop 'after' region and result types. +struct SCFWhileAbsorbCastAfterPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::WhileOp op, + PatternRewriter &rewriter) const override { + SmallVector iterArgsToUpdate; + Region &after = op.getAfter(); + Region &before = op.getBefore(); + auto originalCond = cast(before.front().getTerminator()); + + SmallVector newCondOperands(originalCond.getArgs()); + SmallVector> blockTypeUpdates; + bool hasUpdate = false; + for (BlockArgument arg : after.getArguments()) { + auto tensorType = dyn_cast(arg.getType()); + if (!tensorType) + continue; + Value condOperand = originalCond.getArgs()[arg.getArgNumber()]; + Value result = op.getResult(arg.getArgNumber()); + if (!result.hasOneUse()) + continue; + auto condCast = condOperand.getDefiningOp(); + auto resultCast = dyn_cast(*result.user_begin()); + if (!condCast || !resultCast || + condCast.getOperand().getType() != resultCast.getType()) + continue; + newCondOperands[arg.getArgNumber()] = condCast.getOperand(); + blockTypeUpdates.emplace_back(arg.getArgNumber(), + condCast.getOperand().getType()); + hasUpdate = true; + } + if (!hasUpdate) + return failure(); + + rewriter.modifyOpInPlace(op, [&]() { + for (auto [argNumber, type] : blockTypeUpdates) { + after.getArgument(argNumber).setType(type); + op.getResult(argNumber).setType(type); + } + }); + rewriter.modifyOpInPlace(originalCond, [&]() { + originalCond.getArgsMutable().assign(newCondOperands); + }); + return success(); + } +}; + +/// Absorb cast operations into the iteration carried arguments of `scf.for`. +struct SCFForAbsorbCastPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(scf::ForOp op, + PatternRewriter &rewriter) const override { + SmallVector iterArgsToUpdate; + Block *body = op.getBody(); + auto originalYield = cast(body->getTerminator()); + + SmallVector newOperands(op.getInitArgs()); + SmallVector newYieldOperands(originalYield.getOperands()); + SmallVector> blockTypeUpdates; + bool hasUpdate = false; + for (auto [iterArgIdx, arg] : llvm::enumerate(op.getRegionIterArgs())) { + if (!arg.hasOneUse()) + continue; + auto tensorType = dyn_cast(arg.getType()); + if (!tensorType) + continue; + auto argCast = dyn_cast(*arg.user_begin()); + if (!argCast) + continue; + Value yieldOperand = originalYield.getOperands()[iterArgIdx]; + auto yieldCast = yieldOperand.getDefiningOp(); + if (!yieldCast || argCast.getType() != yieldCast.getOperand().getType()) + continue; + Value newToOperand = rewriter.create( + op.getLoc(), argCast.getType(), op.getInitArgs()[iterArgIdx]); + newOperands[iterArgIdx] = newToOperand; + newYieldOperands[iterArgIdx] = yieldCast.getOperand(); + blockTypeUpdates.emplace_back(iterArgIdx, argCast.getType()); + hasUpdate = true; + } + if (!hasUpdate) + return failure(); + + rewriter.setInsertionPointAfter(op); + for (auto [iterArgIdx, type] : blockTypeUpdates) { + Type originalType = op.getResultTypes()[iterArgIdx]; + auto castOp = rewriter.create(op.getLoc(), originalType, + op.getResult(iterArgIdx)); + rewriter.replaceAllUsesExcept(op.getResult(iterArgIdx), castOp, castOp); + } + + rewriter.modifyOpInPlace(op, [&]() { + op.getInitArgsMutable().assign(newOperands); + for (auto [iterArgIdx, type] : blockTypeUpdates) { + op.getRegionIterArg(iterArgIdx).setType(type); + op->getResult(iterArgIdx).setType(type); + } + }); + rewriter.modifyOpInPlace(originalYield, [&]() { + originalYield.getResultsMutable().assign(newYieldOperands); + }); + + return success(); + } +}; + +/// For any 'shape' parameter of a 'tensor.reshape', ensure that the shape is +/// host-visible. +struct ReshapeAbsorbDeviceCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ReshapeOp op, + PatternRewriter &rewriter) const override { + // Skip past any explicit host-device transfers or host<->host-pinned + // transfers + if (auto matOp = + op.getShape() + .getDefiningOp()) { + auto source = matOp.getSource(); + if (isHostVisible(source)) { + rewriter.modifyOpInPlace( + op, [&]() { op.getShapeMutable().assign(source); }); + return success(); + } + } + if (auto castOp = op.getShape().getDefiningOp()) { + auto source = castOp.getOperand(); + if (isHostVisible(source)) { + rewriter.modifyOpInPlace( + op, [&]() { op.getShapeMutable().assign(source); }); + return success(); + } + } + // Don't insert explicit cast if the shape is already host-visible. + if (isHostVisible(op.getShape())) + return rewriter.notifyMatchFailure(op, "skipping already host-visible"); + // Otherwise, insert a direct cast-to-host. + auto castOp = rewriter.create( + op.getLoc(), + op.getShape().getType().cloneWithEncoding(plan::MemorySpaceAttr::get( + op->getContext(), plan::MemorySpace::host)), + op.getShape()); + rewriter.modifyOpInPlace(op, + [&]() { op.getShapeMutable().assign(castOp); }); + return success(); + } +}; + +// Abosrb `tensor.cast` into `bufferization.alloc_tensor` (with no copy +// operand). +struct AllocTensorAbsorbCastPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::CastOp op, + PatternRewriter &rewriter) const override { + auto source = op.getOperand(); + auto allocOp = source.getDefiningOp(); + if (!allocOp || allocOp.getCopy()) + return failure(); + auto allocType = allocOp.getType(); + auto allocMemorySpace = + llvm::dyn_cast_if_present(allocType.getEncoding()); + if (!allocMemorySpace) + return failure(); + auto castType = dyn_cast(op.getType()); + if (!castType) + return failure(); + auto castMemorySpace = + llvm::dyn_cast_if_present(castType.getEncoding()); + if (!castMemorySpace) + return failure(); + if (castType.getShape() != allocType.getShape()) + return failure(); + rewriter.replaceOpWithNewOp( + op, op.getType(), allocOp.getDynamicSizes(), /*copy=*/Value{}, + /*size_hint=*/Value{}, + /*memory_space=*/castMemorySpace); + return success(); + } +}; + +/// Rewrite `memref.load` that acts on device memory to first copy the buffer to +/// the host and load from the host buffer. +struct TensorDeviceExtractRewriter + : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp op, + PatternRewriter &rewriter) const override { + auto source = op.getTensor(); + if (isHostVisible(source)) + return failure(); + + if (inComputeRegion(op)) + return failure(); + + // First check if there is an existing `tensor.cast` which can be absorbed. + if (auto castOp = source.getDefiningOp()) { + if (isHostVisible(castOp.getOperand())) { + rewriter.modifyOpInPlace( + op, [&]() { op.getTensorMutable().assign(castOp.getOperand()); }); + return success(); + } + } + + rewriter.setInsertionPointAfterValue(source); + Value hostTensor = rewriter.create( + op.getLoc(), + RankedTensorType::get(source.getType().getShape(), + source.getType().getElementType(), + plan::MemorySpaceAttr::get( + op->getContext(), plan::MemorySpace::host)), + source); + + rewriter.replaceUsesWithIf(op.getTensor(), hostTensor, [&](OpOperand &use) { + return isa(use.getOwner()); + }); + + return success(); + } +}; + +/// Rewrite `tensor.insert` so that the insertion destination tensor has +/// 'host_pinned' space. +struct DeviceInsertRewriter : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::InsertOp op, + PatternRewriter &rewriter) const override { + auto dest = op.getDest(); + if (isHostVisible(dest)) + return failure(); + if (inComputeRegion(op)) + return failure(); + rewriter.setInsertionPointAfterValue(dest); + auto newType = RankedTensorType::get( + dest.getType().getShape(), dest.getType().getElementType(), + plan::MemorySpaceAttr::get(op->getContext(), plan::MemorySpace::host)); + Value hostTensor = + rewriter.create(op.getLoc(), newType, dest); + rewriter.replaceUsesWithIf(op.getDest(), hostTensor, [&](OpOperand &use) { + return isa(use.getOwner()); + }); + Type originalType = op.getType(); + rewriter.modifyOpInPlace(op, [&]() { op.getResult().setType(newType); }); + rewriter.setInsertionPointAfter(op); + auto castBack = rewriter.create(op.getLoc(), originalType, + op.getResult()); + rewriter.replaceAllUsesExcept(op.getResult(), castBack, castBack); + return success(); + } +}; + +/// 'tensor.from_elements' is not a DPS operation, so if we yield it from +/// a loop, the result of bufferization will always be to create and yield a new +/// allocation from the loop, which is highly sub-optimal. This pattern matches +/// any `tensor.from_elements` operation which is being yielded from a loop +/// region. It rewrites it to have an explicit +/// `bufferization.materialize_in_destination` operation to materialize the +/// result into a empty tensor. The advantage of this is that the empty tensor +/// can be bufferized into a memref which is allocated above the loop and +/// doesn't change between iterations. +/// +/// Note that you could also use `tensor.insert` to assemble the result, but the +/// BufferizableOpInterface implementation for `tensor.insert` is suboptimal. +struct LoopFromElementsPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::FromElementsOp op, + PatternRewriter &rewriter) const override { + if (!op->hasOneUse()) + return failure(); + auto user = *op->user_begin(); + if (!user->hasTrait()) + return failure(); + + auto parentOp = op->getParentOp(); + if (!isa(parentOp)) + return failure(); + + rewriter.setInsertionPointAfter(op); + + auto emptyOp = rewriter.create( + op.getLoc(), op.getType().getShape(), op.getType().getElementType(), + op.getType().getEncoding()); + auto matOp = rewriter.create( + op.getLoc(), op.getType(), op.getResult(), emptyOp); + + rewriter.replaceAllUsesExcept(op.getResult(), matOp.getResult(), matOp); + return success(); + } +}; + +/// A rewrite listener that transfers replacements to updates to the solver +/// state. +class SolverStateListener : public RewriterBase::Listener { +public: + SolverStateListener(DataFlowSolver &solver) + : RewriterBase::Listener(), solver(solver) {} + +private: + void notifyOperationReplaced(Operation *op, + ValueRange replacements) override { + for (auto [original, replacement] : + llvm::zip_equal(op->getResults(), replacements)) { + remapLatticeState(solver, original, replacement); + remapLatticeState>( + solver, original, replacement); + remapLatticeState(solver, original, replacement); + } + solver.eraseState(solver.getProgramPointAfter(op)); + } + void notifyOperationReplaced(Operation *op, Operation *replacement) override { + notifyOperationReplaced(op, replacement->getResults()); + } + + void notifyOperationErased(Operation *op) override { + solver.eraseState(solver.getProgramPointAfter(op)); + for (Value res : op->getResults()) + solver.eraseState(res); + } + + DataFlowSolver &solver; +}; + +struct OptimizeMemorySpacesPass + : public plan::impl::PlanOptimizeMemorySpacesPassBase< + OptimizeMemorySpacesPass> { + void runOnOperation() override { + func::FuncOp func = getOperation(); + + SymbolTableCollection symbolTables; + DataFlowSolver solver(DataFlowConfig().setInterprocedural(false)); + solver.load(); + solver.load(); + solver.load(symbolTables); + + if (failed(solver.initializeAndRun(func))) { + emitError(getOperation().getLoc()) << "failed to run TensorKindAnalysis"; + return signalPassFailure(); + } + + SolverStateListener solverAwareListener(solver); + GreedyRewriteConfig config; + config.listener = &solverAwareListener; + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext(), solver); + tensor::CastOp::getCanonicalizationPatterns(patterns, &getContext()); + // clang-format off + patterns.insert< + AllocTensorAbsorbCastPattern, + DeviceInsertRewriter, + ReshapeAbsorbDeviceCast, + SCFForAbsorbCastPattern, + SCFWhileAbsorbCastAfterPattern, + SCFWhileAbsorbCastBeforePattern, + TensorDeviceExtractRewriter, + LoopFromElementsPattern + >(&getContext()); + + + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + emitError(func.getLoc()) << "failed to run " << getArgument(); + return signalPassFailure(); + } + } +}; +} // namespace diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp index 418833088..ca1c93633 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp @@ -1,7 +1,6 @@ -//===- Passes.cpp -//----------------------------------------------------------===// +//===- Passes.cpp --------------------------------------------------------===// // -// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. +// SPDX-FileCopyrightText: Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES. // All rights reserved. // SPDX-License-Identifier: Apache-2.0 // @@ -63,6 +62,11 @@ static void buildPlanOneShotBufferizePipelinePipeline( pm.addPass(createInlinerPass()); pm.addPass(bufferization::createEmptyTensorEliminationPass()); pm.addPass(plan::createPlanAssignMemorySpacesPass()); + pm.addNestedPass(plan::createPlanOptimizeMemorySpacesPass()); + pm.addNestedPass( + plan::createPlanPromoteHostTensorsToHostPinnedPass()); + pm.addNestedPass( + plan::createPlanMaterializeExplicitTransfersPass()); pm.addPass(plan::createPlanAllocTensorsPass(opts)); pm.addPass(plan::createPlanModuleBufferizePass()); pm.addPass(mlir::createMemRefCastEliminationPass()); @@ -92,9 +96,10 @@ static void buildPlanBufferDeallocationPipeline( pm.addPass(createCanonicalizerPass()); pm.addPass(bufferization::createBufferDeallocationSimplificationPass()); pm.addPass(bufferization::createLowerDeallocationsPass()); - pm.addPass(mlir::createBufferizationToMemRefPass()); - pm.addPass(createCSEPass()); - pm.addPass(createCanonicalizerPass()); + pm.addNestedPass( + mlir::createConvertBufferizationToMemRefPass()); + pm.addNestedPass(createCSEPass()); + pm.addNestedPass(createCanonicalizerPass()); } namespace { diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/PromoteHostTensorsToHostPinned.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/PromoteHostTensorsToHostPinned.cpp new file mode 100644 index 000000000..1c8692463 --- /dev/null +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/PromoteHostTensorsToHostPinned.cpp @@ -0,0 +1,135 @@ +//===- PromoteHostTensorsToHostPinned.cpp --------------------------------===// +// +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// Implementation of the `plan-promote-host-tensors-to-host-pinned` pass. +/// +//===----------------------------------------------------------------------===// +#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" +#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AsmState.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "plan-assign-memory-spaces" + +namespace mlir::plan { +#define GEN_PASS_DEF_PLANPROMOTEHOSTTENSORSTOHOSTPINNEDPASS +#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h.inc" +} // namespace mlir::plan + +using namespace mlir; +using namespace mlir::plan; + +namespace { + +template +struct CastPromotionPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::CastOp op, + PatternRewriter &rewriter) const override { + auto sourceType = dyn_cast(op.getOperand().getType()); + auto destType = dyn_cast(op.getType()); + if (!sourceType || !destType) + return failure(); + auto sourceSpaceAttr = + llvm::dyn_cast_if_present(sourceType.getEncoding()); + auto destSpaceAttr = + llvm::dyn_cast_if_present(destType.getEncoding()); + if (!sourceSpaceAttr || !destSpaceAttr) + return failure(); + if (sourceSpaceAttr.getValue() != sourceSpace || + destSpaceAttr.getValue() != destSpace) + return failure(); + return handleCast(op, sourceType, destType, rewriter); + } + + virtual LogicalResult handleCast(tensor::CastOp op, + RankedTensorType sourceType, + RankedTensorType destType, + PatternRewriter &rewriter) const = 0; + + plan::MemorySpaceAttr hostSpaceAttr = + plan::MemorySpaceAttr::get(getContext(), plan::MemorySpace::host); + plan::MemorySpaceAttr hostPinnedSpaceAttr = + plan::MemorySpaceAttr::get(getContext(), plan::MemorySpace::host_pinned); + plan::MemorySpaceAttr deviceSpaceAttr = + plan::MemorySpaceAttr::get(getContext(), plan::MemorySpace::device); +}; + +// Pattern for promoting device->host cast as device->host-pinned cast if all +// the cast's users can be updated in-place. +struct DeviceToHostCastPattern + : public CastPromotionPattern { + using CastPromotionPattern::CastPromotionPattern; + + LogicalResult handleCast(tensor::CastOp op, RankedTensorType sourceType, + RankedTensorType destType, + PatternRewriter &rewriter) const override { + // TODO: this should be replaced with some more general conditions. To + // handle `tensor.insert`, we would need to cast type of `tensor.insert` + // result back to original, propogate it forward, etc. + if (!llvm::all_of(op->getUsers(), llvm::IsaPred)) + return failure(); + auto newType = destType.cloneWithEncoding(hostPinnedSpaceAttr); + rewriter.replaceOpWithNewOp(op, newType, op.getOperand()); + return success(); + } +}; + +// Propogate host->device cast as host->device-pinned cast if the producer is a +// `tensor.from_elements` operation. +struct FromElementsPromotionPattern + : public CastPromotionPattern { + using CastPromotionPattern::CastPromotionPattern; + + LogicalResult handleCast(tensor::CastOp op, RankedTensorType sourceType, + RankedTensorType destType, + PatternRewriter &rewriter) const override { + auto fromElementsOp = + dyn_cast(op.getOperand().getDefiningOp()); + if (!fromElementsOp) + return failure(); + auto hostPinnedType = sourceType.cloneWithEncoding(hostPinnedSpaceAttr); + rewriter.modifyOpInPlace(fromElementsOp, [&]() { + fromElementsOp.getResult().setType(hostPinnedType); + }); + return success(); + }; +}; + +class PromoteHostTensorsToHostPinnedPass + : public plan::impl::PlanPromoteHostTensorsToHostPinnedPassBase< + PromoteHostTensorsToHostPinnedPass> { + using Base::Base; + void runOnOperation() override { + auto op = getOperation(); + + RewritePatternSet patterns(&getContext()); + patterns.add( + patterns.getContext()); + if (failed(applyPatternsGreedily(op, std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace diff --git a/mlir-tensorrt/compiler/lib/Transforms/SCFDetensorizeLoops/SCFDetensorizeLoops.cpp b/mlir-tensorrt/compiler/lib/Transforms/SCFDetensorizeLoops/SCFDetensorizeLoops.cpp index d37d6615c..a272dbf4f 100644 --- a/mlir-tensorrt/compiler/lib/Transforms/SCFDetensorizeLoops/SCFDetensorizeLoops.cpp +++ b/mlir-tensorrt/compiler/lib/Transforms/SCFDetensorizeLoops/SCFDetensorizeLoops.cpp @@ -23,6 +23,7 @@ //===----------------------------------------------------------------------===// #include "mlir-tensorrt-dialect/Analysis/TensorKindAnalysis.h" #include "mlir-tensorrt/Transforms/Passes.h" +#include "mlir-tensorrt/Transforms/Transforms.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlowFramework.h" @@ -48,133 +49,188 @@ static bool isHostTensor(Value value, const DataFlowSolver &solver) { return lattice->getValue().isHostOnly(); } -/// Returns true if it is OK to scalarize the given loop-carried variable. -static bool isValidToScalarize(BlockArgument arg, - const DataFlowSolver &solver) { - auto tensorType = dyn_cast(arg.getType()); - if (!tensorType || tensorType.getNumElements() != 1) - return false; - - // Check that all uses are by a `tensor.extract` operation or the terminator. - return isHostTensor(arg, solver); -} - namespace { -/// Attempts to rewrite `scf.while` operations to scalarize the loop-carried -/// variables if possible. -struct DetensorizeWhilePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +/// Absorb cast operations into the while loop 'before' region and init types. +struct WhileScalarizeBeforeArgPattern : public OpRewritePattern { + WhileScalarizeBeforeArgPattern( + MLIRContext *ctx, + ShouldScalarizeWhileBeforeArgFunc shouldScalarizeBeforeArg, + PatternBenefit benefit) + : OpRewritePattern(ctx, benefit), + shouldScalarizeBeforeArg(std::move(shouldScalarizeBeforeArg)) {} - DetensorizeWhilePattern(MLIRContext *ctx, const DataFlowSolver &solver) - : OpRewritePattern(ctx), solver(solver) {} + ShouldScalarizeWhileBeforeArgFunc shouldScalarizeBeforeArg; - LogicalResult matchAndRewrite(WhileOp op, + LogicalResult matchAndRewrite(scf::WhileOp op, PatternRewriter &rewriter) const override { SmallVector iterArgsToUpdate; Region &after = op.getAfter(); Region &before = op.getBefore(); - if (after.getArgumentTypes() != before.getArgumentTypes()) - return rewriter.notifyMatchFailure( - op, "only scf.while with same before/after region argument types are " - "supported"); + auto originalYield = cast(after.front().getTerminator()); - for (BlockArgument arg : after.getArguments()) { - if (!isValidToScalarize(arg, solver) || - !isValidToScalarize(before.getArgument(arg.getArgNumber()), solver)) + SmallVector newOperands(op.getOperands()); + SmallVector newYieldOperands(originalYield.getOperands()); + SmallVector> blockTypeUpdates; + bool hasUpdate = false; + for (BlockArgument arg : before.getArguments()) { + auto tensorType = dyn_cast(arg.getType()); + if (!tensorType || !tensorType.hasStaticShape() || + tensorType.getNumElements() != 1) + continue; + Value aboveOperand = op.getOperand(arg.getArgNumber()); + Value yieldOperand = originalYield.getOperands()[arg.getArgNumber()]; + if (!shouldScalarizeBeforeArg(arg, aboveOperand, yieldOperand)) continue; - iterArgsToUpdate.push_back(arg.getArgNumber()); + rewriter.setInsertionPoint(op); + newOperands[arg.getArgNumber()] = rewriter.create( + op.getLoc(), aboveOperand, + SmallVector( + tensorType.getRank(), + rewriter.create(op.getLoc(), 0))); + rewriter.setInsertionPoint(originalYield); + newYieldOperands[arg.getArgNumber()] = rewriter.create( + op.getLoc(), yieldOperand, + SmallVector( + tensorType.getRank(), + rewriter.create(op.getLoc(), 0))); + blockTypeUpdates.emplace_back(arg.getArgNumber(), + tensorType.getElementType()); + hasUpdate = true; } - - if (iterArgsToUpdate.empty()) + if (!hasUpdate) return failure(); - // Create the `tensor.extract` operations before the loop op. - Value zero = rewriter.create(op.getLoc(), 0); - SmallVector newTypes(op->getResultTypes()); - SmallVector newOperands(op.getOperands()); - for (int64_t idx : iterArgsToUpdate) { - auto tensorType = cast(newTypes[idx]); - newTypes[idx] = tensorType.getElementType(); - newOperands[idx] = rewriter.create( - op.getLoc(), newOperands[idx], - SmallVector(tensorType.getRank(), zero)); + rewriter.setInsertionPointToStart(&before.front()); + for (auto [argNumber, type] : blockTypeUpdates) { + Type originalType = before.getArgument(argNumber).getType(); + auto fromElements = rewriter.create( + op.getLoc(), originalType, before.getArgument(argNumber)); + rewriter.replaceAllUsesExcept(before.getArgument(argNumber), fromElements, + fromElements); } - // Update the `while` op by moving the regions to a new while op. - auto whileOp = rewriter.create(op.getLoc(), newTypes, newOperands); - rewriter.inlineRegionBefore(before, whileOp.getBefore(), - whileOp.getBefore().end()); - rewriter.inlineRegionBefore(after, whileOp.getAfter(), - whileOp.getAfter().end()); - auto yield = cast(whileOp.getAfterBody()->getTerminator()); - auto cond = cast(whileOp.getBeforeBody()->getTerminator()); - - SmallVector newConditionArgs(cond.getArgs()); - SmallVector newYieldArgs(yield.getOperands()); - - for (int64_t idx : iterArgsToUpdate) { - // For each loop-carried arg being transformed, update the block argument - // types. - auto oldType = cast( - whileOp.getBeforeBody()->getArgument(idx).getType()); - whileOp.getBeforeBody()->getArgument(idx).setType(newTypes[idx]); - whileOp.getAfterBody()->getArgument(idx).setType(newTypes[idx]); - - // Update uses. By design of preconditions. - for (BlockArgument arg : {whileOp.getBeforeBody()->getArgument(idx), - whileOp.getAfterBody()->getArgument(idx)}) { - rewriter.setInsertionPointToStart(arg.getOwner()); - - tensor::FromElementsOp replacement = - rewriter.create(arg.getLoc(), oldType, arg); - rewriter.replaceAllUsesExcept(arg, replacement, replacement); - } + rewriter.modifyOpInPlace(op, [&]() { + op.getInitsMutable().assign(newOperands); + for (auto [argNumber, type] : blockTypeUpdates) + before.getArgument(argNumber).setType(type); + }); - // Update the terminator to be a `tensor.extract` if the yielded value is - // not exactly the block argument. - auto getCoord = [&](Location loc) { - return oldType.getRank() == 0 - ? Value{} - : rewriter.create(loc, 0) - .getResult(); - }; - if (isa(cond.getArgs()[idx].getType())) { - rewriter.setInsertionPoint(cond); - Location loc = cond.getArgs()[idx].getLoc(); - Value coord = getCoord(loc); - auto extractOp = rewriter.create( - loc, cond.getArgs()[idx], coord ? ValueRange{coord} : ValueRange{}); - newConditionArgs[idx] = extractOp; - } - if (isa(yield.getOperands()[idx].getType())) { - rewriter.setInsertionPoint(yield); - Location loc = yield.getOperand(idx).getLoc(); - Value coord = getCoord(loc); - auto extractOp = rewriter.create( - loc, yield.getOperand(idx), - coord ? ValueRange{coord} : ValueRange{}); - newYieldArgs[idx] = extractOp; - } + rewriter.setInsertionPointToStart(&before.front()); + rewriter.modifyOpInPlace(originalYield, [&]() { + originalYield.getResultsMutable().assign(newYieldOperands); + }); + return success(); + } +}; + +/// Absorb cast operations into the while loop 'after' region and result types. +struct WhileScalarizeAfterArgPattern : public OpRewritePattern { + WhileScalarizeAfterArgPattern( + MLIRContext *ctx, + ShouldScalarizeWhileAfterArgFunc shouldScalarizeAfterArg, + PatternBenefit benefit) + : OpRewritePattern(ctx, benefit), + shouldScalarizeAfterArg(std::move(shouldScalarizeAfterArg)) {} + + ShouldScalarizeWhileAfterArgFunc shouldScalarizeAfterArg; + + LogicalResult matchAndRewrite(scf::WhileOp op, + PatternRewriter &rewriter) const override { + SmallVector iterArgsToUpdate; + Region &after = op.getAfter(); + Region &before = op.getBefore(); + auto originalCond = cast(before.front().getTerminator()); + + SmallVector newCondOperands(originalCond.getArgs()); + SmallVector> blockTypeUpdates; + bool hasUpdate = false; + for (BlockArgument arg : after.getArguments()) { + auto tensorType = dyn_cast(arg.getType()); + if (!tensorType || !tensorType.hasStaticShape() || + tensorType.getNumElements() != 1) + continue; + Value condOperand = originalCond.getArgs()[arg.getArgNumber()]; + Value result = op.getResult(arg.getArgNumber()); + if (!shouldScalarizeAfterArg(arg, condOperand, result)) + continue; + rewriter.setInsertionPoint(originalCond); + newCondOperands[arg.getArgNumber()] = rewriter.create( + op.getLoc(), condOperand, + SmallVector( + tensorType.getRank(), + rewriter.create(op.getLoc(), 0))); + blockTypeUpdates.emplace_back(arg.getArgNumber(), + tensorType.getElementType()); + hasUpdate = true; } + if (!hasUpdate) + return failure(); - rewriter.modifyOpInPlace( - yield, [&]() { yield.getResultsMutable().assign(newYieldArgs); }); - rewriter.modifyOpInPlace( - cond, [&]() { cond.getArgsMutable().assign(newConditionArgs); }); + for (auto [argNumber, type] : blockTypeUpdates) { + rewriter.setInsertionPointToStart(&after.front()); + Type originalType = op.getResult(argNumber).getType(); + auto fromElements = rewriter.create( + op.getLoc(), originalType, after.getArgument(argNumber)); + rewriter.replaceAllUsesExcept(after.getArgument(argNumber), fromElements, + fromElements); - // Replace the loop with new values. Create the scalars as necessary. - rewriter.setInsertionPointAfter(whileOp); - SmallVector replacements(whileOp.getResults()); - for (int64_t idx : iterArgsToUpdate) - replacements[idx] = rewriter.create( - op.getLoc(), op->getResult(idx).getType(), replacements[idx]); + rewriter.setInsertionPointAfter(op); + auto fromElements2 = rewriter.create( + op.getLoc(), originalType, op.getResult(argNumber)); + rewriter.replaceAllUsesExcept(op.getResult(argNumber), fromElements2, + fromElements2); + } - rewriter.replaceOp(op, replacements); + rewriter.modifyOpInPlace(op, [&]() { + for (auto [argNumber, type] : blockTypeUpdates) { + after.getArgument(argNumber).setType(type); + op.getResult(argNumber).setType(type); + } + }); + rewriter.modifyOpInPlace(originalCond, [&]() { + originalCond.getArgsMutable().assign(newCondOperands); + }); return success(); } - - const DataFlowSolver &solver; }; +} // namespace + +static bool defaultShouldScalarizeBeforeArg(BlockArgument arg, + Value initOperand, + Value yieldOperand) { + RankedTensorType type = cast(arg.getType()); + return isa(type.getElementType()) && + (initOperand.getDefiningOp() || + matchPattern(initOperand, m_Constant())) && + (yieldOperand.getDefiningOp() || + matchPattern(yieldOperand, m_Constant())); +} + +static bool defaultShouldScalarizeAfterArg(BlockArgument arg, Value condOperand, + Value result) { + RankedTensorType type = cast(arg.getType()); + return isa(type.getElementType()) && + result.hasOneUse() && + (isa(condOperand) || + condOperand.getDefiningOp()); +} + +void mlir::populateSCFDetensorizeWhilePatterns( + RewritePatternSet &patterns, + ShouldScalarizeWhileBeforeArgFunc shouldScalarizeBeforeArg, + ShouldScalarizeWhileAfterArgFunc shouldScalarizeAfterArg, + PatternBenefit benefit) { + if (!shouldScalarizeBeforeArg) + shouldScalarizeBeforeArg = defaultShouldScalarizeBeforeArg; + if (!shouldScalarizeAfterArg) + shouldScalarizeAfterArg = defaultShouldScalarizeAfterArg; + patterns.add( + patterns.getContext(), shouldScalarizeBeforeArg, benefit); + patterns.add(patterns.getContext(), + shouldScalarizeAfterArg, benefit); +} + +namespace { class SCFDetensorizeLoopsPass : public impl::SCFDetensorizeLoopsPassBase { @@ -194,8 +250,17 @@ class SCFDetensorizeLoopsPass emitError(op->getLoc()) << "failed to run TensorKindAnalysis"; return signalPassFailure(); } - - patterns.add(ctx, solver); + auto shouldScalarizeBeforeArg = [&](BlockArgument arg, Value initOperand, + Value yieldOperand) { + return isHostTensor(arg, solver); + }; + auto shouldScalarizeAfterArg = [&](BlockArgument arg, Value condOperand, + Value result) { + return isHostTensor(arg, solver); + }; + populateSCFDetensorizeWhilePatterns(patterns, shouldScalarizeBeforeArg, + shouldScalarizeAfterArg, + /*benefit=*/1); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir-tensorrt/compiler/test/Conversion/HostToEmitC/func-to-emitc.mlir b/mlir-tensorrt/compiler/test/Conversion/HostToEmitC/func-to-emitc.mlir new file mode 100644 index 000000000..cc573e880 --- /dev/null +++ b/mlir-tensorrt/compiler/test/Conversion/HostToEmitC/func-to-emitc.mlir @@ -0,0 +1,49 @@ +// RUN: rm -rf %t || true +// RUN: mkdir -p %t +// RUN: mlir-tensorrt-opt -split-input-file -convert-host-to-emitc="artifacts-dir=%t" %s | \ +// RUN: mlir-tensorrt-translate -split-input-file -mlir-to-cpp | FileCheck %s --check-prefix=CPP + +func.func @callee(%arg0: memref) -> memref { + return %arg0 : memref +} + +func.func @caller(%arg0: memref) -> memref { + %0 = call @callee(%arg0) : (memref) -> memref + return %0 : memref +} + +// CPP-LABEL: mtrt::RankedMemRef<1> callee +// CPP-SAME: (mtrt::RankedMemRef<1> [[v1:.+]]) +// CPP: return [[v1]]; +// CPP: mtrt::RankedMemRef<1> caller +// CPP-SAME: (mtrt::RankedMemRef<1> [[v1:.+]]) +// CPP: mtrt::RankedMemRef<1> [[v2:.+]] = callee([[v1]]); +// CPP: return [[v2]]; + +// ----- + +func.func @callee_multiple_return(%arg0: memref, %arg1: index) -> (memref, f32, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1.0 : f32 + return %arg0, %c1, %c0 : memref, f32, index +} + +func.func @caller_multiple_return(%arg0: memref, %arg1: index) -> (memref, f32, index) { + %0:3 = call @callee_multiple_return(%arg0, %arg1) : (memref, index) -> (memref, f32, index) + return %0#0, %0#1, %0#2 : memref, f32, index +} + +// CPP-LABEL: std::tuple, float, size_t> callee_multiple_return +// CPP-SAME: (mtrt::RankedMemRef<1> [[v1:.+]], size_t [[v2:.+]]) +// CPP-DAG: size_t [[v3:.+]] = 0; +// CPP-DAG: float [[v4:.+]] = 1.000000000e+00f; +// CPP-DAG: return std::make_tuple([[v1]], [[v4]], [[v3]]); + + +// CPP-LABEL: std::tuple, float, size_t> caller_multiple_return +// CPP-SAME: (mtrt::RankedMemRef<1> [[v1:.+]], size_t [[v2:.+]]) +// CPP-DAG: mtrt::RankedMemRef<1> [[v3:.+]]; +// CPP-DAG: float [[v4:.+]]; +// CPP-DAG: size_t [[v5:.+]]; +// CPP-DAG: std::tie(v3, v4, v5) = callee_multiple_return(v1, v2); +// CPP-DAG: return std::make_tuple(v3, v4, v5); diff --git a/mlir-tensorrt/compiler/test/Conversion/HostToEmitC/memref-to-emitc.mlir b/mlir-tensorrt/compiler/test/Conversion/HostToEmitC/memref-to-emitc.mlir index 5f6826a2f..445759724 100644 --- a/mlir-tensorrt/compiler/test/Conversion/HostToEmitC/memref-to-emitc.mlir +++ b/mlir-tensorrt/compiler/test/Conversion/HostToEmitC/memref-to-emitc.mlir @@ -1,10 +1,12 @@ // RUN: rm -rf %t || true // RUN: mkdir -p %t -// RUN: mlir-tensorrt-opt -split-input-file -convert-host-to-emitc="artifacts-dir=%t" %s | FileCheck %s -// RUN: file %t/gv3.constant.bin // RUN: mlir-tensorrt-opt -split-input-file -convert-host-to-emitc="artifacts-dir=%t" %s | \ +// RUN: tee %t/out.mlir | \ // RUN: mlir-tensorrt-translate -split-input-file -mlir-to-cpp | FileCheck %s --check-prefix=CPP +// RUN: file %t/gv3.constant.bin +// RUN: FileCheck --check-prefix=CHECK %s < %t/out.mlir + memref.global @gv2 : memref<2x3xf32> = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> @@ -19,13 +21,13 @@ func.func @get_global() { // CHECK: emitc.global @gv2 : !emitc.array<2x3xf32> = dense<{{.*}}> // CHECK: emitc.global @gv3 : !emitc.ptr> -// CHECK-LABEL: emitc.func @get_global +// CHECK-LABEL: func.func @get_global // CHECK: %[[v0:.+]] = "emitc.constant"() <{value = 0 : i32}> : () -> i32 -// CHECK: %[[v1:.+]] = get_global @gv2 : !emitc.array<2x3xf32> -// CHECK: %[[v2:.+]] = call_opaque "mtrt::make_memref_descriptor" -// CHECK: %[[v3:.+]] = get_global @gv3 -// CHECK: %[[v4:.+]] = load %[[v3]] -// CHECK: %[[v5:.+]] = call_opaque "mtrt::make_memref_descriptor" +// CHECK: %[[v1:.+]] = emitc.get_global @gv2 : !emitc.array<2x3xf32> +// CHECK: %[[v2:.+]] = emitc.call_opaque "mtrt::make_memref_descriptor" +// CHECK: %[[v3:.+]] = emitc.get_global @gv3 +// CHECK: %[[v4:.+]] = emitc.load %[[v3]] +// CHECK: %[[v5:.+]] = emitc.call_opaque "mtrt::make_memref_descriptor" // CHECK: return // CHECK-LABEL: emitc.func @unnamed_module_gv3_initialize @@ -79,18 +81,18 @@ func.func @extract_strided_metadata( return } -// CHECK-LABEL: emitc.func @extract_strided_metadata +// CHECK-LABEL: func.func @extract_strided_metadata // CHECK-SAME: (%[[arg0:.+]]: !emitc.opaque<"mtrt::RankedMemRef<2>">) { // CHECK-DAG: %[[v0:.+]] = "emitc.constant"() <{value = 1 : i32}> : () -> i32 // CHECK-DAG: %[[v1:.+]] = "emitc.constant"() <{value = 0 : i32}> : () -> i32 -// CHECK-NEXT: %[[v2:.+]] = call_opaque "mtrt::memref_descriptor_get_allocated_ptr"(%[[arg0]] -// CHECK-NEXT: %[[v3:.+]] = call_opaque "mtrt::memref_descriptor_get_aligned_ptr"(%[[arg0]]) -// CHECK-NEXT: %[[v4:.+]] = call_opaque "mtrt::memref_descriptor_get_offset"(%[[arg0]]) -// CHECK-NEXT: %[[v5:.+]] = call_opaque "mtrt::make_memref_descriptor"(%[[v2]], %[[v3]], %[[v1]]) -// CHECK-NEXT: %[[v6:.+]] = call_opaque "mtrt::memref_descriptor_get_dim_size"(%[[arg0]], %[[v1]]) -// CHECK-NEXT: %[[v7:.+]] = call_opaque "mtrt::memref_descriptor_get_dim_size"(%[[arg0]], %[[v0]]) -// CHECK-NEXT: %[[v8:.+]] = call_opaque "mtrt::memref_descriptor_get_stride"(%[[arg0]], %[[v1]]) -// CHECK-NEXT: %[[v9:.+]] = call_opaque "mtrt::memref_descriptor_get_stride"(%[[arg0]], %[[v0]]) +// CHECK-NEXT: %[[v2:.+]] = emitc.call_opaque "mtrt::memref_descriptor_get_allocated_ptr"(%[[arg0]] +// CHECK-NEXT: %[[v3:.+]] = emitc.call_opaque "mtrt::memref_descriptor_get_aligned_ptr"(%[[arg0]]) +// CHECK-NEXT: %[[v4:.+]] = emitc.call_opaque "mtrt::memref_descriptor_get_offset"(%[[arg0]]) +// CHECK-NEXT: %[[v5:.+]] = emitc.call_opaque "mtrt::make_memref_descriptor"(%[[v2]], %[[v3]], %[[v1]]) +// CHECK-NEXT: %[[v6:.+]] = emitc.call_opaque "mtrt::memref_descriptor_get_dim_size"(%[[arg0]], %[[v1]]) +// CHECK-NEXT: %[[v7:.+]] = emitc.call_opaque "mtrt::memref_descriptor_get_dim_size"(%[[arg0]], %[[v0]]) +// CHECK-NEXT: %[[v8:.+]] = emitc.call_opaque "mtrt::memref_descriptor_get_stride"(%[[arg0]], %[[v1]]) +// CHECK-NEXT: %[[v9:.+]] = emitc.call_opaque "mtrt::memref_descriptor_get_stride"(%[[arg0]], %[[v0]]) // CHECK-NEXT: return // CPP-LABEL: void extract_strided_metadata(mtrt::RankedMemRef<2> v1) { @@ -196,3 +198,23 @@ func.func @extract_aligned_pointer_as_index(%m: memref) -> index { // CPP-NEXT: void* v2 = mtrt::memref_descriptor_get_aligned_ptr(v1); // CPP-NEXT: size_t v3 = (size_t) v2; // CPP-NEXT: return v3; + + +// ----- + +func.func @alloc_of_index() -> memref<42xindex> { + %0 = memref.alloc() : memref<42xindex> + return %0 : memref<42xindex> +} + + +// CPP-LABEL: mtrt::RankedMemRef<1> alloc_of_index() +// CPP-DAG: int32_t [[v1:.+]] = 0; +// CPP-DAG: int32_t [[v2:.+]] = 16; +// CPP-DAG: int64_t [[v3:.+]] = 42; +// CPP-DAG: int64_t [[v4:.+]] = 1; +// CPP-DAG: int32_t [[v5:.+]] = 8; +// CPP-DAG: int64_t [[v6:.+]] = [[v5]] * [[v3]]; +// CPP-DAG: void* [[v7:.+]] = mtrt::host_aligned_alloc([[v6]], [[v2]]); +// CPP-DAG: mtrt::RankedMemRef<1> [[v8:.+]] = mtrt::make_memref_descriptor<1>([[v7]], [[v7]], [[v1]], [[v3]], [[v4]]); +// CPP-DAG: return [[v8]]; diff --git a/mlir-tensorrt/compiler/test/Conversion/StablehloToArith/stablehlo-constant-to-arith.mlir b/mlir-tensorrt/compiler/test/Conversion/StablehloToArith/stablehlo-constant-to-arith.mlir new file mode 100644 index 000000000..388052318 --- /dev/null +++ b/mlir-tensorrt/compiler/test/Conversion/StablehloToArith/stablehlo-constant-to-arith.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-tensorrt-opt %s -split-input-file -convert-stablehlo-constants-to-arith | FileCheck %s + +func.func @test_stablehlo_constant_to_arith() -> tensor<10xui32> { + %0 = stablehlo.constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : tensor<10xui32> + return %0 : tensor<10xui32> +} + +// CHECK-LABEL: func.func @test_stablehlo_constant_to_arith +// CHECK-DAG: %[[cst:.+]] = arith.constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : tensor<10xi32> +// CHECK-DAG: %[[v0:.+]] = builtin.unrealized_conversion_cast %[[cst]] : tensor<10xi32> to tensor<10xui32> +// CHECK-DAG: return %[[v0]] : tensor<10xui32> + +// ----- + +func.func @test_stablehlo_constant_to_arith() -> tensor<10xui32> { + %0 = stablehlo.constant dense_resource<__elided__> : tensor<10xui32> + return %0 : tensor<10xui32> +} + +// CHECK-LABEL: func.func @test_stablehlo_constant_to_arith +// CHECK-DAG: %[[cst:.+]] = arith.constant dense_resource<__elided__> : tensor<10xi32> +// CHECK-DAG: %[[v0:.+]] = builtin.unrealized_conversion_cast %[[cst]] : tensor<10xi32> to tensor<10xui32> +// CHECK-DAG: return %[[v0]] : tensor<10xui32> diff --git a/mlir-tensorrt/compiler/test/Conversion/StablehloToScf/stablehlo-to-scf.mlir b/mlir-tensorrt/compiler/test/Conversion/StablehloToScf/stablehlo-to-scf.mlir index 8273db8c7..a96f50e41 100644 --- a/mlir-tensorrt/compiler/test/Conversion/StablehloToScf/stablehlo-to-scf.mlir +++ b/mlir-tensorrt/compiler/test/Conversion/StablehloToScf/stablehlo-to-scf.mlir @@ -1,6 +1,6 @@ // RUN: mlir-tensorrt-opt %s -split-input-file -convert-stablehlo-to-scf | FileCheck %s -func.func @stablehlo_while_to_scf(){ +func.func @stablehlo_while_to_scf_for() -> (tensor, tensor) { %init_i = stablehlo.constant dense<1> :tensor %init_sum = stablehlo.constant dense<0> :tensor %one = stablehlo.constant dense<1> :tensor @@ -9,7 +9,7 @@ func.func @stablehlo_while_to_scf(){ %results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({ ^bb0(%arg0: tensor, %arg1: tensor): %cond = "stablehlo.compare"(%arg0, %ten) { - comparison_direction = #stablehlo + comparison_direction = #stablehlo } : (tensor, tensor) -> tensor stablehlo.return %cond : tensor }, { @@ -18,28 +18,142 @@ func.func @stablehlo_while_to_scf(){ %new_i = stablehlo.add %arg0, %one : tensor stablehlo.return %new_i, %new_sum : tensor, tensor }) : (tensor, tensor) -> (tensor, tensor) - return + return %results0, %results1 : tensor, tensor } -// CHECK-LABEL: @stablehlo_while_to_scf -// CHECK-NEXT: %[[v0:.+]] = stablehlo.constant -// CHECK-NEXT: %[[v1:.+]] = stablehlo.constant -// CHECK-NEXT: %[[v2:.+]] = stablehlo.constant -// CHECK-NEXT: %[[v3:.+]] = stablehlo.constant -// CHECK-NEXT: %[[v4:.+]]:2 = scf.while (%[[arg0:.+]] = %[[v0]], %[[arg1:.+]] = %[[v1]]) : {{.*}} { -// CHECK-NEXT: %[[v5:.+]] = stablehlo.compare LT, %[[arg0:.+]], %[[v3]] : {{.*}} -// CHECK-NEXT: %[[extracted:.+]] = tensor.extract %[[v5]][] : {{.*}} -// CHECK-NEXT: scf.condition(%[[extracted]]) %[[arg0]], %[[arg1]] : {{.*}} -// CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[arg0:.+]]: {{.*}}, %[[arg1:.+]]: {{.*}}): -// CHECK-NEXT: %[[v5:.+]] = stablehlo.add %[[arg1]], %[[v2]] : {{.*}} -// CHECK-NEXT: %[[v6:.+]] = stablehlo.add %[[arg0]], %[[v2]] : {{.*}} -// CHECK-NEXT: scf.yield %[[v6]], %[[v5]] : {{.*}} -// CHECK: return +// CHECK-LABEL: func.func @stablehlo_while_to_scf_for +// CHECK-DAG: %[[cst:.+]] = arith.constant dense<9> : tensor +// CHECK-DAG: %[[c10_i64:.+]] = arith.constant 10 : i64 +// CHECK-DAG: %[[c:.+]] = stablehlo.constant dense<0> : tensor +// CHECK-DAG: %[[c1_i64:.+]] = arith.constant 1 : i64 +// CHECK-DAG: %[[cst1:.+]] = stablehlo.constant dense<1> : tensor +// CHECK: %[[v0:.+]] = scf.for %[[arg0:.+]] = %[[c1_i64]] to %[[c10_i64]] step %[[c1_i64]] +// CHECK-SAME: iter_args(%[[arg1:.+]] = %[[c]]) -> (tensor) +// CHECK-DAG: %[[v1:.+]] = stablehlo.add %[[arg1]], %[[cst1]] : tensor +// CHECK-DAG: scf.yield %[[v1]] : tensor +// CHECK: return %[[cst]], %[[v0]] : tensor, tensor // ----- -func.func @if_ops_true_branch() { +func.func private @condition(%arg0: tensor, %arg1: tensor) -> tensor + +func.func @stablehlo_while_to_scf_while(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %one = stablehlo.constant dense<1> :tensor + %results0, %results1 = "stablehlo.while"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %cond = call @condition(%arg2, %arg3) : (tensor, tensor) -> tensor + stablehlo.return %cond : tensor + }, { + ^bb0(%arg4: tensor, %arg5: tensor): + %new_sum = stablehlo.add %arg5, %one : tensor + %new_i = stablehlo.add %arg4, %one : tensor + stablehlo.return %new_i, %new_sum : tensor, tensor + }) : (tensor, tensor) -> (tensor, tensor) + return %results0, %results1 : tensor, tensor +} + +// CHECK-LABEL: func.func @stablehlo_while_to_scf_while +// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor) +// CHECK: %[[c1_i64:.+]] = stablehlo.constant dense<1> : tensor +// CHECK: %[[v0:.+]]:2 = scf.while (%[[arg2:.+]] = %[[arg0]], %[[arg3:.+]] = %[[arg1]]) +// CHECK-DAG: %[[v1:.+]] = func.call @condition(%[[arg2]], %[[arg3]]) : (tensor, tensor) -> tensor +// CHECK-DAG: %[[extracted:.+]] = tensor.extract %[[v1]][] : tensor +// CHECK-DAG: scf.condition(%[[extracted]]) %[[arg2]], %[[arg3]] : tensor, tensor +// CHECK: } do { +// CHECK: ^bb0(%[[arg2:.+]]: tensor, %[[arg3:.+]]: tensor): +// CHECK-DAG: %[[v1:.+]] = stablehlo.add %[[arg3]], %[[c1_i64]] +// CHECK-DAG: %[[v2:.+]] = stablehlo.add %[[arg2]], %[[c1_i64]] +// CHECK-DAG: scf.yield %[[v2]], %[[v1]] : tensor, tensor +// CHECK: return %[[v0]]#0, %[[v0]]#1 : tensor, tensor + +// ----- + +func.func private @some_compute(tensor) -> tensor<1xf32> + +func.func @stablehlo_while_regression(%arg0: tensor<1xf32>, %arg1: tensor) -> tensor<1xf32> { + %c_33 = stablehlo.constant dense<0> : tensor + %cst = stablehlo.constant dense<0.000000e+00> : tensor<1xf32> + %c_31 = stablehlo.constant dense<1> : tensor + %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor + %5:2 = stablehlo.while(%iterArg = %c_33, %iterArg_34 = %cst) : tensor, tensor<1xf32> + cond { + %6 = stablehlo.compare LT, %iterArg, %c_31, SIGNED : (tensor, tensor) -> tensor + stablehlo.return %6 : tensor + } do { + %6 = stablehlo.compare LT, %iterArg, %c_33, SIGNED : (tensor, tensor) -> tensor + %7 = stablehlo.add %iterArg, %c_31 : tensor + %8 = stablehlo.select %6, %7, %iterArg : tensor, tensor + %10 = stablehlo.dynamic_slice %arg0, %8, sizes = [1] : (tensor<1xf32>, tensor) -> tensor<1xf32> + %11 = stablehlo.reshape %10 : (tensor<1xf32>) -> tensor + + %12 = stablehlo.compare GE, %11, %cst_0, FLOAT : (tensor, tensor) -> tensor + %35 = stablehlo.select %12, %cst_0, %arg1 : tensor, tensor + %39 = func.call @some_compute(%35) : (tensor) -> tensor<1xf32> + stablehlo.return %7, %39 : tensor, tensor<1xf32> + } + return %5#1 : tensor<1xf32> +} + +// CHECK-LABEL: func.func @stablehlo_while_regression +// CHECK: scf.while + +// ----- + + +func.func @dont_scalarize_while(%arg0: tensor) -> tensor { + %0 = stablehlo.while(%iterArg = %arg0) : tensor + cond { + %c0 = stablehlo.constant dense<0.0> : tensor + %1 = stablehlo.compare LT, %iterArg, %c0, SIGNED : (tensor, tensor) -> tensor + stablehlo.return %1 : tensor + } do { + %c1 = stablehlo.constant dense<1.0> : tensor + %2 = stablehlo.subtract %iterArg, %c1 : tensor + stablehlo.return %2 : tensor + } + return %0 : tensor +} + +// CHECK-LABEL: @dont_scalarize_while +// CHECK: scf.while {{.*}} (tensor) -> tensor +// CHECK: scf.condition{{.*}} : tensor +// CHECK: scf.yield{{.*}} : tensor + +// ----- + +func.func @if_op_not_foldable(%arg0: tensor, %arg1: tensor) -> (tensor<2xi64>, tensor<2xi64>) { + %cond = stablehlo.compare LT, %arg0, %arg1 : (tensor, tensor) -> tensor + %result0, %result1 = "stablehlo.if"(%cond) ({ + %0 = stablehlo.constant dense<0> : tensor<2xi64> + %1 = stablehlo.constant dense<2> : tensor<2xi64> + %3 = stablehlo.add %0, %1 : tensor<2xi64> + stablehlo.return %3, %3 : tensor<2xi64>, tensor<2xi64> + }, { + %1 = stablehlo.constant dense<1> : tensor<2xi64> + stablehlo.return %1, %1 : tensor<2xi64>, tensor<2xi64> + }) : (tensor) -> (tensor<2xi64>, tensor<2xi64>) + func.return %result0, %result1 : tensor<2xi64>, tensor<2xi64> +} + +// CHECK-LABEL: func.func @if_op_not_foldable +// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor) -> (tensor<2xi64>, tensor<2xi64>) { +// CHECK-DAG: %[[c:.+]] = stablehlo.constant dense<1> : tensor<2xi64> +// CHECK-DAG: %[[c_0:.+]] = stablehlo.constant dense<2> : tensor<2xi64> +// CHECK-DAG: %[[c_1:.+]] = stablehlo.constant dense<0> : tensor<2xi64> +// CHECK-DAG: %[[extracted:.+]] = tensor.extract %[[arg0]][] : tensor +// CHECK-DAG: %[[extracted_2:.+]] = tensor.extract %[[arg1]][] : tensor +// CHECK-DAG: %[[v0:.+]] = arith.cmpi slt, %[[extracted]], %[[extracted_2]] : i64 +// CHECK: %[[v1]]:2 = scf.if %[[v0]] -> (tensor<2xi64>, tensor<2xi64>) { +// CHECK: %[[v2:.+]] = stablehlo.add %[[c_1]], %[[c_0]] : tensor<2xi64> +// CHECK: scf.yield %[[v2]], %[[v2]] : tensor<2xi64>, tensor<2xi64> +// CHECK: } else { +// CHECK: scf.yield %[[c]], %[[c]] : tensor<2xi64>, tensor<2xi64> +// CHECK: return %[[v1]]#0, %[[v1]]#1 : tensor<2xi64>, tensor<2xi64> + + +// ----- + +func.func @if_ops_true_foldable() -> (tensor<2xi64>, tensor<2xi64>) { %pred = stablehlo.constant dense : tensor %result0, %result1 = "stablehlo.if"(%pred) ({ %0 = stablehlo.constant dense<0> : tensor<2xi64> @@ -50,96 +164,94 @@ func.func @if_ops_true_branch() { %1 = stablehlo.constant dense<1> : tensor<2xi64> stablehlo.return %1, %1 : tensor<2xi64>, tensor<2xi64> }) : (tensor) -> (tensor<2xi64>, tensor<2xi64>) - func.return + func.return %result0, %result1 : tensor<2xi64>, tensor<2xi64> } -// CHECK-LABEL: @if_ops_true_branch() { -// CHECK-NEXT: %[[v0:.+]] = stablehlo.constant dense : tensor -// CHECK-NEXT: %[[extracted:.+]] = tensor.extract %[[v0]][] : tensor -// CHECK-NEXT: %[[v1:.+]]:2 = scf.if %[[extracted]] -> {{.*}} { -// CHECK-NEXT: %[[v2:.+]] = stablehlo.constant {{.*}} -// CHECK-NEXT: %[[v3:.+]] = stablehlo.constant{{.*}} -// CHECK-NEXT: %[[v4:.+]] = stablehlo.add %[[v2]], %[[v3]] : {{.*}} -// CHECK-NEXT: scf.yield %[[v4]], %[[v4]] : {{.*}} -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[v2:.+]] = stablehlo.constant {{.*}} -// CHECK-NEXT: scf.yield %[[v2]], %[[v2]] : {{.*}} -// CHECK: return +// CHECK-LABEL: func.func @if_ops_true_foldable +// CHECK-DAG: %[[c:.+]] = stablehlo.constant dense<2> : tensor<2xi64> +// CHECK-DAG: %[[c_0:.+]] = stablehlo.constant dense<0> : tensor<2xi64> +// CHECK-DAG: %[[v0:.+]] = stablehlo.add %[[c_0]], %[[c]] : tensor<2xi64> +// CHECK-DAG: return %[[v0]], %[[v0]] : tensor<2xi64>, tensor<2xi64> // ----- -func.func @case_one_branch() { +func.func @case_one_branch() -> (tensor<2xi64>, tensor<2xi64>) { %index = stablehlo.constant dense<0> : tensor %result_branch0 = stablehlo.constant dense<0> : tensor<2xi64> %result0, %result1 = "stablehlo.case"(%index) ({ stablehlo.return %result_branch0, %result_branch0 : tensor<2xi64>, tensor<2xi64> }) : (tensor) -> (tensor<2xi64>, tensor<2xi64>) - func.return + func.return %result0, %result1 : tensor<2xi64>, tensor<2xi64> } -// CHECK-LABEL: case_one_branch -// CHECK-NEXT: %[[v0:.+]] = stablehlo.constant -// CHECK-NEXT: %[[v1:.+]] = stablehlo.constant -// CHECK-NEXT: return +// CHECK-LABEL: func.func @case_one_branch +// CHECK: %[[c:.+]] = stablehlo.constant dense<0> : tensor<2xi64> +// CHECK: return %[[c]], %[[c]] : tensor<2xi64>, tensor<2xi64> // ----- -func.func @case_two_branches() { - %index = stablehlo.constant dense<0> : tensor +func.func @case_two_branches(%arg0: tensor) -> (tensor<2xi64>, tensor<2xi64>) { %result_branch0 = stablehlo.constant dense<0> : tensor<2xi64> %result_branch1 = stablehlo.constant dense<1> : tensor<2xi64> - %result0, %result1 = "stablehlo.case"(%index) ({ + %result0, %result1 = "stablehlo.case"(%arg0) ({ stablehlo.return %result_branch0, %result_branch0 : tensor<2xi64>, tensor<2xi64> },{stablehlo.return %result_branch1, %result_branch1 : tensor<2xi64>, tensor<2xi64> }) : (tensor) -> (tensor<2xi64>, tensor<2xi64>) - func.return + func.return %result0, %result1 : tensor<2xi64>, tensor<2xi64> } -// CHECK-LABEL: case_two_branches -// CHECK-NEXT: %[[v0:.+]] = stablehlo.constant -// CHECK-NEXT: %[[v1:.+]] = stablehlo.constant -// CHECK-NEXT: %[[v2:.+]] = stablehlo.constant -// CHECK-NEXT: %[[v3:.+]] = stablehlo.constant -// CHECK-NEXT: %[[v4:.+]] = stablehlo.compare EQ, %[[v0]], %[[v3]] -// CHECK-NEXT: %[[extracted:.+]] = tensor.extract %[[v4]][] -// CHECK-NEXT: %[[v5:.+]]:2 = scf.if %[[extracted]] -// CHECK-NEXT: scf.yield %[[v1]], %[[v1]] -// CHECK-NEXT: } else { -// CHECK-NEXT: scf.yield %[[v2]], %[[v2]] -// CHECK: return +// CHECK-LABEL: func.func @case_two_branches +// CHECK-SAME: (%[[arg0:.+]]: tensor) +// CHECK-DAG: %[[c0_i32:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[c:.+]] = stablehlo.constant dense<0> : tensor<2xi64> +// CHECK-DAG: %[[c_0:.+]] = stablehlo.constant dense<1> : tensor<2xi64> +// CHECK-DAG: %[[extracted:.+]] = tensor.extract %[[arg0]][] : tensor +// CHECK-DAG: %[[v0:.+]] = arith.cmpi eq, %[[extracted]], %[[c0_i32]] : i32 +// CHECK-DAG: %[[v1:.+]] = arith.select %[[v0]], %[[c]], %[[c_0]] : tensor<2xi64> +// CHECK-DAG: %[[v2:.+]] = arith.select %[[v0]], %[[c]], %[[c_0]] : tensor<2xi64> +// CHECK: return %[[v1]], %[[v2]] : tensor<2xi64>, tensor<2xi64> // ----- -func.func @case_three_branches() { - %index = stablehlo.constant dense<0> : tensor - %result_branch0 = stablehlo.constant dense<0> : tensor<2xi64> - %result_branch1 = stablehlo.constant dense<1> : tensor<2xi64> - %result_branch2 = stablehlo.constant dense<2> : tensor<2xi64> - %result0, %result1 = "stablehlo.case"(%index) ({ - stablehlo.return %result_branch0, %result_branch0 : tensor<2xi64>, tensor<2xi64> - },{stablehlo.return %result_branch1, %result_branch1 : tensor<2xi64>, tensor<2xi64> - },{stablehlo.return %result_branch2, %result_branch2 : tensor<2xi64>, tensor<2xi64> - }) : (tensor) -> (tensor<2xi64>, tensor<2xi64>) - func.return +func.func @case_three_branches( + %index: tensor, %arg0: tensor<2xi64>, %arg1: tensor<2xi64>, %arg2: tensor<2xi64>) + -> (tensor<2xi64>) { + %result = "stablehlo.case"(%index) ({ + %0 = stablehlo.add %arg0, %arg1 : tensor<2xi64> + %1 = stablehlo.multiply %0, %arg2 : tensor<2xi64> + stablehlo.return %1 : tensor<2xi64> + },{ + %0 = stablehlo.add %arg1, %arg2 : tensor<2xi64> + %1 = stablehlo.multiply %0, %arg0 : tensor<2xi64> + stablehlo.return %1 : tensor<2xi64> + },{ + %0 = stablehlo.add %arg2, %arg0 : tensor<2xi64> + %1 = stablehlo.multiply %0, %arg1 : tensor<2xi64> + stablehlo.return %1 : tensor<2xi64> + }) : (tensor) -> (tensor<2xi64>) + func.return %result : tensor<2xi64> } -// CHECK-LABEL: case_three_branches -// CHECK-NEXT: %[[v0:.+]] = stablehlo.constant -// CHECK-NEXT: %[[v1:.+]] = stablehlo.constant -// CHECK-NEXT: %[[v2:.+]] = stablehlo.constant -// CHECK-NEXT: %[[v3:.+]] = stablehlo.constant -// CHECK-NEXT: %[[v4:.+]] = stablehlo.constant -// CHECK-NEXT: %[[v5:.+]] = stablehlo.compare EQ, %[[v0]], %[[v4]] -// CHECK-NEXT: %[[extracted:.+]] = tensor.extract %[[v5]][] -// CHECK-NEXT: %[[v6:.+]]:2 = scf.if %[[extracted]] -// CHECK-NEXT: scf.yield %[[v1]], %[[v1]] -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[v7:.+]] = stablehlo.constant -// CHECK-NEXT: %[[v8:.+]] = stablehlo.compare EQ, %[[v0]], %[[v7]] -// CHECK-NEXT: %[[extracted_0:.+]] = tensor.extract %[[v8]][] -// CHECK-NEXT: %[[v9:.+]]:2 = scf.if %[[extracted_0]] -// CHECK-NEXT: scf.yield %[[v2]], %[[v2]] -// CHECK-NEXT: } else { -// CHECK-NEXT: scf.yield %[[v3]], %[[v3]] -// CHECK: scf.yield %[[v9]]#0, %[[v9]]#1 -// CHECK: return \ No newline at end of file +// CHECK-LABEL: func.func @case_three_branches +// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor<2xi64>, %[[arg2:.+]]: tensor<2xi64>, %[[arg3:.+]]: tensor<2xi64>) +// CHECK-DAG: %[[c1_i32:.+]] = arith.constant 1 : i32 +// CHECK-DAG: %[[c0_i32:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[extracted:.+]] = tensor.extract %[[arg0]][] : tensor +// CHECK-DAG: %[[v0:.+]] = arith.cmpi eq, %[[extracted]], %[[c0_i32]] : i32 +// CHECK-DAG: %[[v1:.+]] = scf.if %[[v0]] -> (tensor<2xi64>) { +// CHECK: %[[v2:.+]] = stablehlo.add %[[arg1]], %[[arg2]] : tensor<2xi64> +// CHECK: %[[v3:.+]] = stablehlo.multiply %[[v2]], %[[arg3]] : tensor<2xi64> +// CHECK: scf.yield %[[v3]] : tensor<2xi64> +// CHECK: } else { +// CHECK-DAG: %[[extracted_0:.+]] = tensor.extract %[[arg0]][] : tensor +// CHECK-DAG: %[[v2:.+]] = arith.cmpi eq, %[[extracted_0]], %[[c1_i32]] : i32 +// CHECK: %[[v3:.+]] = scf.if %[[v2]] -> (tensor<2xi64>) { +// CHECK-DAG: %[[v4:.+]] = stablehlo.add %[[arg2]], %[[arg3]] : tensor<2xi64> +// CHECK-DAG: %[[v5:.+]] = stablehlo.multiply %[[v4]], %[[arg1]] : tensor<2xi64> +// CHECK: scf.yield %[[v5]] : tensor<2xi64> +// CHECK: } else { +// CHECK-DAG: %[[v4:.+]] = stablehlo.add %[[arg3]], %[[arg1]] : tensor<2xi64> +// CHECK-DAG: %[[v5:.+]] = stablehlo.multiply %[[v4]], %[[arg2]] : tensor<2xi64> +// CHECK-DAG: scf.yield %[[v5]] : tensor<2xi64> +// CHECK: scf.yield %[[v3]] : tensor<2xi64> +// CHECK: return %[[v1]] : tensor<2xi64> \ No newline at end of file diff --git a/mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/dot-to-einsum.mlir b/mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/dot-to-einsum.mlir new file mode 100644 index 000000000..ce2e69104 --- /dev/null +++ b/mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/dot-to-einsum.mlir @@ -0,0 +1,141 @@ +// RUN: mlir-tensorrt-opt %s -split-input-file -convert-stablehlo-to-tensorrt=prefer-einsum=true | FileCheck %s + +!lhs = tensor<2x10x20x30x40xf32> +!rhs = tensor<2x10x20x30x40xf32> +!result = tensor<2x10x20x10x20xf32> + +// CHECK-LABEL: @dot_general_multiple_contraction_dims1 +// CHECK-SAME: (%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: tensor<{{.*}}>) +func.func @dot_general_multiple_contraction_dims1(%arg0: !lhs, %arg1: !rhs) -> !result { + // CHECK: %[[v0:.+]] = tensorrt.einsum + // CHECK-SAME: "adebc,afgbc->adefg" + // CHECK-SAME: ins(%[[arg0]], %[[arg1]] + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [3,4], + rhs_contracting_dimensions = [3,4]>, + precision_config = [#stablehlo, #stablehlo] + } : (!lhs, !rhs) -> !result + // CHECK: return %[[v0]] + return %0 : !result +} + +// ----- + +!lhs = tensor<2x10x30x20x40xf32> +!rhs = tensor<2x30x10x20x40xf32> +!result = tensor<2x10x20x10x20xf32> + +// CHECK-LABEL: @dot_general_multiple_contraction_dims2 +// CHECK-SAME: (%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: tensor<{{.*}}>) +func.func @dot_general_multiple_contraction_dims2(%arg0: !lhs, %arg1: !rhs) -> !result { + // CHECK: %[[v0:.+]] = tensorrt.einsum + // CHECK-SAME: "adbec,abfgc->adefg" + // CHECK-SAME: ins(%[[arg0]], %[[arg1]] + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2,4], + rhs_contracting_dimensions = [1,4]>, + precision_config = [#stablehlo, #stablehlo] + } : (!lhs, !rhs) -> !result + // CHECK: return %[[v0]] + return %0 : !result +} + +// ----- + +// CHECK-LABEL: @dot_general_multiple_outer_product_dims +// CHECK-SAME: (%[[arg0:.+]]: tensor<32x49x32xf32>, %[[arg1:.+]]: tensor<32x1x32x49xf32>) +func.func @dot_general_multiple_outer_product_dims(%arg0: tensor<32x49x32xf32>, + %arg1: tensor<32x1x32x49xf32>) -> tensor<32x49x1x49xf32> { + // CHECK: %[[v0:.+]] = tensorrt.einsum + // CHECK-SAME: "acb,adbe->acde" + // CHECK-SAME: ins(%[[arg0]], %[[arg1]] : tensor<32x49x32xf32>, tensor<32x1x32x49xf32>) -> tensor<32x49x1x49xf32> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [2]>, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<32x49x32xf32>, tensor<32x1x32x49xf32>) -> tensor<32x49x1x49xf32> + // CHECK: return %[[v0]] + return %0 : tensor<32x49x1x49xf32> +} + +// ----- + +!lhs = tensor +!rhs = tensor +!result = tensor + +// CHECK-LABEL: @simple_dot_general1 +// CHECK-SAME: (%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: tensor<{{.*}}>) +func.func @simple_dot_general1(%arg0: !lhs, %arg1: !rhs) -> !result { + // CHECK: %[[v0:.+]] = tensorrt.einsum + // CHECK-SAME: "abdc,abce->abde" + // CHECK-SAME: ins(%[[arg0]], %[[arg1]] + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0, 1], + rhs_batching_dimensions = [0, 1], + lhs_contracting_dimensions = [3], + rhs_contracting_dimensions = [2]>, + precision_config = [#stablehlo, #stablehlo] + } : (!lhs, !rhs) -> !result + // CHECK: return %[[v0]] + return %0 : !result +} + + +// ----- + +!lhs = tensor +!rhs = tensor +!result = tensor + +// CHECK-LABEL: @simple_dot_general2 +// CHECK-SAME: (%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: tensor<{{.*}}>) +func.func @simple_dot_general2(%arg0: !lhs, %arg1: !rhs) -> !result { + // CHECK: %[[v0:.+]] = tensorrt.einsum + // CHECK-SAME: "abcd,abce->abde" + // CHECK-SAME: ins(%[[arg0]], %[[arg1]] + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0, 1], + rhs_batching_dimensions = [0, 1], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [2]>, + precision_config = [#stablehlo, #stablehlo] + } : (!lhs, !rhs) -> !result + // CHECK: return %[[v0]] + return %0 : !result +} + + +// ----- + +// CHECK-LABEL: func.func @dot_general_promoted_result_type +// CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor) +func.func @dot_general_promoted_result_type(%arg0: tensor, %arg1: tensor) -> tensor { +// CHECK: %[[v0:.+]] = tensorrt.identity %[[arg0]] : tensor to tensor +// CHECK: %[[v1:.+]] = tensorrt.identity %[[arg1]] : tensor to tensor +// CHECK: %[[v2:.+]] = tensorrt.einsum +// CHECK-SAME: "abc,abcd->abd" +// CHECK-SAME: ins(%[[v0]], %[[v1]] : tensor, tensor) -> tensor + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0, 1], + rhs_batching_dimensions = [0, 1], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [2]>, + precision_config = [#stablehlo, #stablehlo] + } : (tensor, tensor) -> tensor + // CHECK: return %[[v2]] + return %0 : tensor +} + diff --git a/mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-invalid.mlir b/mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-invalid.mlir index fc77bd61f..e15585035 100644 --- a/mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-invalid.mlir +++ b/mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-invalid.mlir @@ -91,20 +91,3 @@ func.func private @block_dq(%arg0: tensor<258x256xi4>) -> tensor<258x256xf32> at %3 = stablehlo.multiply %2, %1 : tensor<258x256xf32> return %3 : tensor<258x256xf32> } - -// ----- - -// CHECK-LABEL: @unsupported_multiple_outer_product_dims -func.func @unsupported_multiple_outer_product_dims(%arg0: tensor<32x49x32xf32>, - %arg1: tensor<32x1x32x49xf32>) -> tensor<32x49x1x49xf32> { - // CHECK: stablehlo.dot_general - %0 = "stablehlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #stablehlo.dot< - lhs_batching_dimensions = [0], - rhs_batching_dimensions = [0], - lhs_contracting_dimensions = [2], - rhs_contracting_dimensions = [2]>, - precision_config = [#stablehlo, #stablehlo] - } : (tensor<32x49x32xf32>, tensor<32x1x32x49xf32>) -> tensor<32x49x1x49xf32> - return %0 : tensor<32x49x1x49xf32> -} diff --git a/mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-trt10.mlir b/mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-trt10.mlir index 6f1f821c2..67c30f529 100644 --- a/mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-trt10.mlir +++ b/mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-trt10.mlir @@ -382,7 +382,7 @@ func.func @expm1_bf16(%arg0: tensor<3xbf16>) -> tensor<3xbf16> { // CHECK-DAG: %[[v3:.+]] = tensorrt.element_wise (%[[v1]], %[[v1]] : tensor<3xf32>, tensor<3xf32>) // CHECK-DAG: %[[v4:.+]] = tensorrt.unary {unaryOperation = #tensorrt.unary_operation} %[[v3]] // CHECK-DAG: %[[v5:.+]] = tensorrt.element_wise (%[[v2]], %[[v4]] : tensor<3xi1>, tensor<3xi1> -// CHECK-DAG: %[[v6:.+]] = tensorrt.element_wise (%[[v1]], %[[cst_f32]] : +// CHECK-DAG: %[[v6:.+]] = tensorrt.element_wise (%[[v1]], %[[cst_f32]] : // CHECK-DAG: %[[v7:.+]] = tensorrt.element_wise (%[[v6]], %[[cst_f32_0]] // CHECK-DAG: %[[v8:.+]] = tensorrt.unary {unaryOperation = #tensorrt.unary_operation} %[[v1]] // CHECK-DAG: %[[v9:.+]] = tensorrt.element_wise (%[[v8]], %[[v1]] @@ -393,3 +393,4 @@ func.func @expm1_bf16(%arg0: tensor<3xbf16>) -> tensor<3xbf16> { // CHECK-DAG: %[[v14:.+]] = tensorrt.select ins(%[[v5]], %[[v0]], %[[v13]] // CHECK-DAG: %[[v15:.+]] = tensorrt.identity %[[v14]] : tensor<3xf32> to tensor<3xbf16> // CHECK-DAG: return %[[v15]] : tensor<3xbf16> + diff --git a/mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt.mlir b/mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt.mlir index b40f2f428..aac02c217 100644 --- a/mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt.mlir +++ b/mlir-tensorrt/compiler/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt.mlir @@ -444,7 +444,7 @@ func.func @hlo_dot_general(%arg0: tensor, %arg1: tensor } -// CHECK-LABEL: @hlo_dot_general +// CHECK-LABEL: @hlo_dot_general( // CHECK: tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} @@ -504,18 +504,6 @@ func.func @dot_general_promoted_result_type(%arg0: tensor, %arg1: te // ----- -func.func @main(%arg0: tensor<1x1500x384xf32>, %arg1: tensor<384x384xf32>) -> tensor<1x1500x384xf32> { - %0 = "stablehlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #stablehlo.dot< - lhs_contracting_dimensions = [2], - rhs_contracting_dimensions = [0]>, - precision_config = [#stablehlo, #stablehlo] - } : (tensor<1x1500x384xf32>, tensor<384x384xf32>) -> tensor<1x1500x384xf32> - return %0 : tensor<1x1500x384xf32> -} - -// ----- - func.func @hlo_einsum(%arg0: tensor, %arg1: tensor<64x256xf32>) -> tensor { %0 = "stablehlo.einsum"(%arg0, %arg1) {einsum_config = "abcd,de->abce"} : (tensor, tensor<64x256xf32>) -> tensor diff --git a/mlir-tensorrt/compiler/test/Dialect/LinalgExt/to-loops.mlir b/mlir-tensorrt/compiler/test/Dialect/LinalgExt/to-loops.mlir new file mode 100644 index 000000000..caecb2636 --- /dev/null +++ b/mlir-tensorrt/compiler/test/Dialect/LinalgExt/to-loops.mlir @@ -0,0 +1,124 @@ +// RUN: mlir-tensorrt-opt %s -convert-to-loops -split-input-file -cse | FileCheck %s + +func.func @linalg_generic_to_loops(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1, 0)>, + affine_map<(d0, d1) -> (d1, d0)> + ], + iterator_types = ["parallel", "parallel"] + } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2 : tensor<4x4xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %0 = arith.addf %arg3, %arg4 : f32 + %1 = arith.addf %0, %arg5 : f32 + linalg.yield %1 : f32 + } -> tensor<4x4xf32> + return %0 : tensor<4x4xf32> +} + +// CHECK-LABEL: func.func @linalg_generic_to_loops +// CHECK-SAME: (%[[arg0:.+]]: tensor<4x4xf32>, %[[arg1:.+]]: tensor<4x4xf32>, %[[arg2:.+]]: tensor<4x4xf32>) +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 +// CHECK-DAG: %[[c4:.+]] = arith.constant 4 +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 +// CHECK: %[[v0:.+]] = scf.for %[[arg3:.+]] = %[[c0]] to %[[c4]] step %[[c1]] iter_args(%[[arg4:.+]] = %[[arg2]]) +// CHECK: %[[v1:.+]] = scf.for %[[arg5:.+]] = %[[c0]] to %[[c4]] step %[[c1]] iter_args(%[[arg6:.+]] = %[[arg4]]) +// CHECK-DAG: %[[extracted:.+]] = tensor.extract %[[arg0]][%[[arg3]], %[[arg5]]] +// CHECK-DAG: %[[extracted_0:.+]] = tensor.extract %[[arg1]][%[[arg5]], %[[c0]]] +// CHECK-DAG: %[[extracted_1:.+]] = tensor.extract %[[arg6]][%[[arg5]], %[[arg3]]] +// CHECK-DAG: %[[v2:.+]] = arith.addf %[[extracted]], %[[extracted_0]] +// CHECK-DAG: %[[v3:.+]] = arith.addf %[[v2]], %[[extracted_1]] +// CHECK-DAG: %[[inserted:.+]] = tensor.insert %[[v3]] into %[[arg6]][%[[arg5]], %[[arg3]]] +// CHECK: scf.yield %[[inserted]] +// CHECK: scf.yield %[[v1]] +// CHECK: return %[[v0]] + +// ----- + +func.func @contraction(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4xf32>) + -> tensor<4xf32> { + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0)> + ], + iterator_types = ["parallel", "reduction"] + } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2 : tensor<4xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %0 = arith.addf %arg3, %arg4 : f32 + %1 = arith.addf %0, %arg5 : f32 + linalg.yield %1 : f32 + } -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func.func @c +// CHECK-SAME: (%[[arg0:.+]]: tensor<4x4xf32>, %[[arg1:.+]]: tensor<4x4xf32>, %[[arg2:.+]]: tensor<4xf32>) +// CHECK: %[[c0:.+]] = arith.constant 0 +// CHECK: %[[c4:.+]] = arith.constant 4 +// CHECK: %[[c1:.+]] = arith.constant 1 +// CHECK: %[[v0:.+]] = scf.for %[[arg3:.+]] = %[[c0]] to %[[c4]] step %[[c1]] iter_args(%[[arg4:.+]] = %[[arg2]]) +// CHECK: %[[v1:.+]] = scf.for %[[arg5:.+]] = %[[c0]] to %[[c4]] step %[[c1]] iter_args(%[[arg6:.+]] = %[[arg4]]) +// CHECK: %[[extracted:.+]] = tensor.extract %[[arg0]][%[[arg3]], %[[arg5]]] +// CHECK: %[[extracted_0:.+]] = tensor.extract %[[arg1]][%[[arg5]], %[[arg3]]] +// CHECK: %[[extracted_1:.+]] = tensor.extract %[[arg6]][%[[arg3]]] +// CHECK: %[[v2:.+]] = arith.addf %[[extracted]], %[[extracted_0]] +// CHECK: %[[v3:.+]] = arith.addf %[[v2]], %[[extracted_1]] +// CHECK: %[[inserted:.+]] = tensor.insert %[[v3]] into %[[arg6]][%[[arg3]]] +// CHECK: scf.yield %[[inserted]] +// CHECK: scf.yield %[[v1]] +// CHECK: return %[[v0]] + +// ----- + +func.func @linalg_index_op(%arg0: tensor<4x4xindex>) -> tensor<4x4xindex> { + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %3 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> ()>, + affine_map<(d0, d1) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel"] + } ins(%c1 : index) outs(%arg0 : tensor<4x4xindex>) { + ^bb0(%arg1: index, %arg2: index): + %l1 = linalg.index 0 : index + %l2 = linalg.index 1 : index + %0 = arith.muli %l1, %c4 : index + %1 = arith.addi %l2, %arg1 : index + %2 = arith.addi %0, %1 : index + linalg.yield %2 : index + } -> tensor<4x4xindex> + return %3 : tensor<4x4xindex> +} + +// CHECK-LABEL: func.func @linalg_index_op +// CHECK-SAME: (%[[arg0:.+]]: tensor<4x4xindex>) +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK: %[[v0:.+]] = scf.for %[[arg1:.+]] = %[[c0]] to %[[c4]] step %[[c1]] iter_args(%[[arg2:.+]] = %[[arg0]]) +// CHECK: %[[v1:.+]] = scf.for %[[arg3:.+]] = %[[c0]] to %[[c4]] step %[[c1]] iter_args(%[[arg4:.+]] = %[[arg2]]) +// CHECK-DAG: %[[v2:.+]] = arith.muli %[[arg1]], %[[c4]] +// CHECK-DAG: %[[v3:.+]] = arith.addi %[[arg3]], %[[c1]] +// CHECK-DAG: %[[v4:.+]] = arith.addi %[[v2]], %[[v3]] +// CHECK-DAG: %[[inserted:.+]] = tensor.insert %[[v4]] into %[[arg4]][%[[arg1]], %[[arg3]]] : tensor<4x4xindex> +// CHECK-DAG: scf.yield %[[inserted]] : tensor<4x4xindex> +// CHECK: scf.yield %[[v1]] +// CHECK: return %[[v0]] + +// ----- + +func.func @linalg_map(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = linalg.map {arith.addf} ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2: tensor<4x4xf32>) + return %0 : tensor<4x4xf32> +} + +// CHECK-LABEL: func.func @linalg_map +// CHECK-COUNT-2: scf.for +// CHECK-COUNT-2: tensor.extract +// CHECK: arith.addf +// CHECK: tensor.insert +// CHECK-COUNT-2: scf.yield \ No newline at end of file diff --git a/mlir-tensorrt/compiler/test/Dialect/Plan/assign-and-optimize-memory-spaces.mlir b/mlir-tensorrt/compiler/test/Dialect/Plan/assign-and-optimize-memory-spaces.mlir new file mode 100644 index 000000000..ca3139943 --- /dev/null +++ b/mlir-tensorrt/compiler/test/Dialect/Plan/assign-and-optimize-memory-spaces.mlir @@ -0,0 +1,238 @@ +// RUN: mlir-tensorrt-opt %s -split-input-file \ +// RUN: -pass-pipeline="builtin.module(plan-assign-memory-spaces,func.func(plan-optimize-memory-spaces))" \ +// RUN: | FileCheck %s + +func.func private @cond() -> i1 + +// CHECK-LABEL: func.func @scf_while_loop_2 +// CHECK: scf.while {{.*}}tensor<1xf32, #plan.memory_space>) -> tensor<1xf32, #plan.memory_space> +// CHECK-NOT: #plan.memory_space +func.func @scf_while_loop_2(%arg0: f32) -> f32 { + %c0 = arith.constant 0 : index + %1 = tensor.from_elements %arg0 : tensor<1xf32> + %2 = scf.while (%arg1 = %1) : (tensor<1xf32>) -> tensor<1xf32> { + %cond = func.call @cond() : () -> i1 + %e = tensor.extract %arg1[%c0] : tensor<1xf32> + %f = arith.addf %e, %e : f32 + %3 = tensor.from_elements %f : tensor<1xf32> + scf.condition(%cond) %3 : tensor<1xf32> + } do { + ^bb0(%arg1: tensor<1xf32>): + %extract = tensor.extract %arg1[%c0] : tensor<1xf32> + %3 = arith.addf %extract, %extract : f32 + %4 = tensor.from_elements %3 : tensor<1xf32> + scf.yield %4 : tensor<1xf32> + } + %3 = tensor.extract %2[%c0] : tensor<1xf32> + return %3 : f32 +} + +// ----- + +// CHECK-LABEL: func.func @arith_constant +// CHECK: arith.constant {{.*}} : tensor<2xf32, #plan.memory_space> +// CHECK: arith.constant {{.*}} : tensor<2xf32, #plan.memory_space> +func.func @arith_constant() -> (tensor<2xf32>, tensor<2xf32>) { + %0 = arith.constant dense<[0.1, 0.2]> : tensor<2xf32> + %1 = arith.constant dense_resource<__elided__> : tensor<2xf32> + return %0, %1 : tensor<2xf32>, tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: module @nested_module +// CHECK-NOT: #plan.memory_space +module @outer { +module @nested_module { + func.func @nested_func() -> tensor<2xf32> { + %0 = arith.constant dense<[0.1, 0.2]> : tensor<2xf32> + return %0 : tensor<2xf32> + } +} +} + +// ----- + +// CHECK-LABEL: func.func @existing_constraint_1 +// CHECK: tensor.extract {{.*}} +func.func @existing_constraint_1(%arg0: tensor<2xf32, #plan.memory_space>) -> f32 { + %c0 = arith.constant 0 : index + %0 = tensor.extract %arg0[%c0] : tensor<2xf32, #plan.memory_space> + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func.func @existing_constraint_2 +// CHECK-NOT: tensor.cast +// CHECK: tensor.extract {{.*}} +func.func @existing_constraint_2(%arg0: tensor<2xf32, #plan.memory_space>) -> f32 { + %c0 = arith.constant 0 : index + %1 = tensor.cast %arg0 : tensor<2xf32, #plan.memory_space> to tensor<2xf32> + %0 = tensor.extract %1[%c0] : tensor<2xf32> + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func.func @host_func +// CHECK-SAME: (%[[arg0:.+]]: tensor<2xf32, #plan.memory_space>, %[[arg1:.+]]: tensor<2xf32, #plan.memory_space>) +// CHECK-SAME: -> tensor<2xf32, #plan.memory_space> +func.func @host_func(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> + attributes {plan.cluster_kind = #plan.host_cluster} { + // CHECK: %[[v0:.+]] = arith.addf %[[arg0]], %[[arg1]] : tensor<2xf32, #plan.memory_space> + %0 = arith.addf %arg0, %arg1 : tensor<2xf32> + // CHECK: return %[[v0]] + return %0 : tensor<2xf32> +} + +// CHECK-LABEL: func.func @default_func +// CHECK-SAME: (%[[arg0:.+]]: tensor<2xf32, #plan.memory_space>, %[[arg1:.+]]: tensor<2xf32, #plan.memory_space> +// CHECK-SAME: -> (tensor<2xf32, #plan.memory_space>, tensor<2xf32, #plan.memory_space> {plan.memory_space = #plan.memory_space}) +func.func @default_func(%arg0: tensor<2xf32>, %arg1: tensor<2xf32> {plan.memory_space = #plan.memory_space}) -> (tensor<2xf32>, tensor<2xf32> {plan.memory_space = #plan.memory_space}) { + // CHECK-DAG: %[[cast:.+]] = tensor.cast %[[arg0]] + // CHECK-DAG: %[[v0:.+]] = call @host_func(%[[cast]], %[[arg1]]) : + %0 = func.call @host_func(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[cast_0:.+]] = tensor.cast %[[v0]] : tensor<2xf32, #plan.memory_space> to tensor<2xf32, #plan.memory_space> + // CHECK-DAG: %[[extracted:.+]] = tensor.extract %[[v0]][%[[c0]]] : tensor<2xf32, #plan.memory_space> + // CHECK-DAG: %[[inserted:.+]] = tensor.insert %[[extracted]] into %[[v0]][%[[c0]]] + // CHECK-DAG: return %[[cast_0]], %[[inserted]] + %c0 = arith.constant 0 : index + %1 = tensor.extract %0[%c0] : tensor<2xf32> + %2 = tensor.insert %1 into %0[%c0] : tensor<2xf32> + return %0, %2 : tensor<2xf32>, tensor<2xf32> +} + +// ----- + + +// CHECK-LABEL: module @test_decl +// CHECK-LABEL: func.func private @decl(tensor<{{.*}}device>>, tensor<{{.*}}host>> {plan.memory_space = #plan.memory_space}) -> (tensor<{{.*}}host>> {plan.memory_space = #plan.memory_space}, tensor<{{.*}}device>>) + +module @test_decl { + +func.func private @decl(tensor<2xf32>, tensor<2xf32> {plan.memory_space = #plan.memory_space}) + -> (tensor<2xf32> {plan.memory_space = #plan.memory_space}, tensor<2xf32>) + +// CHECK-LABEL: func.func @caller +// CHECK-SAME: (%[[arg0:.+]]: tensor<2xf32, #plan.memory_space>, %[[arg1:.+]]: tensor<2xf32, #plan.memory_space> +func.func @caller(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-DAG: %[[cast:.+]] = tensor.cast %[[arg1]] : tensor<2xf32, #plan.memory_space> to tensor<2xf32, #plan.memory_space> + // CHECK-DAG: %[[v0:.+]]:2 = call @decl(%[[arg0]], %[[cast]]) + // CHECK-DAG: %[[v1:.+]] = tensor.cast %[[v0]]#0 + // CHECK-DAG: %[[v2:.+]] = arith.addf %[[v1]], %[[v0]]#1 + // CHECK-DAG: return %[[v2]] + %0:2 = func.call @decl(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) + %1 = arith.addf %0#0, %0#1 : tensor<2xf32> + return %1 : tensor<2xf32> +} + +} + +// ----- + +// CHECK-LABEL: func.func @multiple_blocks +func.func @multiple_blocks( + %c: i1, + %th: tensor<10xf32> {plan.memory_space = #plan.memory_space}) + -> (tensor<5xf32> {plan.memory_space = #plan.memory_space}) { + // CHECK: %[[cast:.+]] = tensor.cast %[[arg1]] {{.*}}device>> + // CHECK: %[[v0:.+]] = tensor.empty() {{.*}}device>> + %td = tensor.empty() : tensor<10xf32> + // CHECK: cf.cond_br + cf.cond_br %c, ^bb1, ^bb2 +// CHECK: ^bb1 +^bb1: + // CHECK-DAG: %[[es0:.+]] = tensor.extract_slice %[[v0]] + %0 = tensor.extract_slice %td[0][5][2] : tensor<10xf32> to tensor<5xf32> + // CHECK-DAG: return %[[es0]] {{.*}}device>> + return %0 : tensor<5xf32> +// CHECK: ^bb2 +^bb2: + // CHECK-DAG: %[[es0:.+]] = tensor.extract_slice %[[cast]] + %1 = tensor.extract_slice %th[1][5][1] : tensor<10xf32> to tensor<5xf32> + // CHECK-DAG: return %[[es0]] {{.*}}device>> + return %1 : tensor<5xf32> +} + +// ----- + +// Test that the `plan.memory_space` attribute on a function is respected +// but can be overriden by other constraints. + + +func.func @function_level_override( + %arg0: tensor<2xi32>, + %arg1: tensor<2xi32> {plan.memory_space = #plan.memory_space} +) -> + (tensor<2xindex, #plan.memory_space>, + tensor<2xi32>, + tensor<2xf32> {plan.memory_space = #plan.memory_space}) + attributes {plan.memory_space = #plan.memory_space} { + %cst = arith.constant dense<0> : tensor<2xindex> + %cast = tensor.cast %cst : tensor<2xindex> to tensor<2xindex, #plan.memory_space> + %cst1 = arith.constant dense<1> : tensor<2xi32> + %cst2 = arith.constant dense<2.0> : tensor<2xf32> + return %cast, %cst1, %cst2 + : tensor<2xindex, #plan.memory_space>, tensor<2xi32>, tensor<2xf32> +} + +// CHECK-LABEL: func.func @function_level_override +// CHECK-SAME: (%{{.+}}: tensor<2xi32, #plan.memory_space>, +// CHECK-SAME: %{{.+}}: tensor<2xi32, #plan.memory_space> +// CHECK-SAME: -> (tensor<2xindex, #plan.memory_space>, +// CHECK-SAME: tensor<2xi32, #plan.memory_space>, +// CHECK-SAME: tensor<2xf32, #plan.memory_space> + +// ----- + +func.func @tensor_reshape( + %arg0: tensor, + %arg1: i32, + %arg2: i32 +) -> tensor { + %0 = tensor.from_elements %arg1, %arg2 : tensor<2xi32> + %1 = tensor.reshape %arg0(%0) : (tensor, tensor<2xi32>) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: func.func @tensor_reshape +// CHECK-SAME: (%[[arg0:.+]]: tensor>, %[[arg1:.+]]: i32, %[[arg2:.+]]: i32) +// CHECK-DAG: %[[from_elements:.+]] = tensor.from_elements %[[arg1]], %[[arg2]] : tensor<2xi32, #plan.memory_space> +// CHECK-DAG: %[[reshape:.+]] = tensor.reshape %[[arg0]](%[[from_elements]]) +// CHECK-DAG: return %[[reshape]] + +// ----- + +func.func @alloc_tensor() -> tensor<2x128xf32> { + %0 = bufferization.alloc_tensor() { + memory_space = #plan.memory_space + } : tensor<2x128xf32> + %c0 = arith.constant 0 : index + %cst = arith.constant 1.0 : f32 + %1 = tensor.insert %cst into %0[%c0, %c0] : tensor<2x128xf32> + return %1 : tensor<2x128xf32> +} + +// CHECK-LABEL: func.func @alloc_tensor +// CHECK-SAME: () -> tensor<2x128xf32, #plan.memory_space> +// CHECK-DAG: %[[cst:.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[v1:.+]] = bufferization.alloc_tensor() {memory_space = #plan.memory_space} : tensor<2x128xf32, #plan.memory_space> +// CHECK-DAG: %[[inserted:.+]] = tensor.insert %[[cst]] into %[[v1]][%[[c0]], %[[c0]]] +// CHECK-DAG: %[[cast:.+]] = tensor.cast %[[inserted]] : tensor<2x128xf32, #plan.memory_space> to tensor<2x128xf32, #plan.memory_space> +// CHECK-DAG: return %[[cast]] + +// ----- + +func.func @alloc_tensor_2(%arg0: tensor<2x128xf32>) -> tensor<2x128xf32> { + %0 = bufferization.alloc_tensor() copy(%arg0) { + memory_space = #plan.memory_space + } : tensor<2x128xf32> + return %0 : tensor<2x128xf32> +} + +// CHECK-LABEL: func.func @alloc_tensor_2 +// CHECK-SAME: (%[[arg0:.+]]: tensor<2x128xf32, #plan.memory_space>) -> tensor<2x128xf32, #plan.memory_space> { +// CHECK: return %[[arg0]] diff --git a/mlir-tensorrt/compiler/test/Dialect/Plan/assign-memory-spaces.mlir b/mlir-tensorrt/compiler/test/Dialect/Plan/assign-memory-spaces.mlir deleted file mode 100644 index f87590198..000000000 --- a/mlir-tensorrt/compiler/test/Dialect/Plan/assign-memory-spaces.mlir +++ /dev/null @@ -1,73 +0,0 @@ -// RUN: mlir-tensorrt-opt %s -split-input-file --plan-assign-memory-spaces -canonicalize | FileCheck %s - - -func.func private @cond() -> i1 - -// CHECK-LABEL: func.func @scf_while_loop_2 -// CHECK: scf.while {{.*}}tensor<1xf32, #plan.memory_space>) -> tensor<1xf32, #plan.memory_space> -func.func @scf_while_loop_2(%arg0: f32) -> f32 { - %c0 = arith.constant 0 : index - %1 = tensor.from_elements %arg0 : tensor<1xf32> - %2 = scf.while (%arg1 = %1) : (tensor<1xf32>) -> tensor<1xf32> { - %cond = func.call @cond() : () -> i1 - %e = tensor.extract %arg1[%c0] : tensor<1xf32> - %f = arith.addf %e, %e : f32 - %3 = tensor.from_elements %f : tensor<1xf32> - scf.condition(%cond) %3 : tensor<1xf32> - } do { - ^bb0(%arg1: tensor<1xf32>): - %extract = tensor.extract %arg1[%c0] : tensor<1xf32> - %3 = arith.addf %extract, %extract : f32 - %4 = tensor.from_elements %3 : tensor<1xf32> - scf.yield %4 : tensor<1xf32> - } - %3 = tensor.extract %2[%c0] : tensor<1xf32> - return %3 : f32 -} - -// ----- - -// CHECK-LABEL: func.func @arith_constant -// CHECK: arith.constant {{.*}} : tensor<2xf32, #plan.memory_space> -// CHECK: arith.constant {{.*}} : tensor<2xf32, #plan.memory_space> -func.func @arith_constant() -> (tensor<2xf32>, tensor<2xf32>) { - %0 = arith.constant dense<[0.1, 0.2]> : tensor<2xf32> - %1 = arith.constant dense_resource<__elided__> : tensor<2xf32> - return %0, %1 : tensor<2xf32>, tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: module @nested_module -// CHECK-NOT: #plan.memory_space -module @outer { -module @nested_module { - func.func @nested_func() -> tensor<2xf32> { - %0 = arith.constant dense<[0.1, 0.2]> : tensor<2xf32> - return %0 : tensor<2xf32> - } -} -} - -// ----- - -// CHECK-LABEL: func.func @existing_constraint_1 -// CHECK: tensor.extract {{.*}} -func.func @existing_constraint_1(%arg0: tensor<2xf32, #plan.memory_space>) -> f32 { - %c0 = arith.constant 0 : index - %0 = tensor.extract %arg0[%c0] : tensor<2xf32, #plan.memory_space> - return %0 : f32 -} - -// ----- - -// CHECK-LABEL: func.func @existing_constraint_2 -// CHECK-NOT: tensor.cast -// CHECK: tensor.extract {{.*}} -func.func @existing_constraint_2(%arg0: tensor<2xf32, #plan.memory_space>) -> f32 { - %c0 = arith.constant 0 : index - %1 = tensor.cast %arg0 : tensor<2xf32, #plan.memory_space> to tensor<2xf32> - %0 = tensor.extract %1[%c0] : tensor<2xf32> - return %0 : f32 -} - diff --git a/mlir-tensorrt/compiler/test/Dialect/Plan/buffer-results-to-out-params.mlir b/mlir-tensorrt/compiler/test/Dialect/Plan/buffer-results-to-out-params.mlir new file mode 100644 index 000000000..ce3d642e6 --- /dev/null +++ b/mlir-tensorrt/compiler/test/Dialect/Plan/buffer-results-to-out-params.mlir @@ -0,0 +1,179 @@ +// RUN: mlir-tensorrt-opt %s -split-input-file -plan-buffer-results-to-out-params=ignore-public-functions -canonicalize | FileCheck %s + +func.func private @alloc_size() -> index + +func.func private @callee(%arg0 : index) + -> (memref<5xf32> {plan.tag = "foo"}, index, memref {plan.tag = "bar"}) { + %0 = memref.alloc() : memref<5xf32> + %size = func.call @alloc_size() : () -> index + %1 = memref.alloc(%size) : memref + %c1 = arith.constant 1 : index + %2 = arith.addi %arg0, %c1 : index + return %0, %2, %1 : memref<5xf32>, index, memref +} + +func.func private @callee_external() -> (memref<5xf32> {plan.tag = "foo"}) + +func.func @caller() -> (memref<5xf32>, index, memref, memref<5xf32>) { + %c10 = arith.constant 10 : index + %0:3 = func.call @callee(%c10) : (index) -> (memref<5xf32>, index, memref) + %1 = func.call @callee_external() : () -> (memref<5xf32>) + return %0#0, %0#1, %0#2, %1 : memref<5xf32>, index, memref, memref<5xf32> +} + +// CHECK-LABEL: func.func private @callee( +// CHECK-SAME: %[[arg0:.+]]: index, %[[arg1:.+]]: memref<5xf32> {plan.result_arg, plan.tag = "foo"}) -> (index, memref {plan.tag = "bar"}) +// CHECK: %[[alloc:.+]] = memref.alloc +// CHECK: %[[add:.+]] = arith.addi +// CHECK-NOT: memref.copy +// CHECK: return %[[add]], %[[alloc]] : index, memref + +// CHECK-LABEL: func.func private @callee_external() -> (memref<5xf32> {plan.tag = "foo"}) + +// CHECK-LABEL: func.func @caller +// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[alloc:.+]] = memref.alloc() : memref<5xf32> +// CHECK-DAG: %[[v0:.+]]:2 = call @callee(%[[c10]], %[[alloc]]) : (index, memref<5xf32>) -> (index, memref) +// CHECK-DAG: %[[v1:.+]] = call @callee_external() : () -> memref<5xf32> +// CHECK-DAG: return %[[alloc]], %[[v0]]#0, %[[v0]]#1, %[[v1]] + +// ----- + +func.func private @callee_returns_aliasing() -> (memref<10xf32>, memref<5xf32>) { + %0 = memref.alloc() : memref<10xf32> + %1 = memref.subview %0[0][5][1] : memref<10xf32> to memref<5xf32> + return %0, %1 : memref<10xf32>, memref<5xf32> +} + +func.func @caller() -> (memref<10xf32>, memref<5xf32>) { + %0:2 = func.call @callee_returns_aliasing() : () -> (memref<10xf32>, memref<5xf32>) + return %0#0, %0#1 : memref<10xf32>, memref<5xf32> +} + +// CHECK-LABEL: @callee_returns_aliasing() -> (memref<10xf32>, memref<5xf32>) + + +// ----- + +func.func private @callee_returns_block_arg(%arg0: memref<5xf32>) -> (memref<5xf32>) { + return %arg0 : memref<5xf32> +} + +func.func @caller(%arg0: memref<5xf32>) -> (memref<5xf32>) { + %1 = func.call @callee_returns_block_arg(%arg0) : (memref<5xf32>) -> (memref<5xf32>) + return %1 : memref<5xf32> +} + +// CHECK-LABEL: @callee_returns_block_arg +// CHECK-NEXT: return + +// CHECK-LABEL: @caller +// CHECK-NOT: memref.alloc + + +// ----- + +func.func private @callee_aliasing_duplicate_alloc() -> (memref<10xf32>, memref<10xf32>) { + %0 = memref.alloc() : memref<10xf32> + return %0, %0 : memref<10xf32>, memref<10xf32> +} + +func.func @caller(%arg0: index, %arg1: f32) -> (f32) { + %0:2 = func.call @callee_aliasing_duplicate_alloc() : () -> (memref<10xf32>, memref<10xf32>) + memref.store %arg1, %0#0[%arg0] : memref<10xf32> + %1 = memref.load %0#1[%arg0] : memref<10xf32> + return %1 : f32 +} + +// CHECK-LABEL: @callee_aliasing_duplicate_alloc() -> (memref<10xf32>, memref<10xf32>) +// CHECK-LABEL: @callee + + +// ----- + +func.func private @multiple_blocks(%arg0: i1, %arg1: memref<5xf32>, %arg2: memref<5xf32>) -> (memref<5xf32>) { + cf.cond_br %arg0, ^bb1, ^bb2 +^bb1: + return %arg1 : memref<5xf32> +^bb2: + return %arg2 : memref<5xf32> +} + +func.func @caller(%arg0: i1, %arg1: memref<5xf32>, %arg2: memref<5xf32>) -> (memref<5xf32>) { + %0 = func.call @multiple_blocks(%arg0, %arg1, %arg2) : (i1, memref<5xf32>, memref<5xf32>) -> (memref<5xf32>) + return %0 : memref<5xf32> +} + +// CHECK-LABEL: func.func private @multiple_blocks +// CHECK-SAME: (%[[arg0:.+]]: i1, %[[arg1:.+]]: memref<5xf32>, %[[arg2:.+]]: memref<5xf32>, %[[arg3:.+]]: memref<5xf32> {plan.result_arg}) +// CHECK: memref.copy %[[arg1]], %[[arg3]] +// CHECK: return +// CHECK: memref.copy %[[arg2]], %[[arg3]] +// CHECK: return + +// CHECK-LABEL: @caller +// CHECK: %[[alloc:.+]] = memref.alloc() : memref<5xf32> +// CHECK: call @multiple_blocks +// CHECK: return %[[alloc]] + +// ----- + +// This test uses a bunch of convoluted code to verify +// that complicated sequence of operations rooted at an allocation +// can be hoisted. + +// Note that ultimately the allocation and unused function arguments +// are dropped in follow on passes (canonicalize, remove-dead-values). + +!result_type1 = memref +!result_type2 = memref> + +func.func private @callee_returns_complicated( + %arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (!result_type1, !result_type2) { + %0 = memref.alloc(%arg0) : memref + %1 = memref.alloc(%arg1) : memref + + %c0 = arith.constant 0 : index + %3 = arith.constant 1.0 : f32 + %cond = arith.cmpi slt, %arg3, %arg4 : index + %4 = scf.if %cond -> !result_type2 { + %sv0 = memref.subview %1[%arg2][%arg3][%arg4] : memref to !result_type2 + scf.yield %sv0 : !result_type2 + } else { + %sv1 = memref.subview %1[%arg4][%arg2][%arg3] : memref to !result_type2 + scf.yield %sv1 : !result_type2 + } + memref.store %3, %0[%c0] : !result_type1 + memref.store %3, %4[%c0] : !result_type2 + return %0, %4 : !result_type1, !result_type2 +} + +func.func @caller(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) + -> (!result_type1, !result_type2) { + %0:2 = func.call @callee_returns_complicated(%arg0, %arg1, %arg2, %arg3, %arg4) + : (index, index, index, index, index) -> (!result_type1, !result_type2) + return %0#0, %0#1 : !result_type1, !result_type2 +} + +// CHECK-LABEL: func.func private @callee_returns_complicated +// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index, +// CHECK-SAME: %[[arg5:.+]]: memref {plan.result_arg}, %[[arg6:.+]]: memref> {plan.result_arg}) { +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[cst:.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: memref.store %[[cst]], %[[arg5]][%[[c0]]] : memref +// CHECK-DAG: memref.store %[[cst]], %[[arg6]][%[[c0]]] : memref> +// CHECK: return + +// CHECK-LABEL: func.func @caller +// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index) +// CHECK-DAG: %[[alloc:.+]] = memref.alloc(%[[arg0]]) : memref +// CHECK-DAG: %[[alloc_0:.+]] = memref.alloc(%[[arg1]]) : memref +// CHECK-DAG: %[[v0:.+]] = arith.cmpi slt, %[[arg3]], %[[arg4]] : index +// CHECK-DAG: %[[v1:.+]] = scf.if %[[v0]] -> (memref>) { +// CHECK: %[[subview:.+]] = memref.subview %[[alloc_0]][%[[arg2]]] [%[[arg3]]] [%[[arg4]]] +// CHECK: scf.yield %[[subview]] : +// CHECK: } else { +// CHECK: %[[subview:.+]] = memref.subview %[[alloc_0]][%[[arg4]]] [%[[arg2]]] [%[[arg3]]] +// CHECK: scf.yield %[[subview]] +// CHECK: call @callee_returns_complicated(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]], %[[alloc]], %[[v1]]) +// CHECK: return %[[alloc]], %[[v1]] diff --git a/mlir-tensorrt/compiler/test/Dialect/Plan/materialize-explicit-transfers.mlir b/mlir-tensorrt/compiler/test/Dialect/Plan/materialize-explicit-transfers.mlir new file mode 100644 index 000000000..ce7005fdd --- /dev/null +++ b/mlir-tensorrt/compiler/test/Dialect/Plan/materialize-explicit-transfers.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-tensorrt-opt %s --plan-materialize-explicit-transfers -split-input-file | FileCheck %s + +!host_type = tensor<1xf32, #plan.memory_space> +!device_type = tensor<1xf32, #plan.memory_space> + +// CHECK-LABEL: func.func @tensor_cast +// CHECK-SAME: (%[[arg0:.+]]: tensor<1xf32, #plan.memory_space> +func.func @tensor_cast(%arg0: !host_type) -> !device_type { + // CHECK: %[[v0:.+]] = bufferization.alloc_tensor() + // CHECK-SAME: memory_space = #plan.memory_space + // CHECK-SAME: tensor<1xf32, #plan.memory_space> + + // CHECK: %[[v1:.+]] = bufferization.materialize_in_destination + // CHECK-SAME: %[[arg0]] in %[[v0]] + %1 = tensor.cast %arg0 : !host_type to !device_type + // CHECK: return %[[v1]] : tensor<1xf32, #plan.memory_space> + return %1 : !device_type +} + +// ----- + +!host_type = tensor> +!device_type = tensor> + +// CHECK-LABEL: func.func @dynamic_shape +// CHECK-SAME: (%[[arg0:.+]]: tensor> +func.func @dynamic_shape(%arg0: !host_type) -> !device_type { +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c0]] : +// CHECK-DAG: %[[dim_0:.+]] = tensor.dim %[[arg0]], %[[c2]] : + // CHECK: %[[v0:.+]] = bufferization.alloc_tensor(%[[dim]], %[[dim_0]]) + // CHECK-SAME: memory_space = #plan.memory_space + // CHECK-SAME: tensor> + + // CHECK: %[[v1:.+]] = bufferization.materialize_in_destination + // CHECK-SAME: %[[arg0]] in %[[v0]] + %1 = tensor.cast %arg0 : !host_type to !device_type + // CHECK: return %[[v1]] : tensor> + return %1 : !device_type +} + + +// ----- + +!host_type = tensor<4xf32, #plan.memory_space> +!device_type = tensor<4xf32, #plan.memory_space> + +func.func @redundant_materialize_in_dest(%arg0: !device_type) -> !host_type { + %0 = tensor.empty() : !host_type + %1 = bufferization.materialize_in_destination %arg0 in %0 + : (!device_type, !host_type) -> !host_type + %2 = bufferization.materialize_in_destination %arg0 in %1 + : (!device_type, !host_type) -> !host_type + return %2 : !host_type +} + +// CHECK-LABEL: func.func @redundant_materialize_in_dest +// CHECK-SAME: (%[[arg0:.+]]: tensor<4xf32, #plan.memory_space>) -> tensor<4xf32, #plan.memory_space> { +// CHECK: %[[v0:.+]] = tensor.empty() : tensor<4xf32, #plan.memory_space> +// CHECK: %[[v1:.+]] = bufferization.materialize_in_destination %[[arg0]] in %[[v0]] +// CHECK: return %[[v1]] : tensor<4xf32, #plan.memory_space> diff --git a/mlir-tensorrt/compiler/test/Dialect/Plan/materialize-shape-calculations-composite.mlir b/mlir-tensorrt/compiler/test/Dialect/Plan/materialize-shape-calculations-composite.mlir index 82675fcd0..95e151c9a 100644 --- a/mlir-tensorrt/compiler/test/Dialect/Plan/materialize-shape-calculations-composite.mlir +++ b/mlir-tensorrt/compiler/test/Dialect/Plan/materialize-shape-calculations-composite.mlir @@ -85,7 +85,7 @@ func.func private @pt_q(%arg0: tensor) -> tensor attrib // SHAPE: return %[[arg0]], %[[c3]], %[[arg1]], %[[arg2]] : index, index, index, index // SHAPE: } // SHAPE-LABEL: func.func @stablehlo_composite_dynamic_shapes_get_shapes -// SHAPE-SAME: (%[[arg0:.+]]: tensor<4xindex, #plan.memory_space> {tensorrt.host_tensor}) +// SHAPE-SAME: (%[[arg0:.+]]: tensor<4xindex, #plan.memory_space>) // SHAPE-DAG: %[[c0:.+]] = arith.constant 0 : index // SHAPE-DAG: %[[extracted:.+]] = tensor.extract %[[arg0]][%[[c0]]] : tensor<4xindex, #plan.memory_space> // SHAPE-DAG: %[[c2:.+]] = arith.constant 2 : index diff --git a/mlir-tensorrt/compiler/test/Dialect/Plan/materialize-shape-calculations.mlir b/mlir-tensorrt/compiler/test/Dialect/Plan/materialize-shape-calculations.mlir index 935ab15a6..ed31982c9 100644 --- a/mlir-tensorrt/compiler/test/Dialect/Plan/materialize-shape-calculations.mlir +++ b/mlir-tensorrt/compiler/test/Dialect/Plan/materialize-shape-calculations.mlir @@ -21,7 +21,7 @@ func.func @test_simple(%arg0: tensor) -> tensor { // SHAPE-NEXT: %[[c10:.+]] = arith.constant 10 : index // SHAPE-NEXT: return %[[arg0]], %[[c10]] : index, index // SHAPE-LABEL: @test_simple_get_shapes -// SHAPE-SAME: (%[[arg0:.+]]: tensor<2xindex, #plan.memory_space> {tensorrt.host_tensor}) +// SHAPE-SAME: (%[[arg0:.+]]: tensor<2xindex, #plan.memory_space>) // SHAPE-NEXT: %[[c0:.+]] = arith.constant 0 : index // SHAPE-NEXT: %[[extracted:.+]] = tensor.extract %[[arg0]][%[[c0]]] : tensor<2xindex, #plan.memory_space> // SHAPE-NEXT: %[[v0:.+]]:2 = call @shape_test_simple_result_0(%[[extracted]]) : (index) -> (index, index) @@ -52,7 +52,7 @@ func.func @test_dynamic_reshape(%arg0: tensor<4xf32>, %arg1: tensor<2xi32>) -> t // SHAPE-SAME: %[[arg1:.+]]: i32 {plan.shape_func_arg = {argument = 1 : index, indices = array}}) // SHAPE: return %[[arg0]], %[[arg1]] : i32, i32 // SHAPE-LABEL: @test_dynamic_reshape_get_shapes -// SHAPE-SAME: (%[[arg0:.+]]: tensor<1xindex, #plan.memory_space> {tensorrt.host_tensor}, %[[arg1:.+]]: tensor<2xi32> {tensorrt.host_tensor}) -> (tensor<2xindex, #plan.memory_space> {tensorrt.host_tensor}) +// SHAPE-SAME: (%[[arg0:.+]]: tensor<1xindex, #plan.memory_space>, %[[arg1:.+]]: tensor<2xi32>) -> tensor<2xindex, #plan.memory_space> // SHAPE: %[[c0:.+]] = arith.constant 0 : index // SHAPE: %[[extracted:.+]] = tensor.extract %[[arg1]][%[[c0]]] : tensor<2xi32> // SHAPE: %[[c1:.+]] = arith.constant 1 : index @@ -120,7 +120,7 @@ func.func @test_get_dim_size_max(%arg0: tensor, %arg1: tensor) // SHAPE: return %[[v4]], %[[v5]] : i32, i32 // SHAPE-LABEL: func.func @test_get_dim_size_max_get_shapes -// SHAPE-SAME: (%[[arg0:.+]]: tensor<2xindex, #plan.memory_space> {tensorrt.host_tensor}, %[[arg1:.+]]: tensor<2xindex, #plan.memory_space> {tensorrt.host_tensor}) -> (tensor<2xindex, #plan.memory_space> {tensorrt.host_tensor}) +// SHAPE-SAME: (%[[arg0:.+]]: tensor<2xindex, #plan.memory_space>, %[[arg1:.+]]: tensor<2xindex, #plan.memory_space>) -> tensor<2xindex, #plan.memory_space> // SHAPE-DAG: %[[c0:.+]] = arith.constant 0 : index // SHAPE-DAG: %[[extracted:.+]] = tensor.extract %[[arg0]][%[[c0]]] : tensor<2xindex, #plan.memory_space> // SHAPE-DAG: %[[c1:.+]] = arith.constant 1 : index @@ -206,8 +206,8 @@ func.func @dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tensor<1 // SHAPE-DAG: %[[v6:.+]] = arith.addi %[[v5]], %[[arg3]] : index // SHAPE-DAG: return %[[v6]] // SHAPE-LABEL: @dynamic_pad_get_shapes -// SHAPE-SAME: (%[[arg0:.+]]: tensor<1xindex, #plan.memory_space> {tensorrt.host_tensor}, %[[arg1:.+]]: tensor<1xindex, #plan.memory_space> {tensorrt.host_tensor}, -// SHAPE-SAME: %[[arg2:.+]]: tensor<1xindex> {tensorrt.host_tensor}, %[[arg3:.+]]: tensor<1xindex> {tensorrt.host_tensor}, %[[arg4:.+]]: tensor<1xindex> {tensorrt.host_tensor}) +// SHAPE-SAME: (%[[arg0:.+]]: tensor<1xindex, #plan.memory_space>, %[[arg1:.+]]: tensor<1xindex, #plan.memory_space>, +// SHAPE-SAME: %[[arg2:.+]]: tensor<1xindex>, %[[arg3:.+]]: tensor<1xindex>, %[[arg4:.+]]: tensor<1xindex>) // SHAPE: %[[c0:.+]] = arith.constant 0 : index // SHAPE: %[[extracted:.+]] = tensor.extract %[[arg0]][%[[c0]]] : tensor<1xindex, #plan.memory_space> // SHAPE: %[[c0_0:.+]] = arith.constant 0 : index @@ -372,7 +372,7 @@ func.func @add_dynamic_derive_shape( // SHAPE-DAG: return %[[v3]] : i32 // SHAPE-LABEL: func.func @add_dynamic_derive_shape_get_shapes -// SHAPE-SAME: (%[[arg0:.+]]: tensor<1xindex, #plan.memory_space> {tensorrt.host_tensor}, %[[arg1:.+]]: tensor<1xindex, #plan.memory_space> {tensorrt.host_tensor}) -> (tensor<1xindex, #plan.memory_space> {tensorrt.host_tensor}) +// SHAPE-SAME: (%[[arg0:.+]]: tensor<1xindex, #plan.memory_space>, %[[arg1:.+]]: tensor<1xindex, #plan.memory_space>) -> tensor<1xindex, #plan.memory_space> // SHAPE: %[[c0:.+]] = arith.constant 0 : index // SHAPE: %[[extracted:.+]] = tensor.extract %[[arg0]][%[[c0]]] : tensor<1xindex, #plan.memory_space> // SHAPE: %[[c0_0:.+]] = arith.constant 0 : index @@ -745,7 +745,7 @@ func.func @bufferization_aloc_tensor(%arg0: tensor<1xindex>) -> tensor { // SHAPE-NEXT: return %[[arg0]] : // SHAPE-LABEL: @bufferization_aloc_tensor_get_shapes -// SHAPE-SAME: (%[[arg0:.+]]: tensor<1xindex> {tensorrt.host_tensor}) -> (tensor<1xindex, #plan.memory_space> {tensorrt.host_tensor}) +// SHAPE-SAME: (%[[arg0:.+]]: tensor<1xindex>) -> tensor<1xindex, #plan.memory_space> // SHAPE: %[[c0:.+]] = arith.constant 0 : index // SHAPE: %[[extracted:.+]] = tensor.extract %[[arg0]][%[[c0]]] : tensor<1xindex> // SHAPE: %[[v0:.+]] = call @shape_bufferization_aloc_tensor_result_0(%[[extracted]]) : @@ -886,7 +886,7 @@ func.func @slice_with_repetetive_max(%arg0: tensor<2xi32>, %arg1: tensor<1xf32>) // SHAPE-DAG: %[[v0:.+]] = arith.maxsi %[[arg0]], %[[arg1]] : i32 // SHAPE-DAG: return %[[v0]] : i32 // SHAPE-LABEL: func.func @slice_with_repetetive_max_get_shapes -// SHAPE-SAME: (%[[arg0:.+]]: tensor<2xi32> {tensorrt.host_tensor}, %[[arg1:.+]]: tensor<1xindex, #plan.memory_space> {tensorrt.host_tensor}) -> (tensor<1xindex, #plan.memory_space> {tensorrt.host_tensor}) +// SHAPE-SAME: (%[[arg0:.+]]: tensor<2xi32>, %[[arg1:.+]]: tensor<1xindex, #plan.memory_space>) -> tensor<1xindex, #plan.memory_space> // SHAPE-DAG: %[[c0:.+]] = arith.constant 0 : index // SHAPE-DAG: %[[extracted:.+]] = tensor.extract %[[arg0]][%[[c0]]] : tensor<2xi32> // SHAPE-DAG: %[[c1:.+]] = arith.constant 1 : index @@ -1111,4 +1111,4 @@ func.func @simplify_extract_of_reshape_negative(%arg0: tensor<1x?x3x4xf32>) -> f // CHECK-NEXT: %[[v0:.+]] = plan.with_shape %[[arg0]](%[[c1]], %[[dim]], %[[c3]], %[[c4]]) // CHECK-NEXT: %[[v1:.+]] = stablehlo.reshape %[[v0]] // CHECK-NEXT: %[[extracted:.+]] = tensor.extract %[[v1]][%[[c0]], %[[c1]], %[[c2]]] -// CHECK-NEXT: return %extracted \ No newline at end of file +// CHECK-NEXT: return %extracted \ No newline at end of file diff --git a/mlir-tensorrt/compiler/test/Dialect/Plan/plan-bufferize-pipeline.mlir b/mlir-tensorrt/compiler/test/Dialect/Plan/plan-bufferize-pipeline.mlir index baf0a8c8f..310b012ff 100644 --- a/mlir-tensorrt/compiler/test/Dialect/Plan/plan-bufferize-pipeline.mlir +++ b/mlir-tensorrt/compiler/test/Dialect/Plan/plan-bufferize-pipeline.mlir @@ -27,29 +27,19 @@ func.func @small_host_tensor_constant(%arg0: tensor) -> (tensor } -// TODO: This test shows that the pre-processing prior to one-shot-bufferization is -// sub-optimal. We allocate two host buffers to hold the `tensor<4xindex>` for some -// reason. +// There should be a copy since `%arg0` is not writable under our default settings. +// We can only avoid a copy if we use `force-entrypoints-return-allocs`. +// CHECK: memref.global "private" constant // CHECK-LABEL: func.func @small_host_tensor_constant // CHECK-SAME: (%[[arg0:.+]]: memref>, %[[arg1:.+]]: memref> {plan.result_arg}) { -// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index -// CHECK-DAG: %[[alloc:.+]] = memref.alloc() {alignment = 16 : i64} : memref<4xindex, #plan.memory_space> -// CHECK-DAG: memref.store %[[c1]], %[[alloc]][%[[c0]]] : memref<4xindex, #plan.memory_space> -// CHECK-DAG: memref.store %[[c2]], %[[alloc]][%[[c1]]] : memref<4xindex, #plan.memory_space> -// CHECK-DAG: memref.store %[[c3]], %[[alloc]][%[[c2]]] : memref<4xindex, #plan.memory_space> -// CHECK: %[[alloc_0:.+]] = memref.alloc() {alignment = 16 : i64} : memref<4xindex, #plan.memory_space> -// CHECK: memref.copy %[[alloc]], %[[alloc_0]] : memref<4xindex, #plan.memory_space> to memref<4xindex, #plan.memory_space> -// CHECK: memref.store %[[c4]], %[[alloc_0]][%[[c3]]] : memref<4xindex, #plan.memory_space> -// CHECK: %[[reshape:.+]] = memref.reshape %[[arg0]](%[[alloc_0]]) : -// CHECK: memref.copy %[[reshape]], %[[arg1]] : memref> to memref> -// CHECK: memref.dealloc %[[alloc]] : memref<4xindex, #plan.memory_space> -// CHECK: memref.dealloc %[[alloc_0]] : memref<4xindex, #plan.memory_space> -// CHECK: return +// CHECK: %[[v0:.+]] = memref.get_global +// CHECK: %[[alloc:.+]] = memref.alloc() {{.*}} : memref<4xindex, #plan.memory_space> +// CHECK: memref.copy %[[v0]], %[[alloc]] +// CHECK: %[[reshape:.+]] = memref.reshape %[[arg0]](%[[alloc]]) +// CHECK: memref.copy %[[reshape]], %[[arg1]] +// CHECK: memref.dealloc %[[alloc]] + // ----- @@ -59,71 +49,208 @@ func.func @small_host_and_device_tensor_constant(%arg0: tensor) -> (ten return %1, %0 : tensor, tensor<4xindex> } -// CHECK: memref.global "private" constant @__constant_4xindex : memref<4xindex, #plan.memory_space> = dense<[1, 2, 3, 4]> {alignment = 16 : i64} +// CHECK: memref.global "private" constant @__constant_4xindex // CHECK-LABEL: func.func @small_host_and_device_tensor_constant -// CHECK-SAME: (%[[arg0:.+]]: memref>, -// CHECK-SAME: %[[arg1:.+]]: memref> {plan.result_arg}, -// CHECK-SAME: %[[arg2:.+]]: memref<4xindex, #plan.memory_space> {plan.result_arg}) { -// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index -// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index -// CHECK: %[[v0:.+]] = memref.get_global @__constant_4xindex : memref<4xindex, #plan.memory_space> -// CHECK: memref.copy %[[v0]], %[[arg2]] : memref<4xindex, #plan.memory_space> to memref<4xindex, #plan.memory_space> -// CHECK: %[[alloc:.+]] = memref.alloc() {alignment = 16 : i64} : memref<4xindex, #plan.memory_space> -// CHECK-DAG: memref.store %[[c1]], %[[alloc]][%[[c0]]] : memref<4xindex, #plan.memory_space> -// CHECK-DAG: memref.store %[[c2]], %[[alloc]][%[[c1]]] : memref<4xindex, #plan.memory_space> -// CHECK-DAG: memref.store %[[c3]], %[[alloc]][%[[c2]]] : memref<4xindex, #plan.memory_space> -// CHECK: %[[alloc_0:.+]] = memref.alloc() {alignment = 16 : i64} : memref<4xindex, #plan.memory_space> -// CHECK: memref.copy %[[alloc]], %[[alloc_0]] : memref<4xindex, #plan.memory_space> to memref<4xindex, #plan.memory_space> -// CHECK: memref.store %[[c4]], %[[alloc_0]][%[[c3]]] : memref<4xindex, #plan.memory_space> -// CHECK: %[[reshape:.+]] = memref.reshape %[[arg0]](%[[alloc_0]]) : (memref>, memref<4xindex, #plan.memory_space>) -// CHECK: memref.copy %[[reshape]], %[[arg1]] : memref> to memref> -// CHECK: memref.dealloc %[[alloc]] : memref<4xindex, #plan.memory_space> -// CHECK: memref.dealloc %[[alloc_0]] : memref<4xindex, #plan.memory_space> +// CHECK-SAME: (%[[arg0:.+]]: memref>, %[[arg1:.+]]: memref> {plan.result_arg}, %[[arg2:.+]]: memref<4xindex, #plan.memory_space> {plan.result_arg}) +// CHECK: %[[v0:.+]] = memref.get_global {{.*}} #plan.memory_space> +// CHECK: memref.copy %[[v0]], %[[arg2]] : +// CHECK: %[[alloc:.+]] = memref.alloc() {{.*}} #plan.memory_space +// CHECK: memref.copy %[[arg2]], %[[alloc]] +// CHECK: %[[reshape:.+]] = memref.reshape %[[arg0]](%[[alloc]]) +// CHECK: memref.copy %[[reshape]], %[[arg1]] +// CHECK: memref.dealloc %[[alloc]] // CHECK: return // ----- +module @while_loop { + func.func private @cond() -> i1 // The test case illustrates a while loop that for whatever reason may not // have been "detensorized" earlier in the pipeline. The TensorKindAnalysis -// will show that all tensors are "host-only", but currently bufferization +// will show that all tensors are "host-only", but currently bufferization // does not deduce this via its memory space inference logic. Therefore, the // loop will be bufferized so that the buffers are in the device // space at branch points, which means lots of copies are inserted. Before -// adding the 'plan-assign-memory-spaces' pass, we would get a failure here +// adding the 'plan-assign-memory-spaces' pass, we would get a failure here // due to mixed types of init arg and yielded value inferred by bufferization. -// In the future, we can optimize this case by adding support for rewriting +// In the future, we can optimize this case by adding support for rewriting // the encoding attribute of loop-carried tensors to be host for this case. func.func @while_loop_host_tensor_carried(%arg0: f32) -> f32 { %c0 = arith.constant 0 : index %1 = tensor.from_elements %arg0 : tensor<1xf32> - %2 = scf.while (%arg1 = %1) : (tensor<1xf32>) -> tensor<1xf32> { + %2 = scf.while (%arg1 = %1) : (tensor<1xf32>) -> tensor<1xf32> { %cond = func.call @cond() : () -> i1 %e = tensor.extract %arg1[%c0] : tensor<1xf32> %f = arith.addf %e, %e : f32 %3 = tensor.from_elements %f : tensor<1xf32> scf.condition(%cond) %3 : tensor<1xf32> } do { - ^bb0(%arg1: tensor<1xf32>): - %extract = tensor.extract %arg1[%c0] : tensor<1xf32> + ^bb0(%arg1: tensor<1xf32>): + %extract = tensor.extract %arg1[%c0] : tensor<1xf32> %3 = arith.addf %extract, %extract : f32 - %4 = tensor.from_elements %3 : tensor<1xf32> + %4 = tensor.from_elements %3 : tensor<1xf32> scf.yield %4 : tensor<1xf32> } %3 = tensor.extract %2[%c0] : tensor<1xf32> return %3 : f32 } +} + // CHECK-LABEL: func.func @while_loop_host_tensor_carried // CHECK: scf.while : () -> () -// CHECK-COUNT-2: memref.copy +// CHECK-COUNT-1: memref.copy // CHECK: scf.condition -// CHECK-COUNT-2: memref.copy -// CHECK: scf.yield // CHECK-COUNT-1: memref.copy +// CHECK: scf.yield // CHECK-NOT: memref.copy + +// ----- + +// This test checks that if we create a function with specific constraints, +// then we should not insert unnecessary copies to tranfer between other spaces. + +module @shape_func_with_constraints { + +func.func @shape_func_with_constraints( + %arg0: tensor<2xindex, #plan.memory_space>, + %arg1: tensor<2xindex, #plan.memory_space>) + -> tensor<2xindex, #plan.memory_space> attributes { + plan.memory_space = #plan.memory_space + } { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %extracted = tensor.extract %arg0[%c0] : tensor<2xindex, #plan.memory_space> + %extracted_0 = tensor.extract %arg1[%c0] : tensor<2xindex, #plan.memory_space> + %0 = arith.index_cast %extracted : index to i32 + %1 = arith.index_cast %extracted_0 : index to i32 + %2 = arith.maxsi %0, %1 : i32 + %3 = arith.index_cast %2 : i32 to index + %from_elements = tensor.from_elements %3, %c2 : tensor<2xindex, #plan.memory_space> + return %from_elements : tensor<2xindex, #plan.memory_space> +} + +} + +// CHECK-LABEL: func.func @shape_func_with_constraints +// CHECK-SAME: (%[[arg0:.+]]: memref<2xindex, #plan.memory_space>, %[[arg1:.+]]: memref<2xindex, #plan.memory_space>, %[[arg2:.+]]: memref<2xindex, #plan.memory_space> {plan.result_arg}) attributes {plan.memory_space = #plan.memory_space} { +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[v0:.+]] = memref.load %[[arg0]][%[[c0]]] +// CHECK-DAG: %[[v1:.+]] = memref.load %[[arg1]][%[[c0]]] +// CHECK-DAG: %[[v2:.+]] = arith.index_cast %[[v0]] : index to i32 +// CHECK-DAG: %[[v3:.+]] = arith.index_cast %[[v1]] : index to i32 +// CHECK-DAG: %[[v4:.+]] = arith.maxsi %[[v2]], %[[v3]] : i32 +// CHECK-DAG: %[[v5:.+]] = arith.index_cast %[[v4]] : i32 to index +// CHECK-DAG: %[[alloc:.+]] = memref.alloc() +// CHECK-DAG: memref.store %[[v5]], %[[alloc]][%[[c0]]] +// CHECK-DAG: memref.store %[[c2]], %[[alloc]][%[[c1]]] +// CHECK-DAG: memref.copy %[[alloc]], %[[arg2]] +// CHECK-DAG: memref.dealloc %[[alloc]] +// CHECK: return + +// ----- + +// This test checks that we don't produce incorrect IR when using +// `bufferization.alloc_tensor` to allocate a tensor in a different space. +// Currently `bufferization.alloc_tensor` also requires a cast on the result. +// TODO: remove when upstream is fixed. + +func.func @test_alloc_tensor_copy_to_space(%arg0: tensor<2xindex, #plan.memory_space>) + -> tensor<2xindex, #plan.memory_space> { + %0 = bufferization.alloc_tensor () copy (%arg0) { + memory_space = #plan.memory_space + } : tensor<2xindex, #plan.memory_space> + %1 = tensor.cast %0 + : tensor<2xindex, #plan.memory_space> to tensor<2xindex, #plan.memory_space> + return %1 : tensor<2xindex, #plan.memory_space> +} + +// CHECK-LABEL: func.func @test_alloc_tensor_copy_to_space +// CHECK-SAME: (%[[arg0:.+]]: memref<2xindex, #plan.memory_space>, +// CHECK-SAME: %[[arg1:.+]]: memref<2xindex, #plan.memory_space> {plan.result_arg}) +// CHECK: memref.copy %[[arg0]], %[[arg1]] +// CHECK-NEXT: return + +// ----- + +// This test checks that we produce bufferized IR where: +// - The `scf.for` loops don't have unnecessary transfers inside the +// loop bodies. +// - We produce correct result when "bufferization.alloc_tensor" is used +// at the frontend and the memory space constraint is specified using +// the `memory_space` attribute. + +func.func @fill_buffers_using_for_loops() -> (tensor<2x128xf32>, tensor<128xf32>) { + %c0_i32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + // Create input tensor in host space and fill it with values. + %0 = bufferization.alloc_tensor() { + memory_space = #plan.memory_space + } : tensor<2x128xf32> + + // Create the second tensor in device space, but it is also + // filled with values in a loop. Optimization should make this + // a host tensor allocation. + %01 = bufferization.alloc_tensor() { + memory_space = #plan.memory_space + } : tensor<128xf32> + + %lhs_host = scf.for %i = %c0 to %c256 step %c1 iter_args(%iter = %0) -> (tensor<2x128xf32>) { + %coords:2 = affine.delinearize_index %i into(%c2, %c128) : index, index + %v = arith.index_cast %i : index to i32 + %vf = arith.sitofp %v : i32 to f32 + %y = tensor.insert %vf into %iter[%coords#0, %coords#1] : tensor<2x128xf32> + scf.yield %y : tensor<2x128xf32> + } + %rhs_host = scf.for %i = %c0 to %c128 step %c1 iter_args(%iter = %01) -> (tensor<128xf32>) { + %coords:2 = affine.delinearize_index %i into(%c2, %c128) : index, index + %v = arith.index_cast %i : index to i32 + %vf = arith.sitofp %v : i32 to f32 + %y = tensor.insert %vf into %iter[%i] : tensor<128xf32> + scf.yield %y : tensor<128xf32> + } + + // Clone input host tensor into device memory space. + %lhs = bufferization.alloc_tensor() copy(%lhs_host) { + memory_space = #plan.memory_space + } : tensor<2x128xf32> + %rhs = bufferization.alloc_tensor() copy(%rhs_host) { + memory_space = #plan.memory_space + } : tensor<128xf32> + + return %lhs, %rhs : tensor<2x128xf32>, tensor<128xf32> +} + +// CHECK-LABEL: func.func @fill_buffers_using_for_loops +// CHECK-SAME: (%[[arg0:.+]]: memref<2x128xf32, #plan.memory_space> {plan.result_arg}, %[[arg1:.+]]: memref<128xf32, #plan.memory_space> {plan.result_arg}) +// CHECK-DAG: %[[alloc:.+]] = memref.alloc() {{.*}} #plan.memory_space> +// CHECK: scf.for %[[arg2:.+]] = +// CHECK-DAG: %[[v0]]:2 = affine.delinearize_index %[[arg2]] into (2, 128) : index, index +// CHECK-DAG: %[[v1:.+]] = arith.index_cast %[[arg2]] : index to i32 +// CHECK-DAG: %[[v2:.+]] = arith.sitofp %[[v1]] : i32 to f32 +// CHECK-DAG: memref.store %[[v2]], %[[alloc]][%[[v0]]#0, %[[v0]]#1] +// CHECK: } +// CHECK: memref.copy %[[alloc]], %[[arg0]] : memref<2x128xf32, #plan.memory_space> to memref<2x128xf32, #plan.memory_space> +// CHECK: %[[alloc_0:.+]] = memref.alloc() {{.*}} #plan.memory_space> +// CHECK: scf.for %[[arg2:.+]] = +// CHECK-DAG: %[[v0:.+]] = arith.index_cast %[[arg2]] : index to i32 +// CHECK-DAG: %[[v1:.+]] = arith.sitofp %[[v0]] : i32 to f32 +// CHECK-DAG: memref.store %[[v1]], %[[alloc_0]][%[[arg2]]] : memref<128xf32, #plan.memory_space> +// CHECK: } +// CHECK: memref.copy %[[alloc_0]], %[[arg1]] : memref<128xf32, #plan.memory_space> to memref<128xf32, #plan.memory_space> +// CHECK: memref.dealloc %[[alloc]] : memref<2x128xf32, #plan.memory_space> +// CHECK: memref.dealloc %[[alloc_0]] : memref<128xf32, #plan.memory_space> +// CHECK: return diff --git a/mlir-tensorrt/compiler/test/Dialect/Plan/promote-host-tensors-to-host-pinned.mlir b/mlir-tensorrt/compiler/test/Dialect/Plan/promote-host-tensors-to-host-pinned.mlir new file mode 100644 index 000000000..3f0299983 --- /dev/null +++ b/mlir-tensorrt/compiler/test/Dialect/Plan/promote-host-tensors-to-host-pinned.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-tensorrt-opt %s -plan-promote-host-tensors-to-host-pinned -split-input-file | FileCheck %s + +func.func @test(%arg0: tensor<10xf32, #plan.memory_space>) -> f32 { + %c0 = arith.constant 0 : index + %0 = tensor.cast %arg0 : tensor<10xf32, #plan.memory_space> to tensor<10xf32, #plan.memory_space> + %1 = tensor.extract %0[%c0] : tensor<10xf32, #plan.memory_space> + return %1 : f32 +} + +// CHECK-LABEL: func.func @test +// CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32, #plan.memory_space> +// CHECK: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<10xf32, #plan.memory_space> to tensor<10xf32, #plan.memory_space> +// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[CAST]]{{.*}} : tensor<10xf32, #plan.memory_space> +// CHECK: return %[[EXTRACT]] : f32 + +// ----- + +func.func @from_elements_case(%arg0 :f32) -> tensor> { + %0 = tensor.from_elements %arg0 : tensor> + %1 = tensor.cast %0 : tensor> to tensor> + return %1 : tensor> +} + + +// CHECK-LABEL: func.func @from_elements_case +// CHECK-SAME: (%[[arg0:.+]]: f32) -> tensor> +// CHECK: %[[from_elements:.+]] = tensor.from_elements %[[arg0]] : tensor> +// CHECK: %[[cast:.+]] = tensor.cast %[[from_elements]] +// CHECK: return %[[cast]] : + diff --git a/mlir-tensorrt/compiler/test/Pipelines/StableHloInputPipeline/preprocessing-pipeline.mlir b/mlir-tensorrt/compiler/test/Pipelines/StableHloInputPipeline/preprocessing-pipeline.mlir new file mode 100644 index 000000000..a1b8dd47a --- /dev/null +++ b/mlir-tensorrt/compiler/test/Pipelines/StableHloInputPipeline/preprocessing-pipeline.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-tensorrt-opt %s -stablehlo-preprocessing-pipeline + +func.func public @composite_call(%arg0: tensor<4xf32>) -> (tensor<4xf32> {jax.result_info = ""}) { + %0 = stablehlo.composite "my.tangent" %arg0 {composite_attributes = {dtype = f32, int = 1 : i64, str = "bar", tensor = dense<0.000000e+00> : tensor<1x2xf32>, tensor_r1 = dense<0.000000e+00> : tensor<2xf32>}, decomposition = @my.tangent} : (tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} +func.func private @my.tangent(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %0 = stablehlo.sine %arg0 : tensor<4xf32> + %1 = stablehlo.cosine %arg0 : tensor<4xf32> + %2 = stablehlo.divide %0, %1 : tensor<4xf32> + return %2 : tensor<4xf32> +} + +// CHECK-LABEL: @composite_call +// CHECK-SAME: (%[[arg0:.+]]: tensor<4xf32>) +// CHECK-NEXT: %[[v0:.+]] = stablehlo.sine %[[arg0]] : tensor<4xf32> +// CHECK-NEXT: %[[v1:.+]] = stablehlo.cosine %[[arg0]] : tensor<4xf32> +// CHECK-NEXT: %[[v2:.+]] = stablehlo.divide %[[v0]], %[[v1]] : tensor<4xf32> +// CHECK-NEXT: return %[[v2]] : tensor<4xf32> \ No newline at end of file diff --git a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-binary.mlir b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-binary.mlir index 5a34a461a..2bd7c24ae 100644 --- a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-binary.mlir +++ b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-binary.mlir @@ -1,9 +1,10 @@ -// RUN: mlir-tensorrt-opt %s \ +// RUN: %pick-one-gpu mlir-tensorrt-opt %s \ // RUN: -pass-pipeline="builtin.module(stablehlo-preprocessing-pipeline{disable-inliner},\ // RUN: stablehlo-clustering-pipeline{entrypoint=}, \ // RUN: post-clustering-pipeline, \ // RUN: executor-lowering-pipeline)" \ -// RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable -allow-unregistered-dialect | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable -allow-unregistered-dialect | \ +// RUN: %pick-one-gpu mlir-tensorrt-runner -input-type=rtexe -features=core,cuda,tensorrt | FileCheck %s #profile = #tensorrt.shape_profile #profile1 = #tensorrt.shape_profile diff --git a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-unary.mlir b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-unary.mlir index 82b4c7322..c6b42aa59 100644 --- a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-unary.mlir +++ b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/end-to-end-unary.mlir @@ -1,10 +1,10 @@ -// RUN: mlir-tensorrt-opt %s \ +// RUN: %pick-one-gpu mlir-tensorrt-opt %s \ // RUN: -pass-pipeline="builtin.module(stablehlo-preprocessing-pipeline{disable-inliner},\ // RUN: stablehlo-clustering-pipeline{entrypoint=}, \ // RUN: post-clustering-pipeline, \ // RUN: executor-lowering-pipeline)" \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable -allow-unregistered-dialect | \ -// RUN: mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: %pick-one-gpu mlir-tensorrt-runner -input-type=rtexe -features=core,cuda,tensorrt | FileCheck %s #profile0 = #tensorrt.shape_profile #profile1 = #tensorrt.shape_profile diff --git a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-bf16.mlir b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-bf16.mlir index 15f652aac..f3e984839 100644 --- a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-bf16.mlir +++ b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-bf16.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe -features=core,cuda | FileCheck %s !descriptor1D = !executor.table, !executor.ptr, index, index, index> !hostMemRef = memref<4xbf16, #plan.memory_space> diff --git a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-dynamic.mlir b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-dynamic.mlir index 73c1cd690..17b7120d4 100644 --- a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-dynamic.mlir +++ b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-dynamic.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe -features=core,cuda | FileCheck %s func.func @run_with_shape_2d(%arg0: memref, %arg1: memref<2xindex>) { %c0 = arith.constant 0 : index diff --git a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-f16.mlir b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-f16.mlir index 448b88c6f..4c7a04572 100644 --- a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-f16.mlir +++ b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-f16.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe -features=core,cuda | FileCheck %s !descriptor1D = !executor.table, !executor.ptr, index, index, index> !hostMemRef = memref<4xf16, #plan.memory_space> diff --git a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-f32.mlir b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-f32.mlir index 0d16f189a..917dbb484 100644 --- a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-f32.mlir +++ b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-f32.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe -features=core,cuda | FileCheck %s !descriptor1D = !executor.table, !executor.ptr, index, index, index> !hostMemRef = memref<4xf32, #plan.memory_space> diff --git a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-f8E4M3FN.mlir b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-f8E4M3FN.mlir index 7b3ae4765..0852feac1 100644 --- a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-f8E4M3FN.mlir +++ b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-f8E4M3FN.mlir @@ -2,7 +2,7 @@ // REQUIRES: all-gpus-support-fp8 // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe -features=core,cuda | FileCheck %s !descriptor1D = !executor.table, !executor.ptr, index, index, index> !hostMemRef = memref<4xf8E4M3FN, #plan.memory_space> diff --git a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-i1.mlir b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-i1.mlir index f44da93c5..dd2a26032 100644 --- a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-i1.mlir +++ b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-i1.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe -features=core,cuda | FileCheck %s !descriptor1D = !executor.table, !executor.ptr, index, index, index> !hostMemRef = memref<4xi1, #plan.memory_space> diff --git a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-i4.mlir b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-i4.mlir index 766bec84f..86cffd061 100644 --- a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-i4.mlir +++ b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/buffer-ops-i4.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe -features=core,cuda | FileCheck %s !descriptor1D = !executor.table, !executor.ptr, index, index, index> !hostMemRef = memref<4xi4, #plan.memory_space> diff --git a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/lit.local.cfg b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/lit.local.cfg new file mode 100644 index 000000000..cb43cb4ee --- /dev/null +++ b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/lit.local.cfg @@ -0,0 +1 @@ +config.parallelism_group = "non-collective" diff --git a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/memcpy-strided.mlir b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/memcpy-strided.mlir index 0abcfec01..67122053f 100644 --- a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/memcpy-strided.mlir +++ b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/memcpy-strided.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe -features=core,cuda | FileCheck %s func.func @main() -> index { %c0 = arith.constant 0 : index diff --git a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/memcpy.mlir b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/memcpy.mlir index f750810c8..0ad4ec53e 100644 --- a/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/memcpy.mlir +++ b/mlir-tensorrt/compiler/test/Target/Lua/IntegrationTests/memcpy.mlir @@ -1,7 +1,7 @@ // REQUIRES: host-has-at-least-1-gpus // RUN: mlir-tensorrt-opt %s -convert-memref-to-cuda -convert-plan-to-executor -convert-cuda-to-executor -executor-lowering-pipeline \ // RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable \ -// RUN: | mlir-tensorrt-runner -input-type=rtexe | FileCheck %s +// RUN: | mlir-tensorrt-runner -input-type=rtexe -features=core,cuda | FileCheck %s func.func @main() -> i32 { %c0_i32 = arith.constant 0 : i32 diff --git a/mlir-tensorrt/compiler/test/Transforms/SCFDetensorizeLoops/scf-detensorize-loops.mlir b/mlir-tensorrt/compiler/test/Transforms/SCFDetensorizeLoops/scf-detensorize-loops.mlir index 6d0b5049e..ff19cdfe6 100644 --- a/mlir-tensorrt/compiler/test/Transforms/SCFDetensorizeLoops/scf-detensorize-loops.mlir +++ b/mlir-tensorrt/compiler/test/Transforms/SCFDetensorizeLoops/scf-detensorize-loops.mlir @@ -35,8 +35,8 @@ func.func @detensorize_while(%arg0: tensor, %arg1: tensor<1xi32>) // CHECK-NEXT: ^bb0(%[[arg2:.+]]: i32, %[[arg3:.+]]: i32): // CHECK-NEXT: %[[v1:.+]] = arith.addi %[[arg2]], %[[arg3]] : i32 // CHECK-NEXT: scf.yield %[[v1]], %[[v1]] : i32, i32 -// CHECK: %[[from_elements:.+]] = tensor.from_elements %[[v0]]#0 : tensor -// CHECK-NEXT: %[[from_elements_1:.+]] = tensor.from_elements %[[v0]]#1 : tensor<1xi32> +// CHECK-DAG: %[[from_elements:.+]] = tensor.from_elements %[[v0]]#0 : tensor +// CHECK-DAG: %[[from_elements_1:.+]] = tensor.from_elements %[[v0]]#1 : tensor<1xi32> // CHECK-NEXT: return %[[from_elements]], %[[from_elements_1]] : tensor, tensor<1xi32> // ----- diff --git a/mlir-tensorrt/compiler/test/lit.cfg.py b/mlir-tensorrt/compiler/test/lit.cfg.py index 240fdc4f3..a21462e67 100644 --- a/mlir-tensorrt/compiler/test/lit.cfg.py +++ b/mlir-tensorrt/compiler/test/lit.cfg.py @@ -57,9 +57,13 @@ def estimate_paralllelism( devices, gb_gpu_mem_required ) return int( - min( - parallelism, - (psutil.virtual_memory().available / (1024**3)) // gb_sys_mem_required, + max( + min( + parallelism, + (0.5 * psutil.virtual_memory().available / (1024**3)) + // gb_sys_mem_required, + ), + 1, ) ) except: @@ -83,15 +87,14 @@ def estimate_paralllelism( # Setup the parallelism groups. Note that just instantiating the TRT builder # requires ~2.5 GB of system memory, so we use 3.0 as a baseline limit. -lit_config.parallelism_groups["default"] = estimate_paralllelism( - 2.0, gb_sys_mem_required=3.0 +lit_config.parallelism_groups["non-collective"] = estimate_paralllelism( + 2.0, gb_sys_mem_required=5.0 ) +lit_config.parallelism_groups["collective"] = 1 lit_config.parallelism_groups["models"] = estimate_paralllelism( - 8.0, gb_sys_mem_required=4.0 + 8.0, gb_sys_mem_required=6.0 ) -lit_config.parallelism_groups["heavy"] = 1 - -lit_config.parallelism_group = "default" +config.parallelism_group = "non-collective" print(f"Parallelism Groups: {lit_config.parallelism_groups}", file=sys.stderr) diff --git a/mlir-tensorrt/compiler/test/python/IntegrationTests/Torch/test_torch_add.py b/mlir-tensorrt/compiler/test/python/IntegrationTests/Torch/test_torch_add.py index 590097442..316fa82f5 100644 --- a/mlir-tensorrt/compiler/test/python/IntegrationTests/Torch/test_torch_add.py +++ b/mlir-tensorrt/compiler/test/python/IntegrationTests/Torch/test_torch_add.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s +# RUN: %pick-one-gpu %PYTHON %s import mlir_tensorrt.compiler.api as compiler import mlir_tensorrt.compiler.ir as ir diff --git a/mlir-tensorrt/compiler/test/python/IntegrationTests/lit.local.cfg b/mlir-tensorrt/compiler/test/python/IntegrationTests/lit.local.cfg index 8e80f44d9..9e414568f 100644 --- a/mlir-tensorrt/compiler/test/python/IntegrationTests/lit.local.cfg +++ b/mlir-tensorrt/compiler/test/python/IntegrationTests/lit.local.cfg @@ -2,5 +2,3 @@ if not config.enable_bindings_python: config.unsupported = True if not "host-has-at-least-1-gpus" in config.available_features: config.unsupported = True - -config.parallelism_group = "heavy" diff --git a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_call_validation.py b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_call_validation.py index f55b984ef..bc8b67168 100644 --- a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_call_validation.py +++ b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_call_validation.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s | FileCheck %s +# RUN: %pick-one-gpu %PYTHON %s | FileCheck %s import mlir_tensorrt.compiler.api as compiler import mlir_tensorrt.compiler.ir as ir import mlir_tensorrt.runtime.api as runtime diff --git a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_non_dps_cconv.py b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_non_dps_cconv.py index 88a519f33..40d2189a9 100644 --- a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_non_dps_cconv.py +++ b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_non_dps_cconv.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s +# RUN: %pick-one-gpu %PYTHON %s import time import mlir_tensorrt.compiler.api as compiler diff --git a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_return_allocation_loop.py b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_return_allocation_loop.py index 2e4c4ab07..3643e386e 100644 --- a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_return_allocation_loop.py +++ b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_return_allocation_loop.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s +# RUN: %pick-one-gpu %PYTHON %s # Creates a program that requries ~1GB of memory to run. # We execute it in a loop, and in each execution the program needs to allocate a new output buffer # of size ~1GB. diff --git a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_stablehlo_add.py b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_stablehlo_add.py index f997cbac8..df95d46ed 100644 --- a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_stablehlo_add.py +++ b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_stablehlo_add.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s +# RUN: %pick-one-gpu %PYTHON %s import time import mlir_tensorrt.compiler.api as compiler diff --git a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_stablehlo_dynamic.py b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_stablehlo_dynamic.py index bbc342e49..59fafad38 100644 --- a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_stablehlo_dynamic.py +++ b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_stablehlo_dynamic.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s | FileCheck %s +# RUN: %pick-one-gpu %PYTHON %s | FileCheck %s from typing import Iterable import mlir_tensorrt.compiler.api as compiler diff --git a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_stablehlo_dynamic_iota.py b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_stablehlo_dynamic_iota.py index eb00699be..1e8e5ea06 100644 --- a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_stablehlo_dynamic_iota.py +++ b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_stablehlo_dynamic_iota.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s +# RUN: %pick-one-gpu %PYTHON %s # REQUIRES: tensorrt-version-ge-10.0 import mlir_tensorrt.compiler.api as compiler import mlir_tensorrt.compiler.ir as ir diff --git a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_tensorrt10_data_type_support.py b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_tensorrt10_data_type_support.py index a20bf3bcc..e50ccafdc 100644 --- a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_tensorrt10_data_type_support.py +++ b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_tensorrt10_data_type_support.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s | FileCheck %s +# RUN: %pick-one-gpu %PYTHON %s | FileCheck %s # REQUIRES: all-gpus-support-fp8 # REQUIRES: tensorrt-version-ge-10.0 from dataclasses import dataclass diff --git a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_tensorrt_add.py b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_tensorrt_add.py index ecf098f6f..d426b86d3 100644 --- a/mlir-tensorrt/compiler/test/python/IntegrationTests/test_tensorrt_add.py +++ b/mlir-tensorrt/compiler/test/python/IntegrationTests/test_tensorrt_add.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s +# RUN: %pick-one-gpu %PYTHON %s | FileCheck %s # Restricted to TRT 10+ due to use of "strongly-typed" mode below. # REQUIRES: tensorrt-version-ge-10.0 import time @@ -89,14 +89,14 @@ def tensorrt_add(): # CHECK: [ 0. 2. 4. 6.] # CHECK-NEXT: [ 8. 10. 12. 14.] # CHECK-NEXT: [16. 18. 20. 22.]] -# CHECK-NEXT: -# CHECK-NEXT: [24. 26. 28. 30.] + +# CHECK: [24. 26. 28. 30.] # CHECK-NEXT: [32. 34. 36. 38.] # CHECK-NEXT: [40. 42. 44. 46.]]] # CHECK-NEXT: [ 0. 32. 64. 96.] # CHECK-NEXT: [128. 160. 192. 224.] # CHECK-NEXT: [256. 288. 320. 352.]] -# CHECK-NEXT: -# CHECK-NEXT: [384. 416. 448. 480.] + +# CHECK: [384. 416. 448. 480.] # CHECK-NEXT: [512. 544. 576. 608.] # CHECK-NEXT: [640. 672. 704. 736.] diff --git a/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_api.py b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_api.py index d0cee1d40..d8b20c813 100644 --- a/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_api.py +++ b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_api.py @@ -1,7 +1,7 @@ # REQUIRES: tensorrt-version-ge-10.0 # REQUIRES: host-has-at-least-1-gpus # REQUIRES: debug-print -# RUN: %PYTHON %s 2>&1 | FileCheck %s +# RUN: %pick-one-gpu %PYTHON %s 2>&1 | FileCheck %s # This test requires TensorRT >= 10.0 since we are testing ability # to set the 'tensorrt-strongly-typed' flag. diff --git a/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_debug_dump.py b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_debug_dump.py index 26fa48c8b..2710afeaf 100644 --- a/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_debug_dump.py +++ b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_compiler_debug_dump.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s 2>&1 +# RUN: %pick-one-gpu %PYTHON %s 2>&1 # REQUIRES: host-has-at-least-1-gpus import os import tempfile diff --git a/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_plugin_schema_api.py b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_plugin_schema_api.py index 7497ad966..00aa754eb 100644 --- a/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_plugin_schema_api.py +++ b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_plugin_schema_api.py @@ -1,5 +1,4 @@ # REQUIRES: tensorrt-version-ge-10.0 -# REQUIRES: host-has-at-least-1-gpus # RUN: %PYTHON %s 2>&1 | FileCheck %s import ctypes diff --git a/mlir-tensorrt/compiler/test/python/mlir_tensorrt_runtime/test_runtime_api.py b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_runtime/test_runtime_api.py index ddd92886b..fe1edfa38 100644 --- a/mlir-tensorrt/compiler/test/python/mlir_tensorrt_runtime/test_runtime_api.py +++ b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_runtime/test_runtime_api.py @@ -1,5 +1,4 @@ # RUN: %pick-one-gpu %PYTHON %s | FileCheck %s -# REQUIRES: host-has-at-least-1-gpus from typing import Callable @@ -219,8 +218,10 @@ def test_devices(): @make_test def test_stream(): client = runtime.RuntimeClient() + devices = client.get_devices() + if len(devices) == 0: + return stream = client.create_stream() - assert isinstance(stream.ptr, int) diff --git a/mlir-tensorrt/compiler/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py index 984d9deef..238a179b3 100644 --- a/mlir-tensorrt/compiler/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py +++ b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py @@ -1,5 +1,4 @@ # RUN: %PYTHON %s 2>&1 -# REQUIRES: host-has-at-least-1-gpus import mlir_tensorrt.runtime.api as runtime diff --git a/mlir-tensorrt/executor/CMakeLists.txt b/mlir-tensorrt/executor/CMakeLists.txt index 7a934de5f..01883627c 100644 --- a/mlir-tensorrt/executor/CMakeLists.txt +++ b/mlir-tensorrt/executor/CMakeLists.txt @@ -60,12 +60,14 @@ if(MLIR_EXECUTOR_ENABLE_CUDA) mlir_executor_find_and_patch_libnvptxcompiler(CUDANVPTXCompilerLibraryPatched) endif() -mlir_executor_add_nvtx() +find_package(NVTX REQUIRED) mlir_executor_add_lua() mlir_executor_add_sol2() mlir_tensorrt_find_dlpack() find_package(Flatbuffers REQUIRED) +find_package(MLIRTensorRTCommon REQUIRED) + if(MLIR_EXECUTOR_ENABLE_MPI) find_package(MPI COMPONENTS C) endif() diff --git a/mlir-tensorrt/executor/cmake/ExecutorDependencies.cmake b/mlir-tensorrt/executor/cmake/ExecutorDependencies.cmake index b65935533..ad6013708 100644 --- a/mlir-tensorrt/executor/cmake/ExecutorDependencies.cmake +++ b/mlir-tensorrt/executor/cmake/ExecutorDependencies.cmake @@ -134,28 +134,6 @@ macro(mlir_executor_add_sol2) target_compile_definitions(sol2 INTERFACE SOL_ALL_SAFETIES_ON=1) endmacro() -# ----------------------------------------------------------------------------- -# Downloads NVTX from Github and adds it to the build -# ----------------------------------------------------------------------------- -function(mlir_executor_add_nvtx) - CPMAddPackage( - NAME nvtx - GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git - GIT_TAG v3.1.0 - GIT_SHALLOW TRUE - SOURCE_SUBDIR c - EXCLUDE_FROM_ALL TRUE - DOWNLOAD_ONLY TRUE - ) - add_library(nvtx3-cpp INTERFACE IMPORTED) - target_include_directories(nvtx3-cpp INTERFACE - "$") - # Ignore some warnings due to NVTX3 code style. - target_compile_options(nvtx3-cpp INTERFACE - -Wno-missing-braces) -endfunction() - - # ----------------------------------------------------------------------------- # Finds the NCCL headers and creates an interface target `NCCL`. # ----------------------------------------------------------------------------- diff --git a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h index ae62663f0..ed7bb0b1e 100644 --- a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h +++ b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h @@ -418,6 +418,11 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionOptionsCreate( int32_t numDevices, int32_t deviceId, MTRT_StringView ncclUuid, MTRT_RuntimeSessionOptions *options); +/// Enable a particular feature for the runtime session. +MLIR_CAPI_EXPORTED void +mtrtRuntimeSessionOptionsEnableFeature(MTRT_RuntimeSessionOptions options, + MTRT_StringView feature); + /// Destroy `options` and free any resources. MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionOptionsDestroy(MTRT_RuntimeSessionOptions options); diff --git a/mlir-tensorrt/executor/include/mlir-executor/Conversion/ConvertToExecutorCommon.h b/mlir-tensorrt/executor/include/mlir-executor/Conversion/ConvertToExecutorCommon.h index e4e5cb841..5d3d080b8 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Conversion/ConvertToExecutorCommon.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Conversion/ConvertToExecutorCommon.h @@ -289,36 +289,11 @@ class ConvertOpToExecutorPattern : public ConvertToExecutorPattern { : ConvertToExecutorPattern(typeConverter, SourceOp::getOperationName(), benefit, context) {} - /// Wrappers around the ConversionPattern methods that pass the derived op - /// type. - LogicalResult match(Operation *op) const final { - return match(cast(op)); - } - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - if constexpr (SourceOp::hasProperties()) - return rewrite(cast(op), - OpAdaptor(operands, op->getAttrDictionary(), - cast(op).getProperties()), - rewriter); - rewrite(cast(op), OpAdaptor(operands, op->getAttrDictionary()), - rewriter); - } - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto sourceOp = cast(op); - rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); - } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if constexpr (SourceOp::hasProperties()) - return matchAndRewrite(cast(op), - OpAdaptor(operands, op->getAttrDictionary(), - cast(op).getProperties()), - rewriter); - return matchAndRewrite(cast(op), - OpAdaptor(operands, op->getAttrDictionary()), + auto sourceOp = cast(op); + return matchAndRewrite(cast(op), OpAdaptor(operands, sourceOp), rewriter); } LogicalResult @@ -329,32 +304,10 @@ class ConvertOpToExecutorPattern : public ConvertToExecutorPattern { rewriter); } - /// Rewrite and Match methods that operate on the SourceOp type. These must be - /// overridden by the derived pattern class. - virtual LogicalResult match(SourceOp op) const { - (void)op; - llvm_unreachable("must override match or matchAndRewrite"); - } - virtual void rewrite(SourceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - (void)op; - (void)adaptor; - (void)rewriter; - llvm_unreachable("must override matchAndRewrite or a rewrite method"); - } - virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - SmallVector oneToOneOperands = - getOneToOneAdaptorOperands(adaptor.getOperands()); - rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); - } virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (failed(match(op))) - return failure(); - rewrite(op, adaptor, rewriter); - return success(); + llvm_unreachable("must override matchAndRewrite"); } virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, diff --git a/mlir-tensorrt/executor/include/mlir-executor/Executor/IR/ExecutorOps.td b/mlir-tensorrt/executor/include/mlir-executor/Executor/IR/ExecutorOps.td index 5e3c1ab82..93f55faba 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Executor/IR/ExecutorOps.td +++ b/mlir-tensorrt/executor/include/mlir-executor/Executor/IR/ExecutorOps.td @@ -1288,7 +1288,7 @@ def Executor_StoreOp : Executor_Op<"store", [ }]; } -def Executor_MemcpyOp : Executor_Op<"memcpy", [Executor_LowerToFuncCallTrait, +def Executor_MemcpyOp : Executor_Op<"memcpy", [ AllTypesMatch<["src_offset", "dest_offset", "num_bytes"]> ]> { let description = [{ @@ -1338,8 +1338,7 @@ def Executor_StridedMemrefCopyOp : Executor_Op<"strided_memref_copy", [ class Executor_PointerCastOp traits = []> : Executor_Op])> { + !listconcat(traits, [Pure])> { let assemblyFormat = [{ attr-dict $arg `:` functional-type(operands, results) }]; @@ -1348,36 +1347,11 @@ class Executor_PointerCastOp traits = []> : def Executor_PtrToIntOp : Executor_PointerCastOp<"ptrtoint"> { let arguments = (ins Executor_Ptr:$arg); let results = (outs Executor_Index:$result); - let extraClassDefinition = [{ - FailureOr - $cppClass::getRuntimeBuiltinFunctionName(const DataLayout &dataLayout) { - std::string result; - llvm::raw_string_ostream ss(result); - uint64_t ptrWidth = dataLayout.getTypeSizeInBits(getArg().getType()); - ss << "_ptrtoint_i" << ptrWidth - << "_" << getType(); - ss.flush(); - return result; - } - }]; } def Executor_IntToPtrOp : Executor_PointerCastOp<"inttoptr"> { let arguments = (ins Executor_Index:$arg); let results = (outs Executor_Ptr:$result); - let extraClassDefinition = [{ - /// Specify the function type. Different pointer types can be used, so treat as variadic. - FailureOr - $cppClass::getRuntimeBuiltinFunctionName(const DataLayout &dataLayout) { - std::string result; - uint64_t ptrWidth = dataLayout.getTypeSizeInBits(getType()); - llvm::raw_string_ostream ss(result); - ss << "_inttoptr_i" << ptrWidth - << "_" << getArg().getType(); - ss.flush(); - return result; - } - }]; } #endif // MLIR_TENSORRT_DIALECT_EXECUTOR_IR_EXECUTOROPS_TD diff --git a/mlir-tensorrt/executor/include/mlir-executor/InitAllPasses.h b/mlir-tensorrt/executor/include/mlir-executor/InitAllPasses.h index 492ad0d32..5c09bcf1e 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/InitAllPasses.h +++ b/mlir-tensorrt/executor/include/mlir-executor/InitAllPasses.h @@ -19,7 +19,7 @@ //===----------------------------------------------------------------------===// #include "mlir-executor/Conversion/Passes.h" #include "mlir-executor/Executor/Transforms/Passes.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/Passes.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" @@ -36,9 +36,7 @@ inline void registerAllPasses() { mlir::func::registerDuplicateFunctionEliminationPass(); mlir::memref::registerMemRefPasses(); mlir::registerTransformsPasses(); - mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { - return mlir::createConvertSCFToCFPass(); - }); + mlir::registerSCFToControlFlowPass(); } } // namespace mlir::executor diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h index aff2fb09d..d417ee11a 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h @@ -22,8 +22,8 @@ /// the compielr-runtime or runtime-user interface. /// //===----------------------------------------------------------------------===// -#ifndef MLIR_TENSORRT_RUNTIME_API_API -#define MLIR_TENSORRT_RUNTIME_API_API +#ifndef MLIR_EXECUTOR_RUNTIME_API_API +#define MLIR_EXECUTOR_RUNTIME_API_API #include "dlpack/dlpack.h" #include "mlir-executor/Support/Allocators.h" @@ -34,8 +34,10 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Support/MemoryBuffer.h" #include +#include #include #include @@ -284,7 +286,7 @@ class MemRefTypeView : public FlatbufferTypeObjectView ScalarValue(T data_, ScalarType type) : RuntimeValue(Kind::Scalar), type(type) { - static_assert(sizeof(T) <= sizeof(int64_t) && - alignof(T) <= alignof(int64_t), + static_assert(sizeof(T) <= sizeof(uint64_t) && + alignof(T) <= alignof(uint64_t), "expected scalar type size to be <= 8 bytes"); - *reinterpret_cast(&data) = data_; + *reinterpret_cast(&data.real) = data_; + } + + template + ScalarValue(T real_, T imag_, ScalarType type) + : RuntimeValue(Kind::Scalar), type(type) { + static_assert(std::is_same_v || std::is_same_v, + "complex constructor only valid for float and double types."); + static_assert(sizeof(T) <= sizeof(uint64_t) && + alignof(T) <= alignof(uint64_t), + "expected scalar type size to be <= 8 bytes."); + assert(isComplex() && + "complex value constructor used for non-complex scalar type."); + data.complex = std::make_unique>(real_, imag_).release(); } + // Delete copy constructors. + ScalarValue(const ScalarValue &other) = delete; + ScalarValue &operator=(const ScalarValue &other) = delete; + + // Move constructors. + ScalarValue(ScalarValue &&other) noexcept; + ScalarValue &operator=(ScalarValue &&other) noexcept; + + ~ScalarValue(); + ScalarType getType() const { return type; } template T get() const { - static_assert(sizeof(T) <= sizeof(int64_t), - "expected scalar type size to be <= 8 bytes"); - return *reinterpret_cast(&data); + static_assert(sizeof(T) <= sizeof(uint64_t), + "expected scalar type size to be <= 8 bytes."); + assert(!isComplex() && "use `getComplex()` for complex scalar."); + return *reinterpret_cast(&data.real); } - void *getRaw() { return &data; } + template + std::complex getComplex() const { + static_assert(std::is_same_v || std::is_same_v, + "getComplex() only supports float and double types."); + static_assert(sizeof(T) <= sizeof(uint64_t) && + alignof(T) <= alignof(uint64_t), + "expected scalar type size to be <= 8 bytes."); + assert(isComplex() && + "complex value constructor used for non-complex scalar type."); + if constexpr (std::is_same_v) { + assert( + type.getCode() == ScalarTypeCode::complex32 && + "Type mismatch: expected scalar type code of complex32 for float."); + } else { + assert( + type.getCode() == ScalarTypeCode::complex64 && + "Type mismatch: expected scalar type code of complex64 for double."); + } + return *static_cast *>(data.complex); + } + + bool isComplex() const { + return type.getCode() == ScalarTypeCode::complex32 || + type.getCode() == ScalarTypeCode::complex64; + } + + void *getRaw() { return isComplex() ? data.complex : &data.real; } static bool classof(const RuntimeValue *v) { return v->getKind() == Kind::Scalar; } private: - int64_t data; + void cleanup(); + Storage data; ScalarType type; }; @@ -882,8 +940,13 @@ class RuntimeSessionOptions { /// devices, and NCCL UUID. Single-device sessions can use the default /// options. RuntimeSessionOptions(int32_t numDevices = 1, int32_t deviceId = 0, - llvm::StringRef ncclUuid = "") - : numDevices(numDevices), deviceId(deviceId), ncclUuid(ncclUuid) {} + llvm::StringRef ncclUuid = ""); + + /// Enable the specified features for the runtime session. + void enableFeatures(llvm::ArrayRef features); + + /// Returns true if the given feature is enabled. + bool isFeatureEnabled(llvm::StringRef feature) const; /// Populates the runtime session options using the MPI calls. Each MPI /// process is expected to be associated with a CUDA device associated with @@ -905,10 +968,17 @@ class RuntimeSessionOptions { /// one device.a llvm::StringRef getNcclUuid() const { return ncclUuid; } + /// Return the set of features that are enabled for this session. + const llvm::StringSet<> &getEnabledFeatures() const { return features; } + private: int32_t numDevices; int32_t deviceId; std::string ncclUuid; + + /// A list of features names (e.g. module names) that should be enabled for + /// this session. + llvm::StringSet<> features; }; //===----------------------------------------------------------------------===// @@ -1123,7 +1193,7 @@ class RuntimeClient { copyToDevice(const MemRefValue &hostBuffer, const Device &device, std::optional stream); - /// Allocates a new device buffer and fills it with data present on the host + /// Allocates a new host buffer and fills it with data present on the device /// in the specified buffer. The allocation and copy are performed on the /// given stream. StatusOr> @@ -1203,4 +1273,4 @@ inline llvm::raw_ostream &print(llvm::raw_ostream &os, } // namespace mlirtrt::runtime -#endif // MLIR_TENSORRT_RUNTIME_API_API +#endif // MLIR_EXECUTOR_RUNTIME_API_API diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaExtensionRegistry.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaExtensionRegistry.h index f37eadb5b..2ee50f28c 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaExtensionRegistry.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaExtensionRegistry.h @@ -37,11 +37,14 @@ struct LuaRuntimeExtension { void registerLuaRuntimeExtension(llvm::StringRef name, LuaRuntimeExtension extensionInfo); -void populateRuntimeExtensions(const RuntimeSessionOptions &options, - lua_State *state, - PinnedMemoryAllocator *pinnedMemoryAllocator, - AllocTracker *allocTracker, - ResourceTracker *resourceTracker); +/// Enable the Lua runtime extension modules that are specified in the features +/// of 'options'. If an enabled feature is not present in the registry, then an +/// error is returned. +Status populateRuntimeExtensions(const RuntimeSessionOptions &options, + lua_State *state, + PinnedMemoryAllocator *pinnedMemoryAllocator, + AllocTracker *allocTracker, + ResourceTracker *resourceTracker); } // namespace mlirtrt::runtime diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h index 92bfda786..b2414aa8e 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h @@ -78,7 +78,7 @@ class LuaRuntimeSession : public RuntimeSession { /// integer result (which is returned if the execution is successful). /// TODO: this should take a handle to a function for streaming output/errors. StatusOr runExecutorLuaScript( - std::string_view luaScript, + RuntimeSessionOptions options, std::string_view luaScript, LuaRuntimeSession::LuaModuleRegistrationFunc registerExtraLuaFuncs = {}); /// Synchronously run a serialized executor Executable one time. An `Executable` @@ -92,7 +92,7 @@ StatusOr runExecutorLuaScript( /// TODO: this should take a handle to a function for /// streaming output/errors. StatusOr runExecutorExecutable( - std::unique_ptr executable, + RuntimeSessionOptions options, std::unique_ptr executable, LuaRuntimeSession::LuaModuleRegistrationFunc registerExtraLuaFuncs = {}); /// Execute a named function in the session with the specified input args and diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Utils/NvtxUtils.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Utils/NvtxUtils.h index 06a0ab7ef..bc5ba5c2c 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Utils/NvtxUtils.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Utils/NvtxUtils.h @@ -24,8 +24,6 @@ #ifndef MLIR_TENSORRT_RUNTIME_BACKEND_UTILS_NVTXUTILS_H #define MLIR_TENSORRT_RUNTIME_BACKEND_UTILS_NVTXUTILS_H -#ifdef MLIR_TRT_ENABLE_NVTX - #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmissing-braces" @@ -71,16 +69,4 @@ constexpr inline nvtx3::rgb CudaModuleColor() { mlirtrt::runtime::NvtxRange r(mlirtrt::runtime::tracing::CudaModuleColor(), \ funcName) -#else - -#define ADD_RUNTIME_MODULE_RANGE(funcName) - -#define ADD_CORE_MODULE_RANGE(funcName) - -#define ADD_TENSORRT_MODULE_RANGE(funcName) - -#define ADD_CUDA_MODULE_RANGE(funcName) - -#endif - #endif // MLIR_TENSORRT_RUNTIME_BACKEND_UTILS_NVTXUTILS_H diff --git a/mlir-tensorrt/executor/include/mlir-executor/Support/Status.h b/mlir-tensorrt/executor/include/mlir-executor/Support/Status.h index 8b8918fa2..3de8f448e 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Support/Status.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Support/Status.h @@ -23,8 +23,8 @@ /// mechanisms. /// //===----------------------------------------------------------------------===// -#ifndef MLIR_TENSORRT_SUPPORT_STATUS_H -#define MLIR_TENSORRT_SUPPORT_STATUS_H +#ifndef MLIR_EXECUTOR_SUPPORT_STATUS +#define MLIR_EXECUTOR_SUPPORT_STATUS #include "mlir-executor/Utils/ADTExtras.h" #include "llvm/ADT/StringExtras.h" @@ -260,4 +260,4 @@ class StatusOr { } // namespace mlirtrt -#endif // MLIR_TENSORRT_SUPPORT_STATUS_H +#endif // MLIR_EXECUTOR_SUPPORT_STATUS diff --git a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp index 27263d2bc..196262fe7 100644 --- a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp +++ b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp @@ -725,6 +725,12 @@ mtrtRuntimeSessionOptionsDestroy(MTRT_RuntimeSessionOptions options) { return mtrtStatusGetOk(); } +void mtrtRuntimeSessionOptionsEnableFeature(MTRT_RuntimeSessionOptions options, + MTRT_StringView feature) { + RuntimeSessionOptions *cppOptions = unwrap(options); + cppOptions->enableFeatures(std::string(feature.data, feature.length)); +} + //===----------------------------------------------------------------------===// // MTRT_RuntimeSession //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/executor/lib/Executor/IR/Executor.cpp b/mlir-tensorrt/executor/lib/Executor/IR/Executor.cpp index 1e2583b3e..a025fdc54 100644 --- a/mlir-tensorrt/executor/lib/Executor/IR/Executor.cpp +++ b/mlir-tensorrt/executor/lib/Executor/IR/Executor.cpp @@ -484,12 +484,6 @@ uint64_t PointerType::getABIAlignment(const DataLayout &dataLayhout, return kDefaultPointerAlignment; } -uint64_t -PointerType::getPreferredAlignment(const DataLayout &dataLayout, - DataLayoutEntryListRef params) const { - return getABIAlignment(dataLayout, params); -} - //===----------------------------------------------------------------------===// // ExecutorTableType //===----------------------------------------------------------------------===// @@ -527,11 +521,6 @@ uint64_t TableType::getABIAlignment(const DataLayout &dataLayout, return structAlignment; } -uint64_t TableType::getPreferredAlignment(const DataLayout &dataLayout, - DataLayoutEntryListRef params) const { - return getABIAlignment(dataLayout, params); -} - //===----------------------------------------------------------------------===// // AssertOp //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/executor/lib/Executor/Transforms/Passes.cpp b/mlir-tensorrt/executor/lib/Executor/Transforms/Passes.cpp index be334241c..3a3ffe873 100644 --- a/mlir-tensorrt/executor/lib/Executor/Transforms/Passes.cpp +++ b/mlir-tensorrt/executor/lib/Executor/Transforms/Passes.cpp @@ -21,8 +21,8 @@ #include "mlir-executor/Conversion/Passes.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" +#include "mlir/Conversion/Passes.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" @@ -41,7 +41,7 @@ void executor::buildExecutorLoweringPipeline( OpPassManager &pm, const ConvertStdToExecutorPassOptions &stdToExecutorOpts) { pm.addPass(createConvertComplexToStandardPass()); - pm.addPass(createConvertSCFToCFPass()); + pm.addPass(mlir::createSCFToControlFlowPass()); pm.addPass(memref::createFoldMemRefAliasOpsPass()); pm.addPass(memref::createExpandOpsPass()); pm.addPass(memref::createExpandStridedMetadataPass()); diff --git a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp index 947d6a516..36eed3ae1 100644 --- a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp @@ -297,6 +297,40 @@ Status Executable::verify() const { //===----------------------------------------------------------------------===// // RuntimeSessionOptions //===----------------------------------------------------------------------===// +RuntimeSessionOptions::RuntimeSessionOptions(int32_t numDevices, + int32_t deviceId, + llvm::StringRef ncclUuid) + : numDevices(numDevices), deviceId(deviceId), ncclUuid(ncclUuid) { + this->enableFeatures({"core"}); +} + +void RuntimeSessionOptions::enableFeatures( + llvm::ArrayRef toEnable) { + for (const auto &feature : toEnable) { + if (feature == "cuda") { + // If the "cuda" feature is enabled, then we need to enable a module + // that provides the device ID and number of devices. This is either SPMD + // (default module for single-device, no NCCL/MPI) or NCCL. + if (!features.contains("single-device") && !features.contains("nccl")) + this->enableFeatures({"single-device"}); + } + // If we enable the NCCL feature, disable the SPMD feature if present. + if (feature == "nccl") { + // If the "nccl" feature is enabled, then we override the default + // "single-device" feature. + if (auto it = features.find("single-device"); it != features.end()) + features.erase(it); + + // If NCCL is enabled, then enable CUDA as well. + this->features.insert("cuda"); + } + this->features.insert(feature); + } +} + +bool RuntimeSessionOptions::isFeatureEnabled(llvm::StringRef feature) const { + return features.contains(feature); +} StatusOr RuntimeSessionOptions::createUsingSingleHostMpi() { @@ -583,6 +617,46 @@ StatusOr> Device::create(int32_t deviceNumber) { return std::unique_ptr(new Device(deviceNumber)); } +//===----------------------------------------------------------------------===// +// ScalarValue +//===----------------------------------------------------------------------===// + +ScalarValue::ScalarValue(ScalarValue &&other) noexcept + : RuntimeValue(Kind::Scalar), type(other.type) { + if (other.isComplex()) { + data.complex = other.data.complex; + other.data.complex = nullptr; + } else { + data.real = other.data.real; + } +} + +ScalarValue &ScalarValue::operator=(ScalarValue &&other) noexcept { + if (this != &other) { + cleanup(); + type = other.type; + if (other.isComplex()) { + data.complex = other.data.complex; + other.data.complex = nullptr; + } else { + data.real = other.data.real; + } + } + return *this; +} + +ScalarValue::~ScalarValue() { cleanup(); } + +void ScalarValue::cleanup() { + if (isComplex()) { + if (type.getCode() == ScalarTypeCode::complex32) { + delete static_cast *>(data.complex); + } else { + delete static_cast *>(data.complex); + } + } +} + //===----------------------------------------------------------------------===// // MemRefValue //===----------------------------------------------------------------------===// @@ -985,8 +1059,9 @@ StatusOr> RuntimeClient::create() { // Setup device objects. Create a view of the device pointers. llvm::SmallVector> devices; mlirtrt::Status status = populateDevices(devices); - if (!status.isOk()) - return status; + if (!status.isOk()) { + // TODO: we should emit a warning here. + } auto defaultAllocator = std::make_unique(); diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaExtensionRegistry.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaExtensionRegistry.cpp index 8ca904d74..ca5b02fd9 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaExtensionRegistry.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaExtensionRegistry.cpp @@ -24,6 +24,7 @@ //===----------------------------------------------------------------------===// #include "mlir-executor/Runtime/Backend/Lua/LuaExtensionRegistry.h" #include "mlir-executor/Runtime/API/API.h" +#include "mlir-executor/Runtime/Support/Support.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/ManagedStatic.h" @@ -38,11 +39,27 @@ void runtime::registerLuaRuntimeExtension(llvm::StringRef name, (*extensionRegistry)[name] = std::move(extensionInfo); } -void runtime::populateRuntimeExtensions( +Status runtime::populateRuntimeExtensions( const RuntimeSessionOptions &options, lua_State *state, PinnedMemoryAllocator *pinnedMemoryAllocator, AllocTracker *allocTracker, ResourceTracker *resourceTracker) { - for (const auto &[key, ext] : *extensionRegistry) - ext.populateLuaState(options, state, pinnedMemoryAllocator, allocTracker, - resourceTracker); + for (const auto &[key, ext] : *extensionRegistry) { + if (options.isFeatureEnabled(key)) { + MTRT_DBG("Enabling Lua runtime module: {0}", key); + ext.populateLuaState(options, state, pinnedMemoryAllocator, allocTracker, + resourceTracker); + continue; + } + MTRT_DBG("Disabling Lua runtime module: {0}", key); + } + + // Check for features that are enabled but not supported by the runtime. + for (const auto &feature : options.getEnabledFeatures()) { + if (!extensionRegistry->contains(feature.getKey())) { + return getInvalidArgStatus( + "feature {0} is enabled but not supported by the runtime", + feature.getKey()); + } + } + return getOkStatus(); } diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp index 420e0846a..d212d7996 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp @@ -30,7 +30,6 @@ #include "mlir-executor/Runtime/Backend/Common/CommonRuntime.h" #include "mlir-executor/Runtime/Backend/Lua/LuaExtensionRegistry.h" #include "mlir-executor/Runtime/Backend/Lua/LuaExtensions.h" -#include "mlir-executor/Runtime/Backend/Lua/LuaRegistration.h" #include "mlir-executor/Runtime/Backend/Lua/Modules/Utils/MemRefUtils.h" #include "mlir-executor/Runtime/Backend/Utils/NvtxUtils.h" #include "mlir-executor/Runtime/Support/Support.h" @@ -61,11 +60,11 @@ using namespace mlirtrt::runtime; static constexpr uint64_t kMinConstantBufferByteAlignment = 8; -#ifndef MLIR_EXECUTOR_ENABLE_NCCL -/// If the runtime is not built with MLIR_EXECUTOR_ENABLE_NCCL, then this -/// function registers default implementations for the required SPMD functions, -/// reflecting that the executable is expected to run against a single fixed -/// CUDA device and is not part of a larger device grid. +/// This function registers default implementations for the required SPMD +/// functions, reflecting that the executable is expected to run against a +/// single fixed CUDA device and is not part of a larger device grid. +/// These functions are only used if the runtime session is created with the +/// "single-device" feature enabled. static void registerDefaultDeviceDependentMethods(lua_State *state, int32_t numDevices, int32_t deviceIdx) { @@ -77,7 +76,6 @@ static void registerDefaultDeviceDependentMethods(lua_State *state, return deviceIdx; }; } -#endif // MLIR_EXECUTOR_ENABLE_NCCL namespace mlirtrt::runtime { void registerLuaCoreRuntimeExtension(); @@ -108,9 +106,12 @@ void runtime::registerLuaRuntimeExtensions() { #endif #ifdef MLIR_EXECUTOR_ENABLE_NCCL registerLuaNcclRuntimeExtension(); -#else +#endif + + // The "single-device" module provides default implementation for the SPMD + // device rank/num rank functions which just map to the one enabled device . registerLuaRuntimeExtension( - "spmd", + "single-device", LuaRuntimeExtension{ [](const RuntimeSessionOptions &options, lua_State *state, PinnedMemoryAllocator *pinnedMemoryAllocator, @@ -118,15 +119,6 @@ void runtime::registerLuaRuntimeExtensions() { registerDefaultDeviceDependentMethods( state, options.getNumDevices(), options.getDeviceId()); }}); -#endif -} - -void mlirtrt::runtime::registerLuaRuntimeMethods( - lua_State *state, const RuntimeSessionOptions &options, - PinnedMemoryAllocator *pinnedMemoryAllocator, AllocTracker *allocTracker, - ResourceTracker *resourceTracker) { - populateRuntimeExtensions(options, state, pinnedMemoryAllocator, allocTracker, - resourceTracker); } /// If the program was compiled with NCCL enabled, then check for the @@ -275,10 +267,10 @@ LuaRuntimeSession::create(RuntimeSessionOptions options, lua.open_libraries(sol::lib::base, sol::lib::string, sol::lib::coroutine); // Register builtin methods. - registerLuaRuntimeMethods(lua.lua_state(), session->getOptions(), - &session->getPinnedMemoryAllocator(), - &session->getAllocTracker(), - &session->getResourceTracker()); + MTRT_RETURN_IF_ERROR(populateRuntimeExtensions( + session->getOptions(), lua.lua_state(), + &session->getPinnedMemoryAllocator(), &session->getAllocTracker(), + &session->getResourceTracker())); // Register user-provided methods. if (registerExtraLuaFuncs) @@ -347,7 +339,7 @@ Status LuaRuntimeSession::setCudaStream(CudaStream stream) { //===----------------------------------------------------------------------===// StatusOr mlirtrt::runtime::runExecutorLuaScript( - std::string_view luaScript, + RuntimeSessionOptions options, std::string_view luaScript, LuaRuntimeSession::LuaModuleRegistrationFunc registerExtraLuaFuncs) { ADD_RUNTIME_MODULE_RANGE("runtime_runExecutorLuaScript"); @@ -355,20 +347,13 @@ StatusOr mlirtrt::runtime::runExecutorLuaScript( if (!client.isOk()) return client.getStatus(); -#ifdef MLIR_EXECUTOR_ENABLE_NCCL - StatusOr options = - RuntimeSessionOptions::createUsingSingleHostMpi(); -#else - StatusOr options = RuntimeSessionOptions(); -#endif - MTRT_ASSIGN_OR_RETURN( std::unique_ptr session, - LuaRuntimeSession::create(std::move(*options), ExecutableView(nullptr), + LuaRuntimeSession::create(std::move(options), ExecutableView(nullptr), std::move(registerExtraLuaFuncs))); sol::state_view lua = session->getLuaState(); - sol::protected_function_result result = lua.script(luaScript); + sol::protected_function_result result = lua.safe_script(luaScript); if (!result.valid()) { sol::error err = result; return getStatusWithMsg(StatusCode::InternalError, @@ -396,25 +381,16 @@ StatusOr mlirtrt::runtime::runExecutorLuaScript( } StatusOr mlirtrt::runtime::runExecutorExecutable( - std::unique_ptr executable, + RuntimeSessionOptions options, std::unique_ptr executable, LuaRuntimeSession::LuaModuleRegistrationFunc registerExtraLuaFuncs) { StatusOr> client = RuntimeClient::create(); if (!client.isOk()) return client.getStatus(); -#ifdef MLIR_EXECUTOR_ENABLE_NCCL - StatusOr options = - RuntimeSessionOptions::createUsingSingleHostMpi(); -#else - StatusOr options = RuntimeSessionOptions(); -#endif - if (!options.isOk()) - return options.getStatus(); - MTRT_ASSIGN_OR_RETURN( std::unique_ptr session, - LuaRuntimeSession::create(*options, executable->getView(), + LuaRuntimeSession::create(std::move(options), executable->getView(), std::move(registerExtraLuaFuncs))); // Call the main function, if present. @@ -658,7 +634,7 @@ getScalarValue(const sol::protected_function_result &pfr, int index, case ScalarTypeCode::i64: return std::make_unique(pfr[index].get(), code); case ScalarTypeCode::f8e4m3fn: - return std::make_unique(pfr[index].get(), code); + return std::make_unique(pfr[index].get<__nv_fp8_e4m3>(), code); case ScalarTypeCode::f16: return std::make_unique(pfr[index].get<__half>(), code); case ScalarTypeCode::bf16: @@ -667,20 +643,22 @@ getScalarValue(const sol::protected_function_result &pfr, int index, return std::make_unique(pfr[index].get(), code); case ScalarTypeCode::f64: return std::make_unique(pfr[index].get(), code); + case ScalarTypeCode::complex32: + return std::make_unique( + static_cast(static_cast(pfr[index])[1]), + static_cast(static_cast(pfr[index])[2]), code); + case ScalarTypeCode::complex64: + return std::make_unique( + static_cast(static_cast(pfr[index])[1]), + static_cast(static_cast(pfr[index])[2]), code); default: - return getInvalidArgStatus("Unsupported scalar type code: ", + return getInvalidArgStatus("Unsupported scalar type code: {0}", impl::EnumNameScalarTypeCode(code)); } } /// Parses the results of a function call, handling both scalar and MemRef /// return types. -/// -/// @param pfr The protected function result to parse. -/// @param sig The function signature view. -/// @param session Lua runtime session. -/// @param client Optional runtime client pointer. -/// @return A vector of unique pointers to RuntimeValue, or an error status. static StatusOr>> parseResults(const sol::protected_function_result &pfr, const FunctionSignatureView &sig, LuaRuntimeSession &session, diff --git a/mlir-tensorrt/executor/lib/Target/Lua/TranslateToLua.cpp b/mlir-tensorrt/executor/lib/Target/Lua/TranslateToLua.cpp index 2918cf11b..3e0eac54d 100644 --- a/mlir-tensorrt/executor/lib/Target/Lua/TranslateToLua.cpp +++ b/mlir-tensorrt/executor/lib/Target/Lua/TranslateToLua.cpp @@ -45,7 +45,8 @@ namespace { /// common components without the code duplication here. class LuaEmitter { public: - explicit LuaEmitter(MLIRContext *ctx, raw_ostream &os); + explicit LuaEmitter(MLIRContext *ctx, raw_ostream &os, + const DataLayout &dataLayout); /// Emit Lua ofr a "module-like" operation. This creates a new scope for all /// resources. It is expected that this is only used as the top-level @@ -149,6 +150,9 @@ class LuaEmitter { MLIRContext *ctx; raw_indented_ostream os; + + /// The data layout of the module. + mlir::DataLayout moduleDataLayout; }; } // namespace @@ -400,6 +404,29 @@ static LogicalResult printOperation(LuaEmitter &emitter, return success(); } +static LogicalResult printPtrToIntOp(LuaEmitter &emitter, + executor::PtrToIntOp op, + const DataLayout &dataLayout) { + if (failed(emitter.emitAssignPrefix(op))) + return failure(); + + uint64_t ptrWidth = dataLayout.getTypeSizeInBits(op.getArg().getType()); + emitter << "_ptrtoint_i" << ptrWidth << "_" << op.getType() << "(" + << emitter.getVariableName(op.getArg()) << ");\n"; + return success(); +} + +static LogicalResult printIntToPtrOp(LuaEmitter &emitter, + executor::IntToPtrOp op, + const DataLayout &dataLayout) { + if (failed(emitter.emitAssignPrefix(op))) + return failure(); + uint64_t ptrWidth = dataLayout.getTypeSizeInBits(op.getType()); + emitter << "_inttoptr_i" << ptrWidth << "_" << op.getOperand().getType() + << "(" << emitter.getVariableName(op.getArg()) << ");\n"; + return success(); +} + static LogicalResult printExecutorBinaryInfixOperation(LuaEmitter &emitter, Operation *op) { Value lhs = op->getOperand(0); @@ -499,6 +526,15 @@ static LogicalResult printOperation(LuaEmitter &emitter, executor::PrintOp op) { return success(); } +static LogicalResult printMemCpyOp(LuaEmitter &emitter, executor::MemcpyOp op) { + emitter << "executor_memcpy(" << emitter.getVariableName(op.getSrc()) << ", " + << emitter.getVariableName(op.getSrcOffset()) << ", " + << emitter.getVariableName(op.getDest()) << ", " + << emitter.getVariableName(op.getDestOffset()) << ", " + << emitter.getVariableName(op.getNumBytes()) << ");\n"; + return success(); +} + /// Emit binary arithmetic op that calls out to a special function. /// Since we don't want this function name to be calculated until all types are /// resolved, we avoid handling this in "executor expand ops" pass. Also, this @@ -674,7 +710,9 @@ static LogicalResult printOperation(LuaEmitter &emitter, executor::FuncOp op) { // LuaEmitter implementation //===----------------------------------------------------------------------===// -LuaEmitter::LuaEmitter(MLIRContext *ctx, raw_ostream &os) : ctx(ctx), os(os) { +LuaEmitter::LuaEmitter(MLIRContext *ctx, raw_ostream &os, + const DataLayout &dataLayout) + : ctx(ctx), os(os), moduleDataLayout(dataLayout) { localsInScopeCount.push(0); labelInScopeCount.push(0); globalsInScopeCount.push(0); @@ -925,6 +963,8 @@ LogicalResult LuaEmitter::emitOperation(Operation &op) { [&](auto op) { return printOperation(*this, op); }) .Case( [&](auto op) { return printOperation(*this, op); }) + .Case( + [&](auto op) { return printMemCpyOp(*this, op); }) .Case( [&](auto op) { return printOperation(*this, op); }) @@ -933,6 +973,12 @@ LogicalResult LuaEmitter::emitOperation(Operation &op) { .Case( [&](auto op) { return printOperation(*this, op); }) + .Case([&](auto op) { + return printPtrToIntOp(*this, op, moduleDataLayout); + }) + .Case([&](auto op) { + return printIntToPtrOp(*this, op, moduleDataLayout); + }) .Default([&](Operation *) { return op.emitOpError("unable to find printer for op"); }); @@ -952,7 +998,7 @@ LogicalResult LuaEmitter::emitOperation(Operation &op) { } LogicalResult mlir::translateToLua(Operation *op, raw_ostream &os) { - LuaEmitter luaEmitter(op->getContext(), os); + LuaEmitter luaEmitter(op->getContext(), os, DataLayout::closest(op)); if (isa(op)) return luaEmitter.emitOperation(*op); if (isModuleLike(*op)) diff --git a/mlir-tensorrt/executor/lib/Target/Lua/TranslateToRuntimeExecutable.cpp b/mlir-tensorrt/executor/lib/Target/Lua/TranslateToRuntimeExecutable.cpp index 7a47cf834..69eea0402 100644 --- a/mlir-tensorrt/executor/lib/Target/Lua/TranslateToRuntimeExecutable.cpp +++ b/mlir-tensorrt/executor/lib/Target/Lua/TranslateToRuntimeExecutable.cpp @@ -134,6 +134,58 @@ class FBBuilder : public fb::FlatBufferBuilder64 { } // namespace +static bool isElidedResourceElementsAttr(ElementsAttr attr) { + auto denseResourceAttr = dyn_cast(attr); + if (!denseResourceAttr) + return false; + DenseResourceElementsHandle handle = denseResourceAttr.getRawHandle(); + if (handle.getKey() != "__elided__") + return false; + return true; +} + +static FailureOr +getDenseElementsAttrOfOnes(ElementsAttr attr) { + ShapedType tensorType = cast(attr.getType()); + Type elementType = tensorType.getElementType(); + if (elementType.isInteger(1)) + return DenseElementsAttr::get(tensorType, true); + if (elementType.isInteger(8)) + return DenseElementsAttr::get(tensorType, APInt(8, 1)); + if (elementType.isInteger(16)) + return DenseElementsAttr::get(tensorType, APInt(16, 1)); + if (elementType.isInteger(32)) + return DenseElementsAttr::get(tensorType, APInt(32, 1)); + if (elementType.isInteger(64)) + return DenseElementsAttr::get(tensorType, APInt(64, 1)); + if (isa(elementType)) + return DenseElementsAttr::get(tensorType, + APFloat::getOne(APFloat::Float8E4M3FN())); + if (elementType.isF16()) + return DenseElementsAttr::get(tensorType, + APFloat::getOne(APFloat::IEEEhalf())); + if (elementType.isBF16()) + return DenseElementsAttr::get(tensorType, + APFloat::getOne(APFloat::BFloat())); + if (elementType.isF32()) + return DenseElementsAttr::get(tensorType, + APFloat::getOne(APFloat::IEEEsingle())); + if (elementType.isF64()) + return DenseElementsAttr::get(tensorType, + APFloat::getOne(APFloat::IEEEdouble())); + if (elementType == + ComplexType::get(Float32Type::get(elementType.getContext()))) { + std::complex complexOne(1.0f, 1.0f); + return DenseElementsAttr::get(tensorType, complexOne); + } + if (elementType == + ComplexType::get(Float64Type::get(elementType.getContext()))) { + std::complex complexOne(1.0, 1.0); + return DenseElementsAttr::get(tensorType, complexOne); + } + return failure(); +} + template class OffsetT, template class VectorT> FailureOr>> @@ -141,6 +193,14 @@ FBBuilder::serialize64(Location loc, const DataLayout &dataLayout, ElementsAttr attr, std::optional alignment) { FlatbufferElementsSerializer serializer(*this, dataLayout); + if (isElidedResourceElementsAttr(attr)) { + // Elided attribute can't be serialized so we create splat + // of `1`s (splat of `true` in case of boolean). + auto attrOfOnes = getDenseElementsAttrOfOnes(attr); + if (failed(attrOfOnes)) + return failure(); + attr = *attrOfOnes; + } if (failed(mlir::serializeElementsAttr(loc, attr, dataLayout, serializer, alignment))) return failure(); @@ -240,7 +300,8 @@ translateTypeVariant(FBBuilder &fbBuilder, Type t) { << "unhandled type (" << t << ") in Executor function metadata"; }; - if (!isa(t)) + if (!isa(t)) return emitTranslateFailure(t); // Encode as a memref. diff --git a/mlir-tensorrt/executor/lib/Tools/ExecutorRunnerMain.cpp b/mlir-tensorrt/executor/lib/Tools/ExecutorRunnerMain.cpp index 0b3eb521e..bbd9a596d 100644 --- a/mlir-tensorrt/executor/lib/Tools/ExecutorRunnerMain.cpp +++ b/mlir-tensorrt/executor/lib/Tools/ExecutorRunnerMain.cpp @@ -110,6 +110,10 @@ struct Options { cl::values(clEnumValN(Lua, "lua", "interpret the input as Lua code")), cl::values(clEnumValN(ExecutorRuntimeExecutable, "rtexe", "load the input file as an Executor executable"))}; + + cl::list features{ + "features", llvm::cl::list_init({"core"}), + cl::CommaSeparated, cl::desc("runtime features/modules to enable")}; }; } // namespace @@ -138,8 +142,30 @@ static LogicalResult initializeCudaRuntime() { llvm::errs() << "cudaFree failed: " << cudaGetErrorString(result); return failure(); } -#endif return success(); +#else + llvm::errs() << "runtime was not built with CUDA support\n"; + return failure(); +#endif +} + +static StatusOr +getRuntimeSessionOptions(const Options &options, + ArrayRef features) { +#ifdef MLIR_EXECUTOR_ENABLE_NCCL + if (llvm::is_contained(features, "nccl")) { + StatusOr opts = + RuntimeSessionOptions::createUsingSingleHostMpi(); + if (!opts.isOk()) + return opts.getStatus(); + opts->enableFeatures(features); + return opts; + } +#endif + + RuntimeSessionOptions opts; + opts.enableFeatures(features); + return opts; } LogicalResult executor::ExecutorRunnerMain( @@ -148,18 +174,25 @@ LogicalResult executor::ExecutorRunnerMain( registerExtraLuaFuncs) { llvm::InitLLVM initLLVM(argc, argv); - Status mpiStatus = maybeInitializeMpi(); - - if (!mpiStatus.isOk()) { - llvm::errs() << "failed to initialize MPI: " << mpiStatus.getString() - << "\n"; - return failure(); - } - // Register and parse CLI args. Options options; cl::ParseCommandLineOptions(argc, argv, "MLIR-TensorRT Runtime Interpreter"); + if (!options.dumpFunctionSignature) { + if (llvm::is_contained(options.features, "nccl")) { + Status mpiStatus = maybeInitializeMpi(); + if (!mpiStatus.isOk()) { + llvm::errs() << "failed to initialize MPI: " << mpiStatus.getString() + << "\n"; + return failure(); + } + } + + if (llvm::is_contained(options.features, "cuda") && + failed(initializeCudaRuntime())) + return failure(); + } + if (postInitCallback) postInitCallback(); @@ -194,14 +227,15 @@ LogicalResult executor::ExecutorRunnerMain( if (options.inputType != InputType::ExecutorRuntimeExecutable) return emitError(loc) << "function signature can only be dumped with " "Runtime Executable inputs"; - } else { - // We only need to initialize the CUDA runtime if we are going to run - // something. - // TODO: Allow enable/disable CUDA requirement on command-line. - if (failed(initializeCudaRuntime())) - return failure(); } + StatusOr sessionOptions = + getRuntimeSessionOptions(options, options.features); + if (!sessionOptions.isOk()) + return emitError(UnknownLoc::get(&context)) + << "failed to get runtime session options: " + << sessionOptions.getStatus().getString(); + // Read the buffer as a Lua script and execute. auto processBuffer = [&](std::unique_ptr input, llvm::raw_ostream &os) -> LogicalResult { @@ -209,8 +243,8 @@ LogicalResult executor::ExecutorRunnerMain( assert(!options.dumpFunctionSignature && "Can not dump function signature for Lua input type."); mlirtrt::StatusOr result = - mlirtrt::runtime::runExecutorLuaScript(input->getBuffer(), - registerExtraLuaFuncs); + mlirtrt::runtime::runExecutorLuaScript( + *sessionOptions, input->getBuffer(), registerExtraLuaFuncs); if (!result.isOk()) return emitError(UnknownLoc::get(&context)) << result.getString(); return success(*result == 0); @@ -256,7 +290,8 @@ LogicalResult executor::ExecutorRunnerMain( mlirtrt::StatusOr executionResult = mlirtrt::runtime::runExecutorExecutable( - std::move(*executable), std::move(registerExtraLuaFuncs)); + *sessionOptions, std::move(*executable), + std::move(registerExtraLuaFuncs)); if (!executionResult.isOk()) return emitError(UnknownLoc::get(&context)) << "failed to load and run executable: " diff --git a/mlir-tensorrt/executor/lib/Utils/CMakeLists.txt b/mlir-tensorrt/executor/lib/Utils/CMakeLists.txt index d5ced1340..aab87884e 100644 --- a/mlir-tensorrt/executor/lib/Utils/CMakeLists.txt +++ b/mlir-tensorrt/executor/lib/Utils/CMakeLists.txt @@ -1,7 +1,3 @@ -if(MLIR_EXECUTOR_ENABLE_TENSORRT) - add_subdirectory(TensorRTDynamicLoader EXCLUDE_FROM_ALL) -endif() - add_mlir_executor_library(MLIRExecutorCommonUtils PARTIAL_SOURCES_INTENDED RegionUtils.cpp diff --git a/mlir-tensorrt/executor/test/Executor/lower-builtins.mlir b/mlir-tensorrt/executor/test/Executor/lower-builtins.mlir index 7af1cc5e9..341224c23 100644 --- a/mlir-tensorrt/executor/test/Executor/lower-builtins.mlir +++ b/mlir-tensorrt/executor/test/Executor/lower-builtins.mlir @@ -1,31 +1,5 @@ // RUN: executor-opt %s -split-input-file -executor-lower-to-runtime-builtins | FileCheck %s -func.func @pointer_cast_lowering() -> (i32, i64) { - %cst0 = executor.constant 123 : i32 - %ptr0 = executor.inttoptr %cst0 : (i32) -> !executor.ptr - %int0 = executor.ptrtoint %ptr0 : (!executor.ptr) -> i32 - - %cst1 = executor.constant 456 : i64 - %ptr1 = executor.inttoptr %cst1 : (i64) -> !executor.ptr - %int1 = executor.ptrtoint %ptr1 : (!executor.ptr) -> i64 - return %int0, %int1 : i32, i64 -} -// CHECK-DAG: executor.func private @_ptrtoint_i64_i64(!executor.ptr) -> i64 -// CHECK-DAG: executor.func private @_inttoptr_i64_i64(i64) -> !executor.ptr -// CHECK-DAG: executor.func private @_ptrtoint_i64_i32(!executor.ptr) -> i32 -// CHECK-DAG: executor.func private @_inttoptr_i64_i32(i32) -> !executor.ptr -// CHECK-LABEL: func.func @pointer_cast_lowering -// CHECK-SAME: () -> (i32, i64) { -// CHECK: %[[c123_i32:.+]] = executor.constant 123 : i32 -// CHECK: %[[v0:.+]] = executor.call @_inttoptr_i64_i32(%[[c123_i32]]) : (i32) -> !executor.ptr -// CHECK: %[[v1:.+]] = executor.call @_ptrtoint_i64_i32(%[[v0]]) : (!executor.ptr) -> i32 -// CHECK: %[[c456_i64:.+]] = executor.constant 456 : i64 -// CHECK: %[[v2:.+]] = executor.call @_inttoptr_i64_i64(%[[c456_i64]]) : (i64) -> !executor.ptr -// CHECK: %[[v3:.+]] = executor.call @_ptrtoint_i64_i64(%[[v2]]) : (!executor.ptr) -> i64 -// CHECK: return %[[v1]], %[[v3]] : i32, i64 - -// ----- - func.func @alignto_lowering(%arg0: i32, %arg1: i64) -> (i32, i64) { %0 = executor.alignto %arg0, 2 : i32 %1 = executor.alignto %arg1, 4 : i64 diff --git a/mlir-tensorrt/executor/test/IntegrationTests/arithmetic.mlir b/mlir-tensorrt/executor/test/IntegrationTests/arithmetic.mlir index fdb6c1b78..c0ce5aa88 100644 --- a/mlir-tensorrt/executor/test/IntegrationTests/arithmetic.mlir +++ b/mlir-tensorrt/executor/test/IntegrationTests/arithmetic.mlir @@ -2,10 +2,10 @@ // RUN: executor-opt %s -executor-lowering-pipeline -o %t.mlir // RUN: executor-translate -mlir-to-lua %t.mlir \ -// RUN: | executor-runner -input-type=lua | FileCheck %s +// RUN: | executor-runner -input-type=lua -features=core | FileCheck %s // RUN: executor-translate -mlir-to-runtime-executable %t.mlir \ -// RUN: | executor-runner -input-type=rtexe | FileCheck %s +// RUN: | executor-runner -input-type=rtexe -features=core | FileCheck %s func.func @test_addi(%arg0: i64, %arg1: i64) { %0 = executor.addi %arg0, %arg1 : i64 @@ -1080,8 +1080,8 @@ func.func @main() -> i64 { %cn1_i8 = executor.constant -1 : i8 func.call @test_trunci_i8_to_i1(%cn1_i8) : (i8) -> () func.call @test_trunci_i8_to_i1(%c1_i8) : (i8) -> () - - + + return %ic0 : i64 diff --git a/mlir-tensorrt/executor/test/IntegrationTests/assertion.mlir b/mlir-tensorrt/executor/test/IntegrationTests/assertion.mlir index 369582ee5..f0526827d 100644 --- a/mlir-tensorrt/executor/test/IntegrationTests/assertion.mlir +++ b/mlir-tensorrt/executor/test/IntegrationTests/assertion.mlir @@ -1,6 +1,6 @@ // RUN: executor-opt %s -executor-lowering-pipeline \ // RUN: | executor-translate -mlir-to-lua \ -// RUN: | not executor-runner -input-type=lua 2>&1 | FileCheck %s +// RUN: | not executor-runner -input-type=lua -features=core 2>&1 | FileCheck %s func.func @main() -> i32 { %c0 = executor.constant 0 : i32 diff --git a/mlir-tensorrt/executor/test/IntegrationTests/complex.mlir b/mlir-tensorrt/executor/test/IntegrationTests/complex.mlir index 5ddc543fc..e9abe1e36 100644 --- a/mlir-tensorrt/executor/test/IntegrationTests/complex.mlir +++ b/mlir-tensorrt/executor/test/IntegrationTests/complex.mlir @@ -1,7 +1,7 @@ // RUN: executor-opt %s \ // RUN: -executor-lowering-pipeline \ // RUN: | executor-translate -mlir-to-runtime-executable \ - // RUN: | executor-runner -input-type=rtexe \ + // RUN: | executor-runner -input-type=rtexe -features=core \ // RUN: | FileCheck %s func.func @print_complex(%arg0: complex) { diff --git a/mlir-tensorrt/executor/test/IntegrationTests/control-flow-nested.mlir b/mlir-tensorrt/executor/test/IntegrationTests/control-flow-nested.mlir index 1ed802724..3556a2d54 100644 --- a/mlir-tensorrt/executor/test/IntegrationTests/control-flow-nested.mlir +++ b/mlir-tensorrt/executor/test/IntegrationTests/control-flow-nested.mlir @@ -1,6 +1,6 @@ // RUN: executor-opt %s -executor-lowering-pipeline \ // RUN: | executor-translate -mlir-to-lua \ -// RUN: | executor-runner -input-type=lua | FileCheck %s +// RUN: | executor-runner -input-type=lua -features=core | FileCheck %s func.func @test_for(%lb: index, %ub: index, %step: index) { %c0 = executor.constant 0 : index diff --git a/mlir-tensorrt/executor/test/IntegrationTests/control-flow.mlir b/mlir-tensorrt/executor/test/IntegrationTests/control-flow.mlir index 415284908..8e3b4d8a8 100644 --- a/mlir-tensorrt/executor/test/IntegrationTests/control-flow.mlir +++ b/mlir-tensorrt/executor/test/IntegrationTests/control-flow.mlir @@ -1,6 +1,6 @@ // RUN: executor-opt %s -executor-lowering-pipeline \ // RUN: | executor-translate -mlir-to-lua \ -// RUN: | executor-runner -input-type=lua | FileCheck %s +// RUN: | executor-runner -input-type=lua -features=core | FileCheck %s func.func @test_for(%lb: index, %ub: index, %step: index) { %c0 = executor.constant 0 : index diff --git a/mlir-tensorrt/executor/test/IntegrationTests/coroutine.mlir b/mlir-tensorrt/executor/test/IntegrationTests/coroutine.mlir index 276ccac7d..dcd02cb31 100644 --- a/mlir-tensorrt/executor/test/IntegrationTests/coroutine.mlir +++ b/mlir-tensorrt/executor/test/IntegrationTests/coroutine.mlir @@ -1,6 +1,6 @@ // RUN: executor-opt %s -split-input-file -executor-lowering-pipeline | \ // RUN: executor-translate -mlir-to-lua -split-input-file --output-split-marker="-- -----" \ -// RUN: | executor-runner -input-type=lua -split-input-file="-- -----" | FileCheck %s +// RUN: | executor-runner -features=core -input-type=lua -split-input-file="-- -----" | FileCheck %s func.func @coro(%arg0: i32, %arg1: i32) -> (i32, i32) { %start = arith.index_cast %arg0 : i32 to index diff --git a/mlir-tensorrt/executor/test/IntegrationTests/fill-device-f32.mlir b/mlir-tensorrt/executor/test/IntegrationTests/fill-device-f32.mlir index 876285bed..0ef81b43c 100644 --- a/mlir-tensorrt/executor/test/IntegrationTests/fill-device-f32.mlir +++ b/mlir-tensorrt/executor/test/IntegrationTests/fill-device-f32.mlir @@ -1,6 +1,6 @@ // RUN: executor-opt %s -inline -executor-lowering-pipeline \ // RUN: | executor-translate -mlir-to-runtime-executable \ -// RUN: | executor-runner -input-type=rtexe | FileCheck %s +// RUN: | executor-runner -input-type=rtexe -features=core,cuda | FileCheck %s !scalar_type = f32 !host_memref_type = memref<4x!scalar_type, #executor.memory_type> diff --git a/mlir-tensorrt/executor/test/IntegrationTests/fill-f32.mlir b/mlir-tensorrt/executor/test/IntegrationTests/fill-f32.mlir index 975d91ed9..ed5349260 100644 --- a/mlir-tensorrt/executor/test/IntegrationTests/fill-f32.mlir +++ b/mlir-tensorrt/executor/test/IntegrationTests/fill-f32.mlir @@ -1,6 +1,6 @@ // RUN: executor-opt %s -test-executor-bufferization-pipeline -inline -executor-lowering-pipeline \ // RUN: | executor-translate -mlir-to-runtime-executable \ -// RUN: | executor-runner -input-type=rtexe | FileCheck %s +// RUN: | executor-runner -input-type=rtexe -features=core | FileCheck %s !scalar_type = f32 diff --git a/mlir-tensorrt/executor/test/IntegrationTests/fill-i1.mlir b/mlir-tensorrt/executor/test/IntegrationTests/fill-i1.mlir index 1eb8f9423..2a0aebb3a 100644 --- a/mlir-tensorrt/executor/test/IntegrationTests/fill-i1.mlir +++ b/mlir-tensorrt/executor/test/IntegrationTests/fill-i1.mlir @@ -1,6 +1,6 @@ // RUN: executor-opt %s -test-executor-bufferization-pipeline -inline -executor-lowering-pipeline \ // RUN: | executor-translate -mlir-to-runtime-executable \ -// RUN: | executor-runner -input-type=rtexe | FileCheck %s +// RUN: | executor-runner -input-type=rtexe -features=core | FileCheck %s !scalar_type = i1 diff --git a/mlir-tensorrt/executor/test/IntegrationTests/host-buffer-c32.mlir b/mlir-tensorrt/executor/test/IntegrationTests/host-buffer-c32.mlir index 296547439..27f691629 100644 --- a/mlir-tensorrt/executor/test/IntegrationTests/host-buffer-c32.mlir +++ b/mlir-tensorrt/executor/test/IntegrationTests/host-buffer-c32.mlir @@ -1,6 +1,6 @@ // RUN: executor-opt %s -test-executor-bufferization-pipeline -inline -executor-lowering-pipeline \ // RUN: | executor-translate -mlir-to-runtime-executable \ -// RUN: | executor-runner -input-type=rtexe | FileCheck %s +// RUN: | executor-runner -input-type=rtexe -features=core | FileCheck %s !memref_type = memref<4xcomplex, strided<[?], offset: ?>, #executor.memory_type> diff --git a/mlir-tensorrt/executor/test/IntegrationTests/host-buffer-i4.mlir b/mlir-tensorrt/executor/test/IntegrationTests/host-buffer-i4.mlir index 088466c68..a105b113d 100644 --- a/mlir-tensorrt/executor/test/IntegrationTests/host-buffer-i4.mlir +++ b/mlir-tensorrt/executor/test/IntegrationTests/host-buffer-i4.mlir @@ -1,6 +1,6 @@ // RUN: executor-opt %s -test-executor-bufferization-pipeline -executor-lowering-pipeline \ // RUN: | executor-translate -mlir-to-runtime-executable \ -// RUN: | executor-runner -input-type=rtexe | FileCheck %s +// RUN: | executor-runner -input-type=rtexe -features=core | FileCheck %s !memref_type = memref<4xi4, strided<[?], offset: ?>, #executor.memory_type> diff --git a/mlir-tensorrt/executor/test/IntegrationTests/load-globals.mlir b/mlir-tensorrt/executor/test/IntegrationTests/load-globals.mlir index d067190c7..30257a3dd 100644 --- a/mlir-tensorrt/executor/test/IntegrationTests/load-globals.mlir +++ b/mlir-tensorrt/executor/test/IntegrationTests/load-globals.mlir @@ -1,7 +1,7 @@ // RUN: executor-opt %s \ // RUN: -executor-lower-to-runtime-builtins | \ // RUN: executor-translate -mlir-to-runtime-executable |\ -// RUN: executor-runner -input-type=rtexe | FileCheck %s +// RUN: executor-runner -input-type=rtexe -features=core,cuda | FileCheck %s executor.data_segment @dense_i32 constant dense<[32, 33]> : tensor<2xi32> executor.data_segment @device_i32 constant address_space dense<[99, 101]> : tensor<2xi32> diff --git a/mlir-tensorrt/executor/test/IntegrationTests/pointer-cast-ops.mlir b/mlir-tensorrt/executor/test/IntegrationTests/pointer-cast-ops.mlir index 46003575c..80ae01ed2 100644 --- a/mlir-tensorrt/executor/test/IntegrationTests/pointer-cast-ops.mlir +++ b/mlir-tensorrt/executor/test/IntegrationTests/pointer-cast-ops.mlir @@ -1,6 +1,6 @@ // RUN: executor-opt %s -executor-lower-to-runtime-builtins \ // RUN: | executor-translate -mlir-to-lua \ -// RUN: | executor-runner -input-type=lua | FileCheck %s +// RUN: | executor-runner -input-type=lua -features=core | FileCheck %s func.func @main() -> i32 { %cst0 = executor.constant 123 : i32 diff --git a/mlir-tensorrt/executor/test/IntegrationTests/ptr-to-int.mlir b/mlir-tensorrt/executor/test/IntegrationTests/ptr-to-int.mlir new file mode 100644 index 000000000..aee3a6d2f --- /dev/null +++ b/mlir-tensorrt/executor/test/IntegrationTests/ptr-to-int.mlir @@ -0,0 +1,32 @@ +// RUN: executor-opt %s -executor-lowering-pipeline | \ +// RUN: executor-translate -mlir-to-lua | \ +// RUN: executor-runner -input-type=lua | FileCheck %s + +func.func @host_ptr_to_int(%arg0: !executor.ptr) -> i64 + attributes {no_inline} { + %0 = executor.ptrtoint %arg0 : (!executor.ptr) -> i64 + return %0 : i64 +} + +func.func @device_ptr_to_int(%arg0: !executor.ptr) -> i64 + attributes {no_inline} { + %0 = executor.ptrtoint %arg0 : (!executor.ptr) -> i64 + return %0 : i64 +} + +func.func @main() -> i32 { + %c0_i32 = executor.constant 0 : i32 + %c0 = executor.constant 0 : i64 + %c1 = executor.constant 1 : i64 + %1 = executor.inttoptr %c0 : (i64) -> !executor.ptr + %2 = func.call @host_ptr_to_int(%1) : (!executor.ptr) -> i64 + executor.print "host pointer as i64 = %d"(%2 : i64) + + %3 = executor.inttoptr %c1 : (i64) -> !executor.ptr + %4 = func.call @device_ptr_to_int(%3) : (!executor.ptr) -> i64 + executor.print "device pointer as i64 = %d"(%4 : i64) + return %c0_i32 : i32 +} + +// CHECK: host pointer as i64 = 0 +// CHECK: device pointer as i64 = 1 diff --git a/mlir-tensorrt/executor/test/IntegrationTests/stream.mlir b/mlir-tensorrt/executor/test/IntegrationTests/stream.mlir index 4e89cf7a2..0444cc43e 100644 --- a/mlir-tensorrt/executor/test/IntegrationTests/stream.mlir +++ b/mlir-tensorrt/executor/test/IntegrationTests/stream.mlir @@ -1,6 +1,6 @@ // RUN: executor-opt %s -executor-lowering-pipeline \ // RUN: | executor-translate -mlir-to-lua \ -// RUN: | executor-runner -input-type=lua | FileCheck %s +// RUN: | executor-runner -input-type=lua -features=core,cuda | FileCheck %s executor.func private @__cuda_stream_create() -> (!executor.ptr) executor.func private @__cuda_stream_sync(!executor.ptr) -> () diff --git a/mlir-tensorrt/executor/test/Unit/Runtime/LuaRuntime/ExecuteFunctionWithLuaBackendTests.cpp b/mlir-tensorrt/executor/test/Unit/Runtime/LuaRuntime/ExecuteFunctionWithLuaBackendTests.cpp index 64f7d4c56..03e23e7df 100644 --- a/mlir-tensorrt/executor/test/Unit/Runtime/LuaRuntime/ExecuteFunctionWithLuaBackendTests.cpp +++ b/mlir-tensorrt/executor/test/Unit/Runtime/LuaRuntime/ExecuteFunctionWithLuaBackendTests.cpp @@ -11,6 +11,7 @@ #include "mlir-executor/InitAllDialects.h" #include "mlir-executor/Runtime/API/API.h" +#include "mlir-executor/Runtime/Backend/Lua/LuaExtensions.h" #include "mlir-executor/Runtime/Backend/Lua/LuaRuntime.h" #include "mlir-executor/Target/Lua/TranslateToRuntimeExecutable.h" #include "mlir/IR/AsmState.h" @@ -30,6 +31,7 @@ class TestRuntime : public ::testing::Test { protected: void SetUp() override { mlir::executor::registerAllRequiredDialects(registry); + mlirtrt::runtime::registerLuaRuntimeExtensions(); context = std::make_unique(registry); } @@ -44,8 +46,9 @@ class TestRuntime : public ::testing::Test { StatusOr> createLuaRuntimeSession( const std::unique_ptr &executable) { - return LuaRuntimeSession::create(RuntimeSessionOptions(), - executable->getView(), {}); + RuntimeSessionOptions options; + options.enableFeatures({"core"}); + return LuaRuntimeSession::create(options, executable->getView(), {}); } void assertScalarValuesEqual(const ScalarValue *result, @@ -95,12 +98,14 @@ TEST_F(TestRuntime, TestRuntimeExecution) { std::make_unique(std::move(*exeStorage)); auto session = createLuaRuntimeSession(executable); - ASSERT_TRUE(session.isOk()); + ASSERT_TRUE(session.isOk()) << session.getString(); - std::vector scalarValues = {{1, ScalarTypeCode::i32}, - {2, ScalarTypeCode::i32}, - {3, ScalarTypeCode::i32}, - {4, ScalarTypeCode::i32}}; + std::vector scalarValues; + scalarValues.reserve(4); + scalarValues.emplace_back(1, ScalarTypeCode::i32); + scalarValues.emplace_back(2, ScalarTypeCode::i32); + scalarValues.emplace_back(3, ScalarTypeCode::i32); + scalarValues.emplace_back(4, ScalarTypeCode::i32); llvm::SmallVector reference = { &scalarValues[2], &scalarValues[3], &scalarValues[0], &scalarValues[1]}; @@ -109,12 +114,12 @@ TEST_F(TestRuntime, TestRuntimeExecution) { llvm::SmallVector outputArgs; auto client = createRuntimeClient(); - ASSERT_TRUE(client.isOk()); + ASSERT_TRUE(client.isOk()) << client.getString(); auto result = executeFunctionWithLuaBackend( *(*session).get(), "main", inputArgs, outputArgs, std::nullopt, std::optional((*client).get())); - ASSERT_TRUE(result.isOk()); + ASSERT_TRUE(result.isOk()) << result.getString(); ASSERT_EQ((*result).size(), reference.size()) << "Vector sizes don't match"; diff --git a/mlir-tensorrt/executor/test/executor-runner/invalid.mlir b/mlir-tensorrt/executor/test/executor-runner/invalid.mlir index 6e07fadb2..cc365bf33 100644 --- a/mlir-tensorrt/executor/test/executor-runner/invalid.mlir +++ b/mlir-tensorrt/executor/test/executor-runner/invalid.mlir @@ -1,3 +1,3 @@ -// RUN: not executor-runner %s -input-type=rtexe || FileCheck %s +// RUN: not executor-runner %s -input-type=rtexe -modules=core || FileCheck %s // CHECK: error: failed to load executable from buffer: InvalidArgument: failed to verify that the provided buffer contains a valid MLIR-TRT Executable diff --git a/mlir-tensorrt/integrations/python/bindings/Runtime/RuntimePyBind.cpp b/mlir-tensorrt/integrations/python/bindings/Runtime/RuntimePyBind.cpp index eb44a9277..b82ff7dfa 100644 --- a/mlir-tensorrt/integrations/python/bindings/Runtime/RuntimePyBind.cpp +++ b/mlir-tensorrt/integrations/python/bindings/Runtime/RuntimePyBind.cpp @@ -1118,16 +1118,32 @@ PYBIND11_MODULE(_api, m) { py::class_(m, "RuntimeSessionOptions", py::module_local()) .def(py::init<>([](int32_t numDevices, int32_t deviceId, - std::string ncclUuid) -> PyRuntimeSessionOptions * { + std::string ncclUuid, + std::optional> features) + -> PyRuntimeSessionOptions * { MTRT_RuntimeSessionOptions options; MTRT_Status s = mtrtRuntimeSessionOptionsCreate( numDevices, deviceId, MTRT_StringView{ncclUuid.data(), ncclUuid.size()}, &options); THROW_IF_MTRT_ERROR(s); + + if (features) { + for (const std::string &feature : *features) + mtrtRuntimeSessionOptionsEnableFeature( + options, MTRT_StringView{feature.data(), feature.size()}); + } else { + std::array defaultFeatures = {"core", "cuda", + "tensorrt"}; + // Enable all the default features. + for (const auto &feature : defaultFeatures) + mtrtRuntimeSessionOptionsEnableFeature( + options, MTRT_StringView{feature.data(), feature.size()}); + } return new PyRuntimeSessionOptions(options); }), py::arg("num_devices") = 1, py::arg("device_id") = 0, - py::arg("nccl_uuid") = py::str("")); + py::arg("nccl_uuid") = py::str(""), + py::arg("features") = py::none()); py::class_>( m, "RuntimeSession", py::module_local()) @@ -1135,6 +1151,7 @@ PYBIND11_MODULE(_api, m) { MTRT_RuntimeSession session; MTRT_Status s = mtrtRuntimeSessionCreate(options, exe, &session); THROW_IF_MTRT_ERROR(s); + return std::make_shared(session); }), py::arg("options"), py::arg("executable")) diff --git a/mlir-tensorrt/integrations/python/mlir_tensorrt_runtime/mlir_tensorrt/runtime/_mlir_libs/_api.pyi b/mlir-tensorrt/integrations/python/mlir_tensorrt_runtime/mlir_tensorrt/runtime/_mlir_libs/_api.pyi index 4ab056088..4be1a590a 100644 --- a/mlir-tensorrt/integrations/python/mlir_tensorrt_runtime/mlir_tensorrt/runtime/_mlir_libs/_api.pyi +++ b/mlir-tensorrt/integrations/python/mlir_tensorrt_runtime/mlir_tensorrt/runtime/_mlir_libs/_api.pyi @@ -226,7 +226,11 @@ class RuntimeSession: class RuntimeSessionOptions: def __init__( - self, num_devices: int = 1, device_id: int = 0, nccl_uuid: str = "" + self, + num_devices: int = 1, + device_id: int = 0, + nccl_uuid: str = "", + features: list[str] | None = None, ) -> None: ... class RuntimeValue: diff --git a/mlir-tensorrt/integrations/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py b/mlir-tensorrt/integrations/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py index 553a72d47..ea18794b1 100644 --- a/mlir-tensorrt/integrations/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py +++ b/mlir-tensorrt/integrations/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py @@ -9,9 +9,9 @@ from contextlib import contextmanager from typing import List, Optional, Tuple +from pynvml import * import click import numpy as np -from pynvml import * def get_uniform_devices() -> List[int]: @@ -134,7 +134,28 @@ def cli(): required=False, type=click.FLOAT, ) -def pick_device(required_memory: Optional[float]): +@click.option( + "--required-host-memory", + help="causes the command to block until the specified amount of host memory (in gigabytes) is available", + required=False, + type=click.FLOAT, + default=2.0, +) +def pick_device(required_memory: Optional[float], required_host_memory: float): + try: + import psutil + + while True: + # Force to wait until at least 10GB of host memory is available. + required_host_memory = max(required_host_memory, 10.0) + gb_host_mem_avail = psutil.virtual_memory().available / (1024**3) + if gb_host_mem_avail < required_host_memory: + time.sleep(1.0) + continue + break + except: + pass + with nvml_context() as devices: if len(devices) == 0: return diff --git a/mlir-tensorrt/tensorrt/CMakeLists.txt b/mlir-tensorrt/tensorrt/CMakeLists.txt index 1f9082229..7576c0a88 100644 --- a/mlir-tensorrt/tensorrt/CMakeLists.txt +++ b/mlir-tensorrt/tensorrt/CMakeLists.txt @@ -35,6 +35,7 @@ include(cmake/TensorRTFunctions.cmake) #------------------------------------------------------------------------------- # Dependencies #------------------------------------------------------------------------------- +find_package(MLIRTensorRTCommon REQUIRED) if(NOT TARGET MLIRSupport) find_package(MLIR REQUIRED CONFIG) diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h index 92f3bb211..76a93bcf6 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h @@ -29,6 +29,7 @@ #include "mlir-tensorrt-dialect/Utils/NvInferPluginUtils.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/StringSet.h" +#include #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic push @@ -271,14 +272,17 @@ static nvinfer1::Dims getNvInferDims(ArrayRef arrayRef) { "input array exceeds max dims"); nvinfer1::Dims dims; dims.nbDims = arrayRef.size(); + + using NvInferDimType = std::remove_reference_t; + llvm::copy(llvm::map_range(arrayRef, - [](auto x) { - if (static_cast(x) == - ShapedType::kDynamic) + [](auto x) -> NvInferDimType { + if (ShapedType::isDynamic(x)) return -1; - return static_cast(x); + return static_cast(x); }), dims.d); + return dims; } diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td index d8c378e93..0b157b3e6 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td @@ -1117,7 +1117,7 @@ def TensorRT_LinspaceOp : TensorRT_Op<"linspace", [ }]; let arguments = (ins - Optional<1DTensorOf<[I32]>>:$shape, + Optional<1DTensorOf<[I32, I64]>>:$shape, Optional<0DTensorOf<[I32, I64, F32]>>:$start, Optional<1DTensorOf<[I32, I64, F32]>>:$step, OptionalAttr:$static_start, diff --git a/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp b/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp index b89fc5090..d59b2b15a 100644 --- a/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp +++ b/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp @@ -212,11 +212,12 @@ nvinfer1::Permutation tensorrt::getNvInferPermutation(ArrayRef array) { static std::string getUniqueName(NvInferNetworkEncoder::NamesSet &names, std::string name) { - unsigned i = 0; - while (names.contains(name)) - name = name + "_" + std::to_string(i++); - names.insert(name); - return name; + static unsigned i = 0; + std::string uniqueName = name; + while (names.contains(uniqueName)) + uniqueName = name + "_" + std::to_string(i++); + names.insert(uniqueName); + return uniqueName; } /// Print a representation of the given location to the string. Since MLIR has @@ -267,6 +268,7 @@ static std::string createName(NvInferNetworkEncoder::NamesSet &names, static constexpr size_t kLayerNameSizeLimit = 2048; if (name.size() > kLayerNameSizeLimit) name = name.substr(0, kLayerNameSizeLimit); + // TRT name does not allow nested quotations. name = llvm::join(llvm::split(name, "\""), ""); return getUniqueName(names, name); @@ -761,6 +763,16 @@ static void setProfileDimensions(nvinfer1::IOptimizationProfile *profile, getNvInferDims(maxShape)); } +using NvInferShapeValueType = std::remove_const_t().getShapeValues( + "arg0", nvinfer1::OptProfileSelector::kMIN)[0])>>; + +static NvInferShapeValueType clampToNvInferShapeValueType(int64_t x) { + return static_cast(std::max( + std::min(x, std::numeric_limits::max()), + std::numeric_limits::min())); +} + /// Add the argument and shape tensor bounds information to the optimization /// profile. static void setShapeTensorInputProfile(nvinfer1::IOptimizationProfile *profile, @@ -775,9 +787,10 @@ static void setShapeTensorInputProfile(nvinfer1::IOptimizationProfile *profile, SmallVector> shapes{minShape, optShape, maxShape}; for (auto [kind, shape] : llvm::zip(profiles, shapes)) { nvinfer1::Dims dims = getNvInferDims(shape); - SmallVector shapeValues; + SmallVector shapeValues; shapeValues.reserve(dims.nbDims); - std::copy_n(dims.d, dims.nbDims, std::back_inserter(shapeValues)); + for (int32_t i = 0, e = dims.nbDims; i < e; ++i) + shapeValues.push_back(clampToNvInferShapeValueType(dims.d[i])); profile->setShapeValues(argName.c_str(), kind, shapeValues.data(), shapeValues.size()); } diff --git a/mlir-tensorrt/tensorrt/test/lit.cfg.py b/mlir-tensorrt/tensorrt/test/lit.cfg.py index 32e80a712..b2f387d39 100644 --- a/mlir-tensorrt/tensorrt/test/lit.cfg.py +++ b/mlir-tensorrt/tensorrt/test/lit.cfg.py @@ -137,4 +137,4 @@ def estimate_paralllelism( lit_config.parallelism_groups["translation-tests"] = estimate_paralllelism( 8.0, gb_sys_mem_required=3.0 ) -lit_config.parallelism_group = None +config.parallelism_group = None diff --git a/mlir-tensorrt/third_party/torch-mlir-cmake/CMakeLists.txt b/mlir-tensorrt/third_party/torch-mlir-cmake/CMakeLists.txt new file mode 100644 index 000000000..3225d1e61 --- /dev/null +++ b/mlir-tensorrt/third_party/torch-mlir-cmake/CMakeLists.txt @@ -0,0 +1,262 @@ +#------------------------------------------------------------------------------------- +# This file is used to declare Torch-MLIR CMake targets in lieu of using the CMake +# code that comes with Torch-MLIR. +#------------------------------------------------------------------------------------- + +message(STATUS "Adding Torch-MLIR CMake targets") +message(STATUS "torch-mlir source dir: ${torch_mlir_SOURCE_DIR}") +message(STATUS "torch-mlir binary dir: ${torch_mlir_BINARY_DIR}") + +find_package(MLIR REQUIRED CONFIG) +include(TableGen) +include(AddLLVM) +include(AddMLIR) +include_directories(${LLVM_INCLUDE_DIRS}) +include_directories(${MLIR_INCLUDE_DIRS}) + +include_directories( + "${torch_mlir_SOURCE_DIR}/include" + "${torch_mlir_BINARY_DIR}/include" + ) + +set(TORCH_MLIR_TABLEGEN_FLAGS "") +set(TORCH_MLIR_ENABLE_STABLEHLO "${MLIR_TRT_ENABLE_HLO}") +set(TORCH_MLIR_ENABLE_REFBACKEND OFF) +set(TORCH_MLIR_ENABLE_TOSA OFF) +set(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER OFF) + + +if(TORCH_MLIR_ENABLE_STABLEHLO) + add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO) + list(APPEND TORCH_MLIR_TABLEGEN_FLAGS "-DTORCH_MLIR_ENABLE_STABLEHLO") +endif() + +add_subdirectory( + "${torch_mlir_SOURCE_DIR}/include" + "${CMAKE_CURRENT_BINARY_DIR}/include" + ) + +set(torch_mlir_core_source_files + lib/Conversion/TorchToSCF/TorchToSCF.cpp + lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp + lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp + lib/Conversion/TorchToArith/TorchToArith.cpp + lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp + lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp + lib/Conversion/TorchOnnxToTorch/Patterns.cpp + lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp + lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp + lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp + lib/Conversion/TorchOnnxToTorch/Passes.cpp + lib/Conversion/TorchOnnxToTorch/Utils.cpp + lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp + lib/Conversion/Passes.cpp + lib/Conversion/Utils/Utils.cpp + lib/Conversion/TorchToLinalg/TorchToLinalg.cpp + lib/Conversion/TorchToLinalg/Pooling.cpp + lib/Conversion/TorchToLinalg/Reduction.cpp + lib/Conversion/TorchToLinalg/Linear.cpp + lib/Conversion/TorchToLinalg/Uncategorized.cpp + lib/Conversion/TorchToLinalg/DataMovement.cpp + lib/Conversion/TorchToLinalg/TensorConstructors.cpp + lib/Conversion/TorchToLinalg/Utils.cpp + lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp + lib/Conversion/TorchToLinalg/Random.cpp + lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp + lib/Conversion/TorchToTensor/TorchToTensor.cpp + lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp + lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp + lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp + lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp + lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp + lib/Dialect/TorchConversion/Transforms/Passes.cpp + lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp + lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp + lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp + lib/Dialect/Torch/IR/TorchOps.cpp + lib/Dialect/Torch/IR/TorchTypes.cpp + lib/Dialect/Torch/IR/TorchDialect.cpp + lib/Dialect/Torch/IR/UtilsForODSGenerated.cpp + lib/Dialect/Torch/IR/TorchOpsODSGenerated.cpp + lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp + lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp + lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp + lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp + lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp + lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp + lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp + lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp + lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp + lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp + lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp + lib/Dialect/Torch/Transforms/DropAbstractInterpCalculations.cpp + lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp + lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp + lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp + lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp + lib/Dialect/Torch/Transforms/Passes.cpp + lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp + lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp + lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp + lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp + lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp + lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp + lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp + lib/Dialect/Torch/Utils/SparsityUtils.cpp + lib/Dialect/Torch/Utils/TorchUpstream.cpp + lib/Dialect/Torch/Utils/Utils.cpp + lib/Dialect/TMTensor/IR/TMTensorInterfaces.cpp + lib/Dialect/TMTensor/IR/TMTensorOps.cpp + lib/Dialect/TMTensor/IR/TMTensorDialect.cpp + lib/Dialect/TMTensor/IR/ScalarLoopOpInterface.cpp + lib/Dialect/TMTensor/Transforms/Bufferize.cpp + lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp + lib/Dialect/TMTensor/Transforms/Passes.cpp +) + +if(TORCH_MLIR_ENABLE_STABLEHLO) + list(APPEND torch_mlir_core_source_files + lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp + lib/Conversion/TorchToStablehlo/ViewLike.cpp + lib/Conversion/TorchToStablehlo/Pooling.cpp + lib/Conversion/TorchToStablehlo/Rng.cpp + lib/Conversion/TorchToStablehlo/Reduction.cpp + lib/Conversion/TorchToStablehlo/Linear.cpp + lib/Conversion/TorchToStablehlo/Basic.cpp + lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp + lib/Conversion/TorchToStablehlo/GatherScatter.cpp + lib/Conversion/TorchToStablehlo/Utils.cpp + lib/Conversion/TorchToStablehlo/Uncategorized.cpp + lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp + ) +endif() + +list(TRANSFORM torch_mlir_core_source_files PREPEND ${torch_mlir_SOURCE_DIR}/) + +set(torch_mlir_optional_deps) +if(TORCH_MLIR_ENABLE_STABLEHLO) + list(APPEND torch_mlir_optional_deps + ChloOps + StablehloLinalgTransforms + StablehloOps + StablehloOptimizationPasses + StablehloPasses + ) +endif() + +add_mlir_tensorrt_library(MLIRTensorRTTorchMLIR + ${torch_mlir_core_source_files} + + PARTIAL_SOURCES_INTENDED + + DEPENDS + MLIRTorchConversionOpsIncGen + MLIRTorchOpsIncGen + MLIRTorchTypesIncGen + TorchMLIRConversionPassIncGen + TorchMLIRConversionTorchOnnxToTorchPassIncGen + TorchMLIRTMTensorOpsIncGen + TorchMLIRTMTensorTransformsPassesIncGen + TorchMLIRTorchConversionPassIncGen + TorchMLIRTorchPassIncGen + + LINK_LIBS PUBLIC + ${torch_mlir_optional_deps} + MLIRArithDialect + MLIRFuncInlinerExtension + MLIRFuncTransforms + MLIRIR + MLIRLinalgDialect + MLIRMemRefDialect + MLIRMemRefTransforms + MLIRMLProgramDialect + MLIRSCFDialect + MLIRTensorDialect + MLIRTensorInferTypeOpInterfaceImpl +) + +target_include_directories( + obj.MLIRTensorRTTorchMLIR + INTERFACE + $ + $ +) + +target_include_directories( + MLIRTensorRTTorchMLIR + INTERFACE + $ + $ +) + + +add_mlir_public_c_api_library(TorchMLIRCAPI + ${torch_mlir_SOURCE_DIR}/lib/CAPI/Dialects.cpp + ${torch_mlir_SOURCE_DIR}/lib/CAPI/TorchOps.cpp + ${torch_mlir_SOURCE_DIR}/lib/CAPI/TorchTypes.cpp + ${torch_mlir_SOURCE_DIR}/lib/CAPI/Transforms.cpp + + PARTIAL_SOURCES_INTENDED + + ADDITIONAL_HEADER_DIRS + ${torch_mlir_SOURCE_DIR}/include/torch-mlir-c/ + + ENABLE_AGGREGATION + + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRIR + MLIRSupport + MLIRTensorRTTorchMLIR +) + +declare_mlir_python_sources(TorchMLIRPythonSources) +declare_mlir_python_sources(TorchMLIRPythonExtensions) + +set(TORCH_MLIR_PYTHON_ROOT_DIR "${torch_mlir_SOURCE_DIR}/python/torch_mlir") + +declare_mlir_python_sources(TorchMLIRPythonSources.Dialects + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources +) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT TorchMLIRPythonSources.Dialects + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + TD_FILE dialects/TorchBinding.td + SOURCES dialects/torch/__init__.py + DIALECT_NAME torch +) + +declare_mlir_python_sources(TorchMLIRPythonSources.Importers + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES + extras/fx_importer.py +) + +declare_mlir_python_sources(TorchMLIRPythonSources.PublicAPI + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES + compiler_utils.py + fx.py + extras/fx_decomp_util.py +) + +declare_mlir_python_sources(TorchMLIRPythonSources.Tools + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES + tools/opt/__main__.py +) + +declare_mlir_python_extension(TorchMLIRPythonExtensions.Main + MODULE_NAME _torchMlir + ADD_TO_PARENT TorchMLIRPythonExtensions + SOURCES TorchMLIRModule.cpp + EMBED_CAPI_LINK_LIBS + TorchMLIRCAPI + PRIVATE_LINK_LIBS + LLVMSupport +) diff --git a/mlir-tensorrt/third_party/torch-mlir-cmake/TorchMLIRModule.cpp b/mlir-tensorrt/third_party/torch-mlir-cmake/TorchMLIRModule.cpp new file mode 100644 index 000000000..d83b40326 --- /dev/null +++ b/mlir-tensorrt/third_party/torch-mlir-cmake/TorchMLIRModule.cpp @@ -0,0 +1,33 @@ +//===-- TorchBind.td - Torch dialect bind ------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "torch-mlir-c/Dialects.h" +#include "torch-mlir-c/Registration.h" + +namespace py = pybind11; + +PYBIND11_MODULE(_torchMlir, m) { + m.doc() = "torch-mlir main python extension"; + + m.def( + "register_dialect", + [](MlirContext context, bool load) { + MlirDialectHandle handle = mlirGetDialectHandle__torch__(); + mlirDialectHandleRegisterDialect(handle, context); + if (load) { + mlirDialectHandleLoadDialect(handle, context); + } + }, + py::arg("context"), py::arg("load") = true); + + m.def("get_int64_max", []() { return INT64_MAX; }); + + m.def("get_int64_min", []() { return INT64_MIN; }); +}