Skip to content

Commit 13a941a

Browse files
ezhulenevcopybara-github
authored andcommitted
[xla:gpu] Add support for capturing lmhlo_gpu.conv operations into cuda graphs
PiperOrigin-RevId: 501950550
1 parent 272ce1a commit 13a941a

File tree

4 files changed

+30
-12
lines changed

4 files changed

+30
-12
lines changed

xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ limitations under the License.
3737
#include "xla/mlir/runtime/ir/rt_dialect.h"
3838
#include "xla/mlir/runtime/ir/rt_ops.h"
3939
#include "xla/mlir/runtime/utils/custom_calls.h"
40+
#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h"
4041

4142
namespace xla {
4243
namespace gpu {
@@ -80,23 +81,29 @@ using CaptureSequence =
8081

8182
//===----------------------------------------------------------------------===//
8283

83-
template <typename T, OpCapturePattern::Capture capture>
84+
template <OpCapturePattern::Capture capture, typename T, typename... Ts>
8485
struct OpCapture : public OpCapturePattern {
8586
FailureOr<OpCapturePattern::Capture> match(Operation* op) final {
86-
if (isa<T>(op)) return capture;
87+
if (isa<T, Ts...>(op)) return capture;
8788
return failure();
8889
}
8990
};
9091

9192
static constexpr auto kMove = OpCapturePattern::Capture::kMove;
9293
static constexpr auto kClone = OpCapturePattern::Capture::kClone;
9394

95+
template <typename T, typename... Ts>
96+
using MoveOp = OpCapture<kMove, T, Ts...>;
97+
template <typename T, typename... Ts>
98+
using CloneOp = OpCapture<kClone, T, Ts...>;
99+
94100
// Capture gpu operations by moving them intp graph capture function.
95-
struct LaunchFuncOpCapture : public OpCapture<LaunchFuncOp, kMove> {};
101+
struct LaunchFuncOpCapture : public MoveOp<LaunchFuncOp> {};
102+
struct ConvOpCapture : public MoveOp<lmhlo_gpu::ConvForwardFusedOp> {};
96103

97104
// Capture pure operations by cloning them into graph capture function.
98-
struct ConstantOpCapture : public OpCapture<arith::ConstantOp, kClone> {};
99-
struct ViewOpCapture : public OpCapture<memref::ViewOp, kClone> {};
105+
struct ConstantOpCapture : public CloneOp<arith::ConstantOp> {};
106+
struct ViewOpCapture : public CloneOp<memref::ViewOp> {};
100107

101108
//===----------------------------------------------------------------------===//
102109

@@ -320,6 +327,7 @@ void OutlineCudaGraphsPass::runOnOperation() {
320327

321328
OpCapturePatternSet patterns;
322329
patterns.emplace_back(new LaunchFuncOpCapture());
330+
patterns.emplace_back(new ConvOpCapture());
323331
patterns.emplace_back(new ConstantOpCapture());
324332
patterns.emplace_back(new ViewOpCapture());
325333

xla/mlir/backends/gpu/transforms/passes.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,22 @@ void populateXlaGpuRuntimePasses(mlir::OpPassManager& pm,
3232
// Lower operations with registered IR emitters to Gpu launches.
3333
pm.addPass(createConvertLmhloToGpuLaunchPass(thunk_sequence));
3434

35+
// Clean up IR before converting it to the runtime operations.
36+
pm.addPass(createCSEPass());
37+
pm.addPass(createCanonicalizerPass());
38+
3539
// Convert global memrefs corresponding to constant arguments.
3640
pm.addPass(createConvertMemrefGetGlobalToArgPass());
3741
pm.addPass(createSymbolDCEPass()); // Clean up unused global constants.
3842

39-
// Lower all Gpu operations to the XLA Gpu runtime custom calls.
40-
pm.addPass(createConvertLmhloGpuToGpuRuntimePass());
41-
pm.addPass(createConvertLmhloToGpuRuntimePass());
42-
43+
// Outline CUDA-Graph-compatible operations into graph capture functions.
4344
if (opts.enable_cuda_graphs) {
4445
pm.addPass(createOutlineCudaGraphsPass());
4546
}
4647

48+
// Lower all Gpu operations to the XLA Gpu runtime custom calls.
49+
pm.addPass(createConvertLmhloGpuToGpuRuntimePass());
50+
pm.addPass(createConvertLmhloToGpuRuntimePass());
4751
pm.addPass(createConvertGpuToGpuRuntimePass());
4852

4953
// Add performance tracing annotations.

xla/service/gpu/runtime/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ cc_library(
165165
srcs = ["graph_launch.cc"],
166166
hdrs = ["graph_launch.h"],
167167
deps = [
168+
":conv",
168169
":kernel_launch",
169170
":support",
170171
"//xla:types",

xla/service/gpu/runtime/graph_launch.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525

2626
#include "xla/runtime/custom_call.h"
2727
#include "xla/runtime/executable.h"
28+
#include "xla/service/gpu/runtime/conv.h"
2829
#include "xla/service/gpu/runtime/kernel_launch.h"
2930
#include "xla/service/gpu/runtime/support.h"
3031
#include "xla/service/service_executable_run_options.h"
@@ -204,9 +205,11 @@ static absl::StatusOr<OwnedGraph> CaptureGraph(
204205
//===----------------------------------------------------------------------===//
205206

206207
static absl::Status LaunchGraph(
207-
const ServiceExecutableRunOptions* run_options, const std::string* ptx,
208+
const ServiceExecutableRunOptions* run_options,
209+
const DebugOptions* debug_options, const std::string* ptx,
208210
const std::vector<uint8_t>* cubin, se::DeviceMemoryBase* temp_buffer,
209211
StreamExecutorKernels::Snapshot* kernels,
212+
StreamExecutorConvRunners::Snapshot* convs,
210213
GraphInstances::Snapshot* instances, runtime::Executable* executable,
211214
CustomCall::RemainingArgs fwd_args, CustomCall::FunctionOrdinal capture) {
212215
#if GOOGLE_CUDA
@@ -220,8 +223,8 @@ static absl::Status LaunchGraph(
220223

221224
// Forwards user data required for launching kernels.
222225
auto user_data = [&] {
223-
return CustomCall::UserData(run_options, ptx, cubin, temp_buffer, kernels,
224-
executable);
226+
return CustomCall::UserData(run_options, debug_options, ptx, cubin,
227+
temp_buffer, kernels, convs, executable);
225228
};
226229

227230
absl::StatusOr<GraphInstance*> instance = instances->GetOrCreate(
@@ -294,10 +297,12 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL(
294297
Launch, FunctionWrapper<LaunchGraph>(), checks,
295298
CustomCall::Bind("xla.gpu.cuda.graph.launch")
296299
.UserData<const ServiceExecutableRunOptions*>()
300+
.UserData<const DebugOptions*>()
297301
.UserData<const std::string*>()
298302
.UserData<const std::vector<uint8_t>*>()
299303
.UserData<se::DeviceMemoryBase*>()
300304
.UserData<StreamExecutorKernels::Snapshot*>()
305+
.UserData<StreamExecutorConvRunners::Snapshot*>()
301306
.UserData<GraphInstances::Snapshot*>()
302307
.UserData<Executable*>()
303308
.RemainingArgs()

0 commit comments

Comments
 (0)