Skip to content

XLA Convolution algorithms fail on GPU #1166

@rknetemann

Description

@rknetemann

Dear Patrick,

It seems like I am having some issues with training convolutional neural networks due to a JaxRuntimeError that keeps arising when running it on the GPU. Changing the platform to CPU removes the error (but of course I want to be able to use GPU).

Errors:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/raymo/Projects/nlsid/pipeline/v1/3_train_neural_network/train_neural_network.py", line 495, in <module>
    model = train(
            ^^^^^^
  File "/home/raymo/Projects/nlsid/pipeline/v1/3_train_neural_network/train_neural_network.py", line 262, in train
    model, opt_state, train_loss = make_step(model, opt_state, simulations_batch)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raymo/Projects/nlsid/.venv/lib/python3.12/site-packages/equinox/_jit.py", line 209, in __call__
    return _call(self, False, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raymo/Projects/nlsid/.venv/lib/python3.12/site-packages/equinox/_jit.py", line 263, in _call
    marker, _, _ = out = jit_wrapper._cached(
                         ^^^^^^^^^^^^^^^^^^^^
jax.errors.JaxRuntimeError: INTERNAL: No reference output found!

I know Equinox to be buggy with jax=>0.7.0, so I also tried it with jax==0.6.2, giving me a similar, but more detailed traceback:

2026-01-13 15:49:12.945627: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:847] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms.
2026-01-13 15:49:12.945647: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:850] Conv: %cudnn-conv-bw-input = (f32[512,2,502]{2,1,0}, u8[0]{0}) custom-call(%select.2, %Arg_2.3), window={size=10}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBackwardInput", metadata={op_name="jit(make_step)/jit(main)/transpose(jvp(jit(loss)))/vmap(eqx.nn.Conv)/conv_general_dilated" source_file="/home/raymo/Projects/nlsid/.venv/lib/python3.12/site-packages/equinox/nn/_conv.py" source_line=239}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[]}
2026-01-13 15:49:12.948158: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:1077] Failed to determine best cudnn convolution algorithm for:
%cudnn-conv-bw-input = (f32[512,2,502]{2,1,0}, u8[0]{0}) custom-call(%select.2, %Arg_2.3), window={size=10}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBackwardInput", metadata={op_name="jit(make_step)/jit(main)/transpose(jvp(jit(loss)))/vmap(eqx.nn.Conv)/conv_general_dilated" source_file="/home/raymo/Projects/nlsid/.venv/lib/python3.12/site-packages/equinox/nn/_conv.py" source_line=239}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[]}

Original error: INTERNAL: All algorithms tried for (f32[512,2,502]{2,1,0}, u8[0]{0}) custom-call(f32[512,1,493]{2,1,0}, f32[1,2,10]{2,1,0}), window={size=10}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBackwardInput", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","reification_cost":[],"wait_on_operation_queues":[]} failed. Falling back to default algorithm.  Per-algorithm errors:
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng25{k2=2,k3=0}: UNKNOWN: CUDNN_STATUS_INTERNAL_ERROR
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng1{k2=3,k3=0}: UNKNOWN: CUDNN_STATUS_INTERNAL_ERROR
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng25{k2=0,k3=0}: UNKNOWN: CUDNN_STATUS_INTERNAL_ERROR
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng1{k2=1,k3=0}: UNKNOWN: CUDNN_STATUS_INTERNAL_ERROR
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng1{}: UNKNOWN: CUDNN_STATUS_INTERNAL_ERROR
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng25{}: UNKNOWN: CUDNN_STATUS_INTERNAL_ERROR
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'

As a result, convolution performance may be suboptimal.
2026-01-13 15:49:12.950834: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:847] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms.
2026-01-13 15:49:12.950849: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:850] Conv: %cudnn-conv-bw-filter = (f32[2,2,100]{2,1,0}, u8[0]{0}) custom-call(%concatenate.2, %select.3), window={size=100}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBackwardFilter", metadata={op_name="jit(make_step)/jit(main)/transpose(jvp(jit(loss)))/vmap(eqx.nn.Conv)/conv_general_dilated" source_file="/home/raymo/Projects/nlsid/.venv/lib/python3.12/site-packages/equinox/nn/_conv.py" source_line=239}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[]}
2026-01-13 15:49:12.951836: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:1077] Failed to determine best cudnn convolution algorithm for:
%cudnn-conv-bw-filter = (f32[2,2,100]{2,1,0}, u8[0]{0}) custom-call(%concatenate.2, %select.3), window={size=100}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBackwardFilter", metadata={op_name="jit(make_step)/jit(main)/transpose(jvp(jit(loss)))/vmap(eqx.nn.Conv)/conv_general_dilated" source_file="/home/raymo/Projects/nlsid/.venv/lib/python3.12/site-packages/equinox/nn/_conv.py" source_line=239}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[]}

Original error: INTERNAL: All algorithms tried for (f32[2,2,100]{2,1,0}, u8[0]{0}) custom-call(f32[512,2,601]{2,1,0}, f32[512,2,502]{2,1,0}), window={size=100}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBackwardFilter", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","reification_cost":[],"wait_on_operation_queues":[]} failed. Falling back to default algorithm.  Per-algorithm errors:
  Profiling failure on cuDNN engine eng20{k2=6,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng20{k2=7,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng1{k2=6,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng20{k2=2,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng1{k2=1,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng1{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng20{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng20{k2=5,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'

As a result, convolution performance may be suboptimal.
2026-01-13 15:49:12.953197: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:847] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms.
2026-01-13 15:49:12.953212: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:850] Conv: %cudnn-conv-bw-filter.1 = (f32[1,2,10]{2,1,0}, u8[0]{0}) custom-call(%maximum.3, %select.2), window={size=10}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBackwardFilter", metadata={op_name="jit(make_step)/jit(main)/transpose(jvp(jit(loss)))/vmap(eqx.nn.Conv)/conv_general_dilated" source_file="/home/raymo/Projects/nlsid/.venv/lib/python3.12/site-packages/equinox/nn/_conv.py" source_line=239}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[]}
2026-01-13 15:49:12.954353: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:1077] Failed to determine best cudnn convolution algorithm for:
%cudnn-conv-bw-filter.1 = (f32[1,2,10]{2,1,0}, u8[0]{0}) custom-call(%maximum.3, %select.2), window={size=10}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBackwardFilter", metadata={op_name="jit(make_step)/jit(main)/transpose(jvp(jit(loss)))/vmap(eqx.nn.Conv)/conv_general_dilated" source_file="/home/raymo/Projects/nlsid/.venv/lib/python3.12/site-packages/equinox/nn/_conv.py" source_line=239}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[]}

Original error: INTERNAL: All algorithms tried for (f32[1,2,10]{2,1,0}, u8[0]{0}) custom-call(f32[512,2,502]{2,1,0}, f32[512,1,493]{2,1,0}), window={size=10}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBackwardFilter", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","reification_cost":[],"wait_on_operation_queues":[]} failed. Falling back to default algorithm.  Per-algorithm errors:
  Profiling failure on cuDNN engine eng20{k2=6,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng20{k2=7,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng1{k2=6,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng20{k2=2,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng1{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng20{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng20{k2=5,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng1{k2=1,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'

As a result, convolution performance may be suboptimal.
2026-01-13 15:49:12.961294: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:847] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms.
2026-01-13 15:49:12.961316: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:850] Conv: %cudnn-conv-bias-activation.2 = (f32[512,2,502]{2,1,0}, u8[0]{0}) custom-call(%concatenate.2, %Arg_0.1, %bitcast.2), window={size=100}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBiasActivationForward", metadata={op_name="jit(make_step)/jit(main)/jvp(jit(loss))/vmap(eqx.nn.Conv)/conv_general_dilated" source_file="/home/raymo/Projects/nlsid/.venv/lib/python3.12/site-packages/equinox/nn/_conv.py" source_line=239}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[]}
2026-01-13 15:49:12.962093: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:1077] Failed to determine best cudnn convolution algorithm for:
%cudnn-conv-bias-activation.2 = (f32[512,2,502]{2,1,0}, u8[0]{0}) custom-call(%concatenate.2, %Arg_0.1, %bitcast.2), window={size=100}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBiasActivationForward", metadata={op_name="jit(make_step)/jit(main)/jvp(jit(loss))/vmap(eqx.nn.Conv)/conv_general_dilated" source_file="/home/raymo/Projects/nlsid/.venv/lib/python3.12/site-packages/equinox/nn/_conv.py" source_line=239}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[]}

Original error: INTERNAL: All algorithms tried for (f32[512,2,502]{2,1,0}, u8[0]{0}) custom-call(f32[512,2,601]{2,1,0}, f32[2,2,100]{2,1,0}, f32[2]{0}), window={size=100}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","reification_cost":[],"wait_on_operation_queues":[]} failed. Falling back to default algorithm.  Per-algorithm errors:
  Profiling failure on cuDNN engine eng13{}: UNKNOWN: CUDNN_STATUS_INTERNAL_ERROR
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng35{k2=4,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng11{k2=3,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng35{k2=1,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng35{k2=0,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng11{k2=1,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng11{k2=0,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng36{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng36{k2=1,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng11{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'

As a result, convolution performance may be suboptimal.
2026-01-13 15:49:12.965211: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:847] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms.
2026-01-13 15:49:12.965234: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:850] Conv: %cudnn-conv-bias-activation.5 = (f32[512,1,493]{2,1,0}, u8[0]{0}) custom-call(%maximum.3, %Arg_2.3, %bitcast.3), window={size=10}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBiasActivationForward", metadata={op_name="jit(make_step)/jit(main)/jvp(jit(loss))/vmap(eqx.nn.Conv)/conv_general_dilated" source_file="/home/raymo/Projects/nlsid/.venv/lib/python3.12/site-packages/equinox/nn/_conv.py" source_line=239}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[]}
2026-01-13 15:49:12.965988: W external/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:1077] Failed to determine best cudnn convolution algorithm for:
%cudnn-conv-bias-activation.5 = (f32[512,1,493]{2,1,0}, u8[0]{0}) custom-call(%maximum.3, %Arg_2.3, %bitcast.3), window={size=10}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBiasActivationForward", metadata={op_name="jit(make_step)/jit(main)/jvp(jit(loss))/vmap(eqx.nn.Conv)/conv_general_dilated" source_file="/home/raymo/Projects/nlsid/.venv/lib/python3.12/site-packages/equinox/nn/_conv.py" source_line=239}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[]}

Original error: INTERNAL: All algorithms tried for (f32[512,1,493]{2,1,0}, u8[0]{0}) custom-call(f32[512,2,502]{2,1,0}, f32[1,2,10]{2,1,0}, f32[1]{0}), window={size=10}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","reification_cost":[],"wait_on_operation_queues":[]} failed. Falling back to default algorithm.  Per-algorithm errors:
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng11{k2=3,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng36{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng35{k2=4,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng11{k2=0,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng35{k2=0,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng11{k2=1,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng35{k2=1,k3=0}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng11{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'

As a result, convolution performance may be suboptimal.
E0113 15:49:13.937752  153454 pjrt_stream_executor_client.cc:2916] Execution of replica 0 failed: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status'
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/raymo/Projects/nlsid/pipeline/v1/3_train_neural_network/train_neural_network.py", line 511, in <module>
    model = train(
            ^^^^^^
  File "/home/raymo/Projects/nlsid/pipeline/v1/3_train_neural_network/train_neural_network.py", line 278, in train
    model, opt_state, train_loss = make_step(model, opt_state, simulations_batch)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raymo/Projects/nlsid/.venv/lib/python3.12/site-packages/equinox/_jit.py", line 209, in __call__
    return _call(self, False, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raymo/Projects/nlsid/.venv/lib/python3.12/site-packages/equinox/_jit.py", line 263, in _call
    marker, _, _ = out = jit_wrapper._cached(
                         ^^^^^^^^^^^^^^^^^^^^
jaxlib._jax.XlaRuntimeError: UNKNOWN: <unknown cudnn status: 5003>
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(5888): 'status

Context:
The exception is being raised inside of this function:

@eqx.filter_jit
def make_step(
    model: MultiLayerPerceptron,
    opt_state: PyTree,
    simulation: PyTree[Float[Array, "..."]],
):
    loss_value, grads = eqx.filter_value_and_grad(loss)(model, normalizer, simulation)
    updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss_value

Where it worked fine when my model was a simple MLP, but adding the Conv1d layers resulted in these errors.

Python: 3.12.11

To reproduce:
I created a simple script to reproduce the error. With the following packages:

  • jax==0.8.0
  • equinox==0.13.2
  • optax==0.2.6

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PyTree
import optax


class MultiLayerPerceptron(eqx.Module):
    layers: list

    def __init__(self, x_shape, y_shape, key):
        k1, k2, k3, k4 = jax.random.split(key, 4)

        print(x_shape)
        self.layers = [
            eqx.nn.Conv1d(in_channels=2, out_channels=2, kernel_size=100, key=k1),
            jax.nn.relu,
            eqx.nn.Conv1d(in_channels=2, out_channels=1, kernel_size=10, key=k2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(493, 493, key=k3),
            jax.nn.relu,
            eqx.nn.Linear(493, y_shape[0], key=k4),
        ]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
@eqx.filter_jit
def loss(
    model: MultiLayerPerceptron,
    x: Float[Array, "..."],
    y: Float[Array, "..."],
) -> Float[Array, ""]:
    pred_y = model(x)

    loss_value = jnp.mean((pred_y - y) ** 2)
    
    return loss_value

x_fake = jnp.zeros((2, 601))
y_fake = jnp.zeros((8))

model = MultiLayerPerceptron(x_fake.shape, y_fake.shape, jax.random.PRNGKey(0))
optim = optax.adamw(learning_rate=1e-3)
opt_state = optim.init(eqx.filter(model, eqx.is_array))

@eqx.filter_jit
def make_step(
    model: MultiLayerPerceptron,
    opt_state: optax.OptState,
):
    loss_value, grads = eqx.filter_value_and_grad(loss)(model, x_fake, y_fake)
    updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss_value

model, opt_state, train_loss = make_step(model, opt_state)

I thought you might want to know about this error.

Best,
Raymond

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions