Skip to content

Commit 6e762fe

Browse files
authored
[CUDA] Migrate conv code to new cuDNN APIs (#2847)
1 parent 2b95d0c commit 6e762fe

File tree

8 files changed

+345
-608
lines changed

8 files changed

+345
-608
lines changed

mlx/backend/cuda/conv.cpp

Lines changed: 89 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,16 @@ namespace mlx::core {
1515

1616
namespace {
1717

18-
// Alias for better readability.
19-
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
20-
#define CONV_BACKWARD_INPUT \
21-
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
22-
#define CONV_BACKWARD_WEIGHT \
23-
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
24-
25-
// Custom placeholder representing fallback kernel.
26-
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
18+
enum ConvBackendType {
19+
CONV_FALLBACK,
20+
CONV_FORWARD,
21+
CONV_BACKWARD_INPUT,
22+
CONV_BACKWARD_WEIGHT,
23+
};
2724

2825
struct ConvCacheKey {
2926
int device_id;
30-
cudnnDataType_t cudnn_dtype;
27+
fe::DataType_t cudnn_dtype;
3128
std::array<int, MAX_NDIM> input_shape;
3229
std::array<int, MAX_NDIM> weight_shape;
3330
std::array<int, MAX_NDIM> stride;
@@ -44,15 +41,13 @@ struct ConvCacheKey {
4441
auto& conv_cache() {
4542
static LRUBytesKeyCache<
4643
ConvCacheKey,
47-
std::pair<
48-
cudnnBackendDescriptorType_t,
49-
std::optional<cudnn_frontend::ExecutionPlan>>>
44+
std::pair<ConvBackendType, std::optional<DnnGraph>>>
5045
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
5146
return cache;
5247
}
5348

54-
auto get_conv_op_settings(
55-
cudnnBackendDescriptorType_t backend_type,
49+
auto get_conv_settings(
50+
ConvBackendType backend_type,
5651
array& x,
5752
array& w,
5853
array& y,
@@ -68,8 +63,8 @@ auto get_conv_op_settings(
6863
for (int i = 0; i < padding_lo.size(); ++i) {
6964
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
7065
padding_lo[i] = wt_size - padding_lo[i] - 1;
71-
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
72-
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
66+
int in_size = 1 + kernel_strides[i] * (y.shape(1 + i) - 1);
67+
int out_size = 1 + input_dilation[i] * (x.shape(1 + i) - 1);
7368
padding_hi[i] = out_size - in_size + padding_hi[i];
7469
}
7570
return std::make_tuple(
@@ -95,49 +90,57 @@ auto get_conv_op_settings(
9590
}
9691
}
9792

98-
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
93+
std::optional<DnnGraph> build_conv_graph(
9994
cu::CommandEncoder& encoder,
100-
cudnnBackendDescriptorType_t backend_type,
95+
ConvBackendType backend_type,
10196
Dtype dtype,
10297
array& x,
10398
array& w,
10499
array& y,
105-
const SmallVector<int64_t>& stride,
106-
const SmallVector<int64_t>& padding_lo,
107-
const SmallVector<int64_t>& padding_hi,
108-
const SmallVector<int64_t>& dilation) {
109-
try {
110-
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
111-
? CUDNN_DATA_FLOAT
112-
: dtype_to_cudnn_type(dtype);
113-
auto conv_desc = cudnn_frontend::ConvDescBuilder()
114-
.setDataType(compute_dtype)
115-
.setMathMode(CUDNN_CROSS_CORRELATION)
116-
.setNDims(stride.size())
117-
.setStrides(stride.size(), stride.data())
118-
.setPrePadding(padding_lo.size(), padding_lo.data())
119-
.setPostPadding(padding_hi.size(), padding_hi.data())
120-
.setDilation(dilation.size(), dilation.data())
121-
.build();
122-
123-
auto op = cudnn_frontend::OperationBuilder(backend_type)
124-
.setxDesc(build_cudnn_tensor_nchw('x', x))
125-
.setwDesc(build_cudnn_tensor_nchw('w', w))
126-
.setyDesc(build_cudnn_tensor_nchw('y', y))
127-
.setcDesc(conv_desc)
128-
.build();
100+
const std::vector<int64_t>& stride,
101+
const std::vector<int64_t>& padding_lo,
102+
const std::vector<int64_t>& padding_hi,
103+
const std::vector<int64_t>& dilation) {
104+
auto compute_dtype =
105+
(dtype == float16 || dtype == bfloat16) ? float32 : dtype;
106+
DnnGraph graph(encoder.device().cudnn_handle(), dtype, compute_dtype);
107+
auto x_ = graph.tensor_nchw("X", 'x', x);
108+
auto w_ = graph.tensor_nchw("W", 'w', w);
109+
110+
auto set_options = [&](auto& options) {
111+
options.set_compute_data_type(dtype_to_cudnn_type(compute_dtype))
112+
.set_convolution_mode(fe::ConvolutionMode_t::CROSS_CORRELATION)
113+
.set_stride(stride)
114+
.set_pre_padding(padding_lo)
115+
.set_post_padding(padding_hi)
116+
.set_dilation(dilation);
117+
};
118+
119+
std::shared_ptr<fe::graph::Tensor_attributes> y_;
120+
if (backend_type == CONV_FORWARD) {
121+
auto options = fe::graph::Conv_fprop_attributes();
122+
set_options(options);
123+
y_ = graph.conv_fprop(x_, w_, options);
124+
} else if (backend_type == CONV_BACKWARD_INPUT) {
125+
auto options = fe::graph::Conv_dgrad_attributes();
126+
set_options(options);
127+
y_ = graph.conv_dgrad(x_, w_, options);
128+
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
129+
auto options = fe::graph::Conv_wgrad_attributes();
130+
set_options(options);
131+
y_ = graph.conv_wgrad(w_, x_, options);
132+
}
133+
graph.tensor_nchw(y_, 'y', y)->set_output(true);
129134

130-
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
131-
return cudnn_frontend::OperationGraphBuilder()
132-
.setHandle(encoder.device().cudnn_handle())
133-
.setOperationGraph(ops.size(), ops.data())
134-
.build();
135-
} catch (cudnn_frontend::cudnnException& error) {
136-
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
137-
throw;
138-
}
135+
if (graph.prepare().is_bad()) {
139136
return std::nullopt;
140137
}
138+
graph.deselect_numeric_notes({fe::NumericalNote_t::DOWN_CONVERT_INPUTS});
139+
if (dtype == float32 && !env::enable_tf32()) {
140+
graph.deselect_numeric_notes({fe::NumericalNote_t::TENSOR_CORE});
141+
}
142+
CHECK_CUDNN_FE_ERROR(graph.build());
143+
return graph;
141144
}
142145

143146
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
@@ -181,7 +184,7 @@ array group_transpose(
181184
// eval_gpu, with cost of possible redundant copies.
182185
std::tuple<array, array, array> prepare_args(
183186
cu::CommandEncoder& encoder,
184-
cudnnBackendDescriptorType_t backend_type,
187+
ConvBackendType backend_type,
185188
array in,
186189
array wt,
187190
array out,
@@ -221,27 +224,11 @@ std::tuple<array, array, array> prepare_args(
221224
return {std::move(in), std::move(wt), std::move(out)};
222225
}
223226

224-
// Get the x/w/y args from the in/wt/out args depending on backend type.
225-
inline std::tuple<array&, array&, array&> dispatch_args(
226-
cudnnBackendDescriptorType_t backend_type,
227-
array& in,
228-
array& wt,
229-
array& out) {
230-
switch (backend_type) {
231-
case CONV_BACKWARD_INPUT:
232-
return {out, wt, in};
233-
case CONV_BACKWARD_WEIGHT:
234-
return {in, out, wt};
235-
default:
236-
return {in, wt, out};
237-
}
238-
}
239-
240227
// Register inputs and outputs before actually running conv op. Can only be
241228
// called once per eval_gpu.
242229
void register_args(
243230
cu::CommandEncoder& encoder,
244-
cudnnBackendDescriptorType_t backend_type,
231+
ConvBackendType backend_type,
245232
array& in,
246233
array& wt,
247234
array& intermediate_out,
@@ -297,16 +284,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
297284
get_alignment(wt),
298285
get_alignment(out)};
299286
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
300-
auto& [backend_type, plan] = it->second;
301-
if (plan) {
302-
// Run cached plan.
287+
auto& [backend_type, graph] = it->second;
288+
if (graph) {
289+
// Run cached graph.
303290
std::tie(in, wt, out) =
304291
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
305292
register_args(encoder, backend_type, in, wt, out, out_);
306-
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
307-
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
308-
throw std::runtime_error("[conv] Cached plan failed to execute.");
309-
}
293+
CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
294+
encoder,
295+
{
296+
{'x', gpu_ptr<void>(in)},
297+
{'w', gpu_ptr<void>(wt)},
298+
{'y', gpu_ptr<void>(out)},
299+
}));
310300
} else {
311301
// Run fallback kernel.
312302
gemm_conv(
@@ -327,7 +317,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
327317

328318
// There is no reliable way to deduce the proper cuDNN backend for the
329319
// convolution, so we make a best guess and then try.
330-
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
320+
SmallVector<ConvBackendType, 2> try_backends;
331321
if (flip_) {
332322
// When weight is flipped, we assume it is backward input convolution.
333323
try_backends.push_back(CONV_BACKWARD_INPUT);
@@ -345,13 +335,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
345335
}
346336

347337
// Try to build op graph.
348-
cudnnBackendDescriptorType_t backend_type;
349-
std::optional<cudnn_frontend::OperationGraph> op_graph;
338+
ConvBackendType backend_type;
339+
std::optional<DnnGraph> graph;
350340
for (auto try_backend : try_backends) {
351-
auto [in_copy, wt_copy, out_copy] =
341+
auto [x, w, y] =
352342
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
353-
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
354-
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
343+
auto [stride, padding_lo, padding_hi, dilation] = get_conv_settings(
355344
try_backend,
356345
x,
357346
w,
@@ -361,7 +350,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
361350
padding_hi_,
362351
kernel_dilation_,
363352
input_dilation_);
364-
op_graph = build_conv_op_graph(
353+
graph = build_conv_graph(
365354
encoder,
366355
try_backend,
367356
dtype,
@@ -372,30 +361,27 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
372361
padding_lo,
373362
padding_hi,
374363
dilation);
375-
if (op_graph) {
364+
if (graph) {
376365
backend_type = try_backend;
377-
in = std::move(in_copy);
378-
wt = std::move(wt_copy);
379-
out = std::move(out_copy);
366+
in = std::move(x);
367+
wt = std::move(w);
368+
out = std::move(y);
380369
break;
381370
}
382371
}
383372

384-
if (op_graph) {
385-
// Find a plan for the graph and execute it.
386-
auto plan = find_cudnn_plan_from_op_graph(
387-
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
388-
if (plan) {
389-
// Setup inputs and outputs.
390-
register_args(encoder, backend_type, in, wt, out, out_);
391-
392-
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
393-
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
394-
conv_cache().emplace(
395-
cache_key, std::make_pair(backend_type, std::move(*plan)));
396-
return;
397-
}
398-
}
373+
if (graph) {
374+
register_args(encoder, backend_type, in, wt, out, out_);
375+
CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
376+
encoder,
377+
{
378+
{'x', gpu_ptr<void>(in)},
379+
{'w', gpu_ptr<void>(wt)},
380+
{'y', gpu_ptr<void>(out)},
381+
}));
382+
conv_cache().emplace(
383+
cache_key, std::make_pair(backend_type, std::move(*graph)));
384+
return;
399385
}
400386

401387
// Use fallback kernel for settings not supported by cuDNN.

mlx/backend/cuda/cuda_utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55
#include <cublasLt.h>
66
#include <cuda.h>
77
#include <cuda_runtime.h>
8+
#include <cudnn.h>
89

910
namespace mlx::core {
1011

1112
// Throw exception if the cuda API does not succeed.
1213
void check_cublas_error(const char* name, cublasStatus_t err);
1314
void check_cuda_error(const char* name, cudaError_t err);
1415
void check_cuda_error(const char* name, CUresult err);
16+
void check_cudnn_error(const char* name, cudnnStatus_t err);
1517

1618
// The macro version that prints the command that failed.
1719
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
1820
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
21+
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
1922

2023
// Base class for RAII managed CUDA resources.
2124
template <typename Handle, cudaError_t (*Destroy)(Handle)>

0 commit comments

Comments
 (0)