@@ -15,19 +15,16 @@ namespace mlx::core {
1515
1616namespace {
1717
18- // Alias for better readability.
19- #define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
20- #define CONV_BACKWARD_INPUT \
21- CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
22- #define CONV_BACKWARD_WEIGHT \
23- CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
24-
25- // Custom placeholder representing fallback kernel.
26- #define CONV_FALLBACK static_cast <cudnnBackendDescriptorType_t>(-1 )
18+ enum ConvBackendType {
19+ CONV_FALLBACK,
20+ CONV_FORWARD,
21+ CONV_BACKWARD_INPUT,
22+ CONV_BACKWARD_WEIGHT,
23+ };
2724
2825struct ConvCacheKey {
2926 int device_id;
30- cudnnDataType_t cudnn_dtype;
27+ fe::DataType_t cudnn_dtype;
3128 std::array<int , MAX_NDIM> input_shape;
3229 std::array<int , MAX_NDIM> weight_shape;
3330 std::array<int , MAX_NDIM> stride;
@@ -44,15 +41,13 @@ struct ConvCacheKey {
4441auto & conv_cache () {
4542 static LRUBytesKeyCache<
4643 ConvCacheKey,
47- std::pair<
48- cudnnBackendDescriptorType_t,
49- std::optional<cudnn_frontend::ExecutionPlan>>>
44+ std::pair<ConvBackendType, std::optional<DnnGraph>>>
5045 cache (" MLX_CUDA_CONV_CACHE_SIZE" , /* default_capacity */ 128 );
5146 return cache;
5247}
5348
54- auto get_conv_op_settings (
55- cudnnBackendDescriptorType_t backend_type,
49+ auto get_conv_settings (
50+ ConvBackendType backend_type,
5651 array& x,
5752 array& w,
5853 array& y,
@@ -68,8 +63,8 @@ auto get_conv_op_settings(
6863 for (int i = 0 ; i < padding_lo.size (); ++i) {
6964 int wt_size = 1 + kernel_dilation[i] * (w.shape (1 + i) - 1 );
7065 padding_lo[i] = wt_size - padding_lo[i] - 1 ;
71- int in_size = 1 + kernel_strides[i] * (x .shape (1 + i) - 1 );
72- int out_size = 1 + input_dilation[i] * (y .shape (1 + i) - 1 );
66+ int in_size = 1 + kernel_strides[i] * (y .shape (1 + i) - 1 );
67+ int out_size = 1 + input_dilation[i] * (x .shape (1 + i) - 1 );
7368 padding_hi[i] = out_size - in_size + padding_hi[i];
7469 }
7570 return std::make_tuple (
@@ -95,49 +90,57 @@ auto get_conv_op_settings(
9590 }
9691}
9792
98- std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph (
93+ std::optional<DnnGraph> build_conv_graph (
9994 cu::CommandEncoder& encoder,
100- cudnnBackendDescriptorType_t backend_type,
95+ ConvBackendType backend_type,
10196 Dtype dtype,
10297 array& x,
10398 array& w,
10499 array& y,
105- const SmallVector<int64_t >& stride,
106- const SmallVector<int64_t >& padding_lo,
107- const SmallVector<int64_t >& padding_hi,
108- const SmallVector<int64_t >& dilation) {
109- try {
110- auto compute_dtype = (dtype == float16 || dtype == bfloat16)
111- ? CUDNN_DATA_FLOAT
112- : dtype_to_cudnn_type (dtype);
113- auto conv_desc = cudnn_frontend::ConvDescBuilder ()
114- .setDataType (compute_dtype)
115- .setMathMode (CUDNN_CROSS_CORRELATION)
116- .setNDims (stride.size ())
117- .setStrides (stride.size (), stride.data ())
118- .setPrePadding (padding_lo.size (), padding_lo.data ())
119- .setPostPadding (padding_hi.size (), padding_hi.data ())
120- .setDilation (dilation.size (), dilation.data ())
121- .build ();
122-
123- auto op = cudnn_frontend::OperationBuilder (backend_type)
124- .setxDesc (build_cudnn_tensor_nchw (' x' , x))
125- .setwDesc (build_cudnn_tensor_nchw (' w' , w))
126- .setyDesc (build_cudnn_tensor_nchw (' y' , y))
127- .setcDesc (conv_desc)
128- .build ();
100+ const std::vector<int64_t >& stride,
101+ const std::vector<int64_t >& padding_lo,
102+ const std::vector<int64_t >& padding_hi,
103+ const std::vector<int64_t >& dilation) {
104+ auto compute_dtype =
105+ (dtype == float16 || dtype == bfloat16) ? float32 : dtype;
106+ DnnGraph graph (encoder.device ().cudnn_handle (), dtype, compute_dtype);
107+ auto x_ = graph.tensor_nchw (" X" , ' x' , x);
108+ auto w_ = graph.tensor_nchw (" W" , ' w' , w);
109+
110+ auto set_options = [&](auto & options) {
111+ options.set_compute_data_type (dtype_to_cudnn_type (compute_dtype))
112+ .set_convolution_mode (fe::ConvolutionMode_t::CROSS_CORRELATION)
113+ .set_stride (stride)
114+ .set_pre_padding (padding_lo)
115+ .set_post_padding (padding_hi)
116+ .set_dilation (dilation);
117+ };
118+
119+ std::shared_ptr<fe::graph::Tensor_attributes> y_;
120+ if (backend_type == CONV_FORWARD) {
121+ auto options = fe::graph::Conv_fprop_attributes ();
122+ set_options (options);
123+ y_ = graph.conv_fprop (x_, w_, options);
124+ } else if (backend_type == CONV_BACKWARD_INPUT) {
125+ auto options = fe::graph::Conv_dgrad_attributes ();
126+ set_options (options);
127+ y_ = graph.conv_dgrad (x_, w_, options);
128+ } else if (backend_type == CONV_BACKWARD_WEIGHT) {
129+ auto options = fe::graph::Conv_wgrad_attributes ();
130+ set_options (options);
131+ y_ = graph.conv_wgrad (w_, x_, options);
132+ }
133+ graph.tensor_nchw (y_, ' y' , y)->set_output (true );
129134
130- std::array<cudnn_frontend::Operation const *, 1 > ops = {&op};
131- return cudnn_frontend::OperationGraphBuilder ()
132- .setHandle (encoder.device ().cudnn_handle ())
133- .setOperationGraph (ops.size (), ops.data ())
134- .build ();
135- } catch (cudnn_frontend::cudnnException& error) {
136- if (error.getCudnnStatus () != CUDNN_STATUS_BAD_PARAM) {
137- throw ;
138- }
135+ if (graph.prepare ().is_bad ()) {
139136 return std::nullopt ;
140137 }
138+ graph.deselect_numeric_notes ({fe::NumericalNote_t::DOWN_CONVERT_INPUTS});
139+ if (dtype == float32 && !env::enable_tf32 ()) {
140+ graph.deselect_numeric_notes ({fe::NumericalNote_t::TENSOR_CORE});
141+ }
142+ CHECK_CUDNN_FE_ERROR (graph.build ());
143+ return graph;
141144}
142145
143146// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
@@ -181,7 +184,7 @@ array group_transpose(
181184// eval_gpu, with cost of possible redundant copies.
182185std::tuple<array, array, array> prepare_args (
183186 cu::CommandEncoder& encoder,
184- cudnnBackendDescriptorType_t backend_type,
187+ ConvBackendType backend_type,
185188 array in,
186189 array wt,
187190 array out,
@@ -221,27 +224,11 @@ std::tuple<array, array, array> prepare_args(
221224 return {std::move (in), std::move (wt), std::move (out)};
222225}
223226
224- // Get the x/w/y args from the in/wt/out args depending on backend type.
225- inline std::tuple<array&, array&, array&> dispatch_args (
226- cudnnBackendDescriptorType_t backend_type,
227- array& in,
228- array& wt,
229- array& out) {
230- switch (backend_type) {
231- case CONV_BACKWARD_INPUT:
232- return {out, wt, in};
233- case CONV_BACKWARD_WEIGHT:
234- return {in, out, wt};
235- default :
236- return {in, wt, out};
237- }
238- }
239-
240227// Register inputs and outputs before actually running conv op. Can only be
241228// called once per eval_gpu.
242229void register_args (
243230 cu::CommandEncoder& encoder,
244- cudnnBackendDescriptorType_t backend_type,
231+ ConvBackendType backend_type,
245232 array& in,
246233 array& wt,
247234 array& intermediate_out,
@@ -297,16 +284,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
297284 get_alignment (wt),
298285 get_alignment (out)};
299286 if (auto it = conv_cache ().find (cache_key); it != conv_cache ().end ()) {
300- auto & [backend_type, plan ] = it->second ;
301- if (plan ) {
302- // Run cached plan .
287+ auto & [backend_type, graph ] = it->second ;
288+ if (graph ) {
289+ // Run cached graph .
303290 std::tie (in, wt, out) =
304291 prepare_args (encoder, backend_type, in, wt, out, groups_, s);
305292 register_args (encoder, backend_type, in, wt, out, out_);
306- auto [x, w, y] = dispatch_args (backend_type, in, wt, out);
307- if (!encode_cudnn_plan (encoder, *plan, {' x' , ' w' , ' y' }, x, w, y)) {
308- throw std::runtime_error (" [conv] Cached plan failed to execute." );
309- }
293+ CHECK_CUDNN_FE_ERROR (graph->encode_capturing (
294+ encoder,
295+ {
296+ {' x' , gpu_ptr<void >(in)},
297+ {' w' , gpu_ptr<void >(wt)},
298+ {' y' , gpu_ptr<void >(out)},
299+ }));
310300 } else {
311301 // Run fallback kernel.
312302 gemm_conv (
@@ -327,7 +317,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
327317
328318 // There is no reliable way to deduce the proper cuDNN backend for the
329319 // convolution, so we make a best guess and then try.
330- SmallVector<cudnnBackendDescriptorType_t , 2 > try_backends;
320+ SmallVector<ConvBackendType , 2 > try_backends;
331321 if (flip_) {
332322 // When weight is flipped, we assume it is backward input convolution.
333323 try_backends.push_back (CONV_BACKWARD_INPUT);
@@ -345,13 +335,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
345335 }
346336
347337 // Try to build op graph.
348- cudnnBackendDescriptorType_t backend_type;
349- std::optional<cudnn_frontend::OperationGraph> op_graph ;
338+ ConvBackendType backend_type;
339+ std::optional<DnnGraph> graph ;
350340 for (auto try_backend : try_backends) {
351- auto [in_copy, wt_copy, out_copy ] =
341+ auto [x, w, y ] =
352342 prepare_args (encoder, try_backend, in, wt, out, groups_, s);
353- auto [x, w, y] = dispatch_args (try_backend, in_copy, wt_copy, out_copy);
354- auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings (
343+ auto [stride, padding_lo, padding_hi, dilation] = get_conv_settings (
355344 try_backend,
356345 x,
357346 w,
@@ -361,7 +350,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
361350 padding_hi_,
362351 kernel_dilation_,
363352 input_dilation_);
364- op_graph = build_conv_op_graph (
353+ graph = build_conv_graph (
365354 encoder,
366355 try_backend,
367356 dtype,
@@ -372,30 +361,27 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
372361 padding_lo,
373362 padding_hi,
374363 dilation);
375- if (op_graph ) {
364+ if (graph ) {
376365 backend_type = try_backend;
377- in = std::move (in_copy );
378- wt = std::move (wt_copy );
379- out = std::move (out_copy );
366+ in = std::move (x );
367+ wt = std::move (w );
368+ out = std::move (y );
380369 break ;
381370 }
382371 }
383372
384- if (op_graph) {
385- // Find a plan for the graph and execute it.
386- auto plan = find_cudnn_plan_from_op_graph (
387- encoder.device ().cudnn_handle (), backend_type, dtype, *op_graph);
388- if (plan) {
389- // Setup inputs and outputs.
390- register_args (encoder, backend_type, in, wt, out, out_);
391-
392- auto [x, w, y] = dispatch_args (backend_type, in, wt, out);
393- if (encode_cudnn_plan (encoder, *plan, {' x' , ' w' , ' y' }, x, w, y)) {
394- conv_cache ().emplace (
395- cache_key, std::make_pair (backend_type, std::move (*plan)));
396- return ;
397- }
398- }
373+ if (graph) {
374+ register_args (encoder, backend_type, in, wt, out, out_);
375+ CHECK_CUDNN_FE_ERROR (graph->encode_capturing (
376+ encoder,
377+ {
378+ {' x' , gpu_ptr<void >(in)},
379+ {' w' , gpu_ptr<void >(wt)},
380+ {' y' , gpu_ptr<void >(out)},
381+ }));
382+ conv_cache ().emplace (
383+ cache_key, std::make_pair (backend_type, std::move (*graph)));
384+ return ;
399385 }
400386
401387 // Use fallback kernel for settings not supported by cuDNN.
0 commit comments