Skip to content

Commit 70582e0

Browse files
[mlir-tensorrt] Integrate internal changes
--- [compiler] Add cuda.get_program_device op Introduce `cuda.get_program_device` as a pure/speculatable way to map a program logical device id (i32) to a CUDA device ordinal (i32). GitOrigin-RevId: 00512cc5a9e9c61023e1d9de734b2383da369bcf --- [compiler] Refactor device management and stream creation utilities This commit introduces a new device management model to support multi-device SPMD and MPMD programs and refactors stream creation to use reusable utility functions. The primary motivation is to enable more flexible device assignment where programs can be assigned to specific CUDA ordinals via logical device IDs, laying the groundwork for better multi-device support. GitOrigin-RevId: 447b72743e64f394671f866fcdfdb0d6f0f3d579 ---[compiler|executor] Refactor plugin call stream handling This change refactors how CUDA streams are handled for plugin calls in the executor dialect. Previously, when no stream was provided to a CallPluginOp, the lowering would create and use a global CUDA stream (stream0). This approach had several issues: 1. It tightly coupled the executor dialect to CUDA-specific stream creation 2. It required maintaining global stream state across compilation 3. It made the stream handling implicit and harder to reason about The new approach uses null streams (nullptr) when no explicit stream is provided. This is the standard CUDA convention where a null stream represents the default stream. The changes include: - Modified `executor.call_plugin` op to accept any type for the stream operand (not just `!executor.ptr<host>`), allowing frontend dialects to pass their own stream representations (e.g. `!cuda.stream`) - Updated the assembly format to print the stream type for clarity - Removed `getGlobalCudaStream` helper method from ConvertToExecutorPattern - Changed CallPluginConversionPattern to create a null pointer (inttoptr 0) when no stream is provided instead of creating a global stream - Updated StablehloToPlan conversion to use `cuda::getOrCreateDefaultStream0` to explicitly create CUDA streams when converting TVM FFI custom calls - Added CUDADialect dependency to StablehloToPlan pass and CMakeLists This makes stream handling more explicit and flexible, allowing different frontend dialects to manage their own stream creation while falling back to null streams (CUDA default stream) when appropriate. GitOrigin-RevId: 764238bc58308d5d284f8e32da91c7e5f90fdf0c
1 parent ef65735 commit 70582e0

File tree

71 files changed

+774
-433
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+774
-433
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//===- PassManagerUtils.h --------------------------------------*- C++ -*-===//
2+
//
3+
// SPDX-FileCopyrightText: Copyright 2026 NVIDIA CORPORATION & AFFILIATES.
4+
// All rights reserved.
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// Utilities for mlir::OpPassManager.
22+
///
23+
//===----------------------------------------------------------------------===//
24+
#ifndef MLIR_TENSORRT_COMMON_UTILS_PASSMANAGERUTILS
25+
#define MLIR_TENSORRT_COMMON_UTILS_PASSMANAGERUTILS
26+
27+
#include "mlir/Pass/PassManager.h"
28+
29+
namespace mlir {
30+
31+
/// Add nested passes to the given pass manager for the given operation type.
32+
template <typename OpT>
33+
static void
34+
addNestedPasses(OpPassManager &pm,
35+
llvm::function_ref<void(OpPassManager &)> addPasses) {
36+
auto &nestedPM = pm.nest<OpT>();
37+
addPasses(nestedPM);
38+
}
39+
40+
} // namespace mlir
41+
42+
#endif // MLIR_TENSORRT_COMMON_UTILS_PASSMANAGERUTILS

mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def ConvertStablehloToPlanPass : Pass<"convert-stablehlo-to-plan", "::mlir::Modu
4444
"::mlir::cf::ControlFlowDialect",
4545
"::mlir::plan::PlanDialect",
4646
"::mlir::executor::ExecutorDialect",
47-
"::mlir::tensor::TensorDialect"
47+
"::mlir::tensor::TensorDialect",
48+
"::mlir::cuda::CUDADialect"
4849
];
4950
}
5051

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/CUDA/IR/CUDAOps.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,23 @@ def CUDA_GetActiveDeviceOp : CUDA_Op<"get_active_device", [
9393
let assemblyFormat = "attr-dict";
9494
}
9595

96+
def CUDA_GetProgramDeviceOp : CUDA_Op<"get_program_device", [
97+
Pure, AlwaysSpeculatable]> {
98+
let summary = "returns the CUDA device ordinal associated with a program logical device";
99+
let description = [{
100+
Returns the CUDA device ordinal for the given program "logical device"
101+
identifier.
102+
103+
This operation is intended to support compilation modes where device
104+
selection is modeled explicitly and provided by the runtime when the
105+
program is loaded (e.g. via a constant mapping table). As a result, this
106+
operation is marked as being speculatable and side-effect free.
107+
}];
108+
let arguments = (ins I32:$logicalDevice);
109+
let results = (outs I32:$result);
110+
let assemblyFormat = "attr-dict $logicalDevice `:` type($result)";
111+
}
112+
96113
def CUDA_SetActiveDeviceOp : CUDA_Op<"set_active_device", [
97114
MemoryEffects<[MemWrite]>]> {
98115
let summary = "sets the active CUDA device context";
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//===- CUDAUtils.h ----------------------------------------------*- C++ -*-===//
2+
//
3+
// SPDX-FileCopyrightText: Copyright 2026 NVIDIA CORPORATION & AFFILIATES.
4+
// All rights reserved.
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// Utility functions for the CUDA dialect.
22+
///
23+
//===----------------------------------------------------------------------===//
24+
#ifndef MLIR_TENSORRT_DIALECT_CUDA_UTILS_CUDAUTILS_H
25+
#define MLIR_TENSORRT_DIALECT_CUDA_UTILS_CUDAUTILS_H
26+
27+
#include "mlir/IR/Block.h"
28+
namespace mlir {
29+
class Operation;
30+
class Value;
31+
class Location;
32+
class RewriterBase;
33+
class PatternRewriter;
34+
35+
namespace cuda {
36+
37+
/// Create a default stream (stream 0) on device 0. This creates:
38+
/// - A constant 0 for the device index
39+
/// - A cuda.get_program_device operation
40+
/// - A cuda.stream.create operation with index 0
41+
Value createDefaultStream0(RewriterBase &rewriter, Location loc);
42+
43+
/// Go over the operations in Block (containing anchor) from the first operation
44+
/// in the Block to the point before `anchor`. If we find a `cuda.stream.create`
45+
/// operation matching the pattern produced by `createDefaultStream0`, return
46+
/// the result of that operation. Otherwise, call createDefaultStream0 to create
47+
/// a new stream at the beginning of the block.
48+
Value getOrCreateDefaultStream0(RewriterBase &rewriter, Operation *anchor);
49+
50+
/// Go over the operations in Block (containing anchor point) from the first
51+
/// operation in the Block to the point before `anchor point`. If we find a
52+
/// `cuda.stream.create` operation matching the pattern produced by
53+
/// `createDefaultStream0`, return the result of that operation. Otherwise, call
54+
/// createDefaultStream0 to create a new stream at the beginning of the block.
55+
Value getOrCreateDefaultStream0(RewriterBase &rewriter, Location loc,
56+
Block::iterator anchorPoint);
57+
58+
} // namespace cuda
59+
} // namespace mlir
60+
61+
#endif // MLIR_TENSORRT_DIALECT_CUDA_UTILS_CUDAUTILS_H

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ inline llvm::cl::ValuesClass createInputKindClOptions() {
4343
clEnumValN(InputKind::TensorRT, "tensorrt", "TensorRT IR"),
4444
clEnumValN(InputKind::Linalg, "linalg", "Linalg IR"));
4545
}
46+
4647
} // namespace detail
4748

4849
struct ClusterTargetOption;
@@ -83,7 +84,6 @@ struct PlanClusteringOptions : public mlir::OptionsGroup {
8384
void buildPlanSegmentationPipeline(OpPassManager &pm, int abiVersion,
8485
plan::InputKind inputKind,
8586
bool entrypointUsesAllocCConv,
86-
llvm::StringRef entrypoint,
8787
const plan::PlanClusteringOptions &opts);
8888

8989
struct PlanBufferizationOptions {

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -327,19 +327,19 @@ def ClusteringPass : Pass<"plan-clustering", "::mlir::ModuleOp"> {
327327
is to achieve a course segmentation that specifies how clusters of
328328
operations will be compiled.
329329

330+
The pass processes all functions in the module that:
331+
- Are not declarations or external functions
332+
- Do not already have a `plan.cluster_kind` attribute
333+
- Are not private functions with `plan.decomposition` attribute
334+
330335
The kinds of clusters that can be formed and the specific rules for
331336
clustering are defined by the clustering configuration specified
332337
by the module's `plan.backends` attribute. This is an array of
333338
attributes which all implement the
334339
[CompilerBackendAttrInterface](../IR/PlanInterfaces.td).
335340
}];
336341

337-
let options =
338-
[Option<"entrypoint", "entrypoint", "std::string", "\"\"",
339-
"the name of the entrypoint function; if empty then the "
340-
"clustering runs"
341-
" on all functions">,
342-
InputKindOption];
342+
let options = [InputKindOption];
343343
}
344344

345345
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StablehloExt/Transforms/Passes.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ def RefineShapesPass : Pass<"stablehlo-ext-refine-shapes", "ModuleOp"> {
106106
`stablehlo-refine-shapes` patterns as well as some additional patterns
107107
for handling `tensor.cast` operations.
108108
}];
109+
110+
let options = [
111+
Option<"interprocedural", "interprocedural", "bool", "true",
112+
"whether to try to simplify function types">
113+
];
109114
}
110115

111116
//===----------------------------------------------------------------------===//
@@ -126,7 +131,9 @@ def CanonicalizeShapesPass : Pass<"stablehlo-ext-canonicalize-shapes", "ModuleOp
126131
let options = [
127132
Option<"maxIterations", "max-iterations", "int64_t", "8",
128133
"the maximum number of iterations to run the dynamism simplification and "
129-
"shape refinement if a fixed-point is not reached">
134+
"shape refinement if a fixed-point is not reached">,
135+
Option<"interprocedural", "interprocedural", "bool", "true",
136+
"whether to try to simplify function types">
130137
];
131138
}
132139

mlir-tensorrt/compiler/lib/Compiler/Extensions/KernelGenExtension/KernelGenExtension.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,6 @@ void KernelGenExtension::populatePasses(mlir::OpPassManager &pm,
214214
pm.addPass(createConvertKernelToCUDAPass());
215215
return;
216216
}
217-
218-
if (point == ExtensionPoint::ExecutorLowering) {
219-
return;
220-
}
221217
}
222218

223219
//===----------------------------------------------------------------------===//
Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//===- LinalgInputPipeline.cpp --------------------------------------------===//
22
//
3-
// SPDX-FileCopyrightText: Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES.
3+
// SPDX-FileCopyrightText: Copyright 2024-2026 NVIDIA CORPORATION & AFFILIATES.
44
// All rights reserved.
55
// SPDX-License-Identifier: Apache-2.0
66
//
@@ -10,11 +10,10 @@
1010
///
1111
//===----------------------------------------------------------------------===//
1212
#include "mlir-tensorrt/Compiler/InputPipelines/LinalgInputPipeline.h"
13+
#include "mlir-tensorrt-common/Utils/PassManagerUtils.h"
1314
#include "mlir-tensorrt/Transforms/Passes.h"
1415
#include "mlir/Dialect/Func/IR/FuncOps.h"
1516
#include "mlir/Dialect/Linalg/Passes.h"
16-
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17-
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
1817
#include "mlir/Pass/PassManager.h"
1918
#include "mlir/Pass/PassOptions.h"
2019
#include "mlir/Transforms/Passes.h"
@@ -28,12 +27,13 @@ llvm::cl::OptionCategory LinalgInputOptions::category = {
2827

2928
void mtrt::compiler::buildLinalgInputPipeline(OpPassManager &pm,
3029
const LinalgInputOptions &opts) {
31-
OpPassManager &funcPM = pm.nest<func::FuncOp>();
32-
funcPM.addPass(mlir::createLinalgGeneralizeNamedOpsPass());
33-
if (opts.enableLinalgElementwiseFusion)
34-
funcPM.addPass(mtrt::createLinalgElementwiseFusionPass());
35-
funcPM.addPass(mtrt::createLinalgSimplifyExtractSlicePass());
36-
funcPM.addPass(mtrt::createTensorExtPadToInsertSlicePass());
37-
funcPM.addPass(mlir::createCSEPass());
38-
funcPM.addPass(mlir::createCanonicalizerPass());
30+
addNestedPasses<func::FuncOp>(pm, [&opts](OpPassManager &funcPM) {
31+
funcPM.addPass(mlir::createLinalgGeneralizeNamedOpsPass());
32+
if (opts.enableLinalgElementwiseFusion)
33+
funcPM.addPass(mtrt::createLinalgElementwiseFusionPass());
34+
funcPM.addPass(mtrt::createLinalgSimplifyExtractSlicePass());
35+
funcPM.addPass(mtrt::createTensorExtPadToInsertSlicePass());
36+
funcPM.addPass(mlir::createCSEPass());
37+
funcPM.addPass(mlir::createCanonicalizerPass());
38+
});
3939
}

0 commit comments

Comments
 (0)