Skip to content

fp8 support#2989

Merged
greenrazer merged 13 commits intohuggingface:mainfrom
zackangelo:fp8
Aug 4, 2025
Merged

fp8 support#2989
greenrazer merged 13 commits intohuggingface:mainfrom
zackangelo:fp8

Conversation

@zackangelo
Copy link
Contributor

@zackangelo zackangelo commented Jun 11, 2025

Plucked from @EricLBuehler's work in #2745.

This implements fp8 operations where they are straightforward. Many fp8 ops can't be implemented because they require a scale tensor alongside the main tensor to compensate for fp8's limited dynamic range (e.g. matmul).

@zackangelo
Copy link
Contributor Author

Current test failures:

failures:

---- gather_gpu stdout ----
Error: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")

---- embeddings_gpu stdout ----
Error: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")

---- asort_gpu stdout ----
Error: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")

---- scatter_gpu stdout ----
Error: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")

---- index_select_gpu stdout ----
Error: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")

---- index_add_gpu stdout ----
Error: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")


failures:
    asort_gpu
    embeddings_gpu
    gather_gpu
    index_add_gpu
    index_select_gpu
    scatter_gpu

// The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher.
// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119
__device__ __half atomicAdd(__half *address, __half val) {
//__device__ __half atomicAdd(__half *address, __half val) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsure why just this signature was present


template <>
__host__ __device__
constexpr int32_t max_value<int32_t>() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these symbols were missing when cuda bindgen ran for some reason

WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16)
WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16)

WHERE_OP(__nv_fp8_e4m3, int16_t, where_i16_fp8_e4m3)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__CUDA_ARCH__ guard should be 890 here

@zackangelo
Copy link
Contributor Author

zackangelo commented Jun 12, 2025

fp8 seems somewhat slower than I would expect in the candle benchmark harness (this is on a GH100):

cuda_affine_f32/iter    time:   [5.5286 µs 5.5322 µs 5.5349 µs]
                        thrpt:  [705.74 GiB/s 706.10 GiB/s 706.55 GiB/s]
                 change:
                        time:   [-0.0349% +0.0724% +0.1626%] (p = 0.17 > 0.05)
                        thrpt:  [-0.1623% -0.0723% +0.0349%]
                        No change in performance detected.
Found 11 outliers among 100 measurements (11.00%)
  6 (6.00%) low severe
  4 (4.00%) low mild
  1 (1.00%) high severe

cuda_affine_f16/iter    time:   [5.4081 µs 5.4135 µs 5.4180 µs]
                        thrpt:  [360.49 GiB/s 360.79 GiB/s 361.15 GiB/s]
                 change:
                        time:   [-0.8211% -0.7134% -0.6133%] (p = 0.00 < 0.05)
                        thrpt:  [+0.6170% +0.7185% +0.8279%]
                        Change within noise threshold.
Found 11 outliers among 100 measurements (11.00%)
  4 (4.00%) low severe
  6 (6.00%) low mild
  1 (1.00%) high severe

cuda_affine_bf16/iter   time:   [5.4118 µs 5.4154 µs 5.4185 µs]
                        thrpt:  [360.45 GiB/s 360.66 GiB/s 360.90 GiB/s]
                 change:
                        time:   [-0.7404% -0.6370% -0.5323%] (p = 0.00 < 0.05)
                        thrpt:  [+0.5351% +0.6411% +0.7459%]
                        Change within noise threshold.
Found 8 outliers among 100 measurements (8.00%)
  3 (3.00%) low severe
  5 (5.00%) low mild

cuda_affine_fp8/iter    time:   [5.6150 µs 5.6186 µs 5.6216 µs]
                        thrpt:  [173.72 GiB/s 173.81 GiB/s 173.92 GiB/s]
                 change:
                        time:   [+3.6965% +3.7735% +3.8562%] (p = 0.00 < 0.05)
                        thrpt:  [-3.7130% -3.6363% -3.5648%]
                        Performance has regressed.
Found 7 outliers among 100 measurements (7.00%)
  4 (4.00%) low severe
  1 (1.00%) low mild
  1 (1.00%) high mild
  1 (1.00%) high severe

cpu_affine_f32/iter     time:   [60.175 µs 60.214 µs 60.256 µs]
                        thrpt:  [64.827 GiB/s 64.873 GiB/s 64.915 GiB/s]
                 change:
                        time:   [-0.3960% -0.2927% -0.1974%] (p = 0.00 < 0.05)
                        thrpt:  [+0.1978% +0.2936% +0.3976%]
                        Change within noise threshold.
Found 3 outliers among 100 measurements (3.00%)
  3 (3.00%) high mild

cpu_affine_f16/iter     time:   [313.67 µs 314.25 µs 314.84 µs]
                        thrpt:  [6.2035 GiB/s 6.2151 GiB/s 6.2267 GiB/s]
                 change:
                        time:   [+0.2332% +0.3714% +0.5040%] (p = 0.00 < 0.05)
                        thrpt:  [-0.5015% -0.3700% -0.2326%]
                        Change within noise threshold.
Found 15 outliers among 100 measurements (15.00%)
  7 (7.00%) high mild
  8 (8.00%) high severe

cpu_affine_bf16/iter    time:   [3.5991 ms 3.5996 ms 3.6001 ms]
                        thrpt:  [555.54 MiB/s 555.61 MiB/s 555.70 MiB/s]
                 change:
                        time:   [-0.0397% -0.0205% +0.0000%] (p = 0.05 > 0.05)
                        thrpt:  [-0.0000% +0.0205% +0.0397%]
                        No change in performance detected.
Found 7 outliers among 100 measurements (7.00%)
  3 (3.00%) low severe
  1 (1.00%) low mild
  2 (2.00%) high mild
  1 (1.00%) high severe

cpu_affine_fp8/iter     time:   [10.386 ms 10.387 ms 10.389 ms]
                        thrpt:  [96.259 MiB/s 96.271 MiB/s 96.283 MiB/s]
                 change:
                        time:   [-0.4090% -0.3742% -0.3366%] (p = 0.00 < 0.05)
                        thrpt:  [+0.3377% +0.3756% +0.4107%]
                        Change within noise threshold.
Found 2 outliers among 100 measurements (2.00%)
  2 (2.00%) high mild

Probably because we're double-casting from fp8->half->f32?

#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3))
AFFINE_OP(__nv_fp8_e4m3, affine_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(mul) + F8E4M3_TO_FLOAT(add)))

@zackangelo zackangelo marked this pull request as ready for review June 21, 2025 00:40
@zackangelo
Copy link
Contributor Author

@LaurentMazare let me know if this directionally looks good, happy to make any changes to the approach if needed.

@greenrazer
Copy link
Contributor

I fixed a couple things for the CI's to pass, but besides that It looks good to me.

@zackangelo
Copy link
Contributor Author

Thanks for taking a look @greenrazer!

If there's any hesitance to merge as-is, would putting it behind a feature help? We'd probably have to leave the kernel additions but could gate all of the Rust code.

@EricLBuehler
Copy link
Member

@zackangelo can you confirm that the CUDA build works on CC > 8 and CC < 8 (i.e. maintaining compatability)?

@zackangelo
Copy link
Contributor Author

@EricLBuehler I tested an earlier build but would probably be worth getting some time on an A100 and verifying again, I'll see if I can get around to that today or tomorrow

@fchengjin
Copy link

If there’s still a plan to merge this MR?

@greenrazer
Copy link
Contributor

If there’s still a plan to merge this MR?

Yes, just need to get some time on a cuda machine to test that it works with CC > 8 and CC < 8.

@zackangelo
Copy link
Contributor Author

Some testing on an A100:

ubuntu@104-171-202-244:~/candle/candle-core$ nvidia-smi
Sun Aug  3 23:16:38 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.148.08             Driver Version: 570.148.08     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-PCIE-40GB          On  |   00000000:06:00.0 Off |                    0 |
| N/A   40C    P0             39W /  250W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
ubuntu@104-171-202-244:~/candle/candle-core$ cargo test -F cuda
    Finished `test` profile [unoptimized + debuginfo] target(s) in 0.17s
     Running unittests src/lib.rs (/home/ubuntu/candle/target/debug/deps/candle_core-6091b884a3915faa)

running 9 tests
test npy::tests::parse ... ok
test shape::tests::test_from_tuple ... ok
test safetensors::tests::save_single_tensor ... ok
test safetensors::tests::save_load_multiple_tensors ... ok
test shape::tests::stride ... ok
test quantized::cuda::test::cuda_mmv_q8_1 ... ok
test quantized::cuda::test::cuda_mm_q8_1_pad ... ok
test quantized::cuda::test::cuda_mm_q8_1 ... ok
test quantized::cuda::test::cuda_quantize_q8_1 ... ok

test result: ok. 9 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.39s

     Running tests/conv_tests.rs (/home/ubuntu/candle/target/debug/deps/conv_tests-5a39be9c008f8b55)

running 14 tests
test conv1d_small_cpu ... ok
test conv2d_non_square_cpu ... ok
test conv2d_smaller_cpu ... ok
test conv2d_small_cpu ... ok
test conv1d_cpu ... ok
test conv2d_cpu ... ok
test conv2d_grad_cpu ... ok
test conv1d_small_gpu ... ok
test conv2d_small_gpu ... ok
test conv2d_non_square_gpu ... ok
test conv2d_smaller_gpu ... ok
test conv2d_gpu ... ok
test conv1d_gpu ... ok
test conv2d_grad_gpu ... ok

test result: ok. 14 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.42s

     Running tests/custom_op_tests.rs (/home/ubuntu/candle/target/debug/deps/custom_op_tests-838a352a4d30c5d0)

running 4 tests
test custom_op1_no_backward ... ok
test inplace_op1 ... ok
test custom_op1_with_backward ... ok
test ug_op ... ok

test result: ok. 4 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.30s

     Running tests/display_tests.rs (/home/ubuntu/candle/target/debug/deps/display_tests-8068dcaf01c57fd4)

running 3 tests
test display_scalar ... ok
test display_vector ... ok
test display_multi_dim ... ok

test result: ok. 3 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

     Running tests/grad_tests.rs (/home/ubuntu/candle/target/debug/deps/grad_tests-1f870595326eb9ad)

running 13 tests
test simple_grad_cpu ... ok
test binary_grad_cpu ... ok
test sum_grad_cpu ... ok
test test_flip_backprop ... ok
test unary_grad_cpu ... ok
test grad_descent_cpu ... ok
test matmul_grad_cpu ... ok
test simple_grad_gpu ... ok
test matmul_grad_gpu ... ok
test sum_grad_gpu ... ok
test binary_grad_gpu ... ok
test unary_grad_gpu ... ok
test grad_descent_gpu ... ok

test result: ok. 13 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.50s

     Running tests/indexing_tests.rs (/home/ubuntu/candle/target/debug/deps/indexing_tests-32cf783a8ef8fd5a)

running 4 tests
test index_3d ... ok
test integer_index ... ok
test range_index ... ok
test slice_assign ... ok

test result: ok. 4 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

     Running tests/layout_tests.rs (/home/ubuntu/candle/target/debug/deps/layout_tests-1d7c191115df5104)

running 3 tests
test contiguous_cpu ... ok
test strided_blocks ... ok
test contiguous_gpu ... ok

test result: ok. 3 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.28s

     Running tests/matmul_tests.rs (/home/ubuntu/candle/target/debug/deps/matmul_tests-205405f0287ae4fb)

running 12 tests
test matmul_bf16_cpu ... ok
test squeeze_mm_cpu ... ok
test tensor_dot ... ok
test tensor_mv ... ok
test mm_layout_cpu ... ok
test matmul_cpu ... ok
test broadcast_matmul_cpu ... ok
test squeeze_mm_gpu ... ok
test mm_layout_gpu ... ok
test matmul_gpu ... ok
test broadcast_matmul_gpu ... ok
test matmul_bf16_gpu ... ok

test result: ok. 12 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.53s

     Running tests/pool_tests.rs (/home/ubuntu/candle/target/debug/deps/pool_tests-88b96749237816b6)

running 8 tests
test avg_pool2d_cpu ... ok
test avg_pool2d_pytorch_cpu ... ok
test max_pool2d_cpu ... ok
test upsample_nearest2d_cpu ... ok
test avg_pool2d_gpu ... ok
test upsample_nearest2d_gpu ... ok
test avg_pool2d_pytorch_gpu ... ok
test max_pool2d_gpu ... ok

test result: ok. 8 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.28s

     Running tests/pth_tests.rs (/home/ubuntu/candle/target/debug/deps/pth_tests-7f3d11adb4c7657c)

running 3 tests
test test_pth ... ok
test test_pth_fortran_congiguous ... ok
test test_pth_with_key ... ok

test result: ok. 3 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

     Running tests/quantized_tests.rs (/home/ubuntu/candle/target/debug/deps/quantized_tests-41d286d4e0aad6fe)

running 53 tests
test quantize_q4_0_cpu ... ok
test quantize_q4_1_cpu ... ok
test quantize_q5_0_cpu ... ok
test quantize_q3k_cpu ... ok
test quantize_q5_1_cpu ... ok
test quantize_q2k_cpu ... ok
test quantize_q8k_cpu ... ok
test quantize_q6k_cpu ... ok
test quantize_q4k_cpu ... ok
test quantize_q5k_cpu ... ok
test quantized_matmul_q5_0_cpu ... ok
test quantized_matmul_q4_0_cpu ... ok
test quantized_matmul_q4_1_cpu ... ok
test qmm_n_cpu ... ok
test quantized_matmul_q3k_cpu ... ok
test quantized_matmul_q4k_cpu ... ok
test qmm_b_cpu ... ok
test qmm_cpu ... ok
test quantized_matmul_q2k_cpu ... ok
test quantized_matmul_q5_1_cpu ... ok
test quantized_matmul_q6k_cpu ... ok
test quantized_matmul_q5k_cpu ... ok
test quantized_matmul_q8_0_cpu ... ok
test quantized_matmul_q8k ... ok
test quantized_matmul_q3k ... ok
test quantized_matmul_q2k ... ok
test quantized_matmul_q4k ... ok
test quantized_mm ... ok
test quantized_matmul_q6k ... ok
test quantized_matmul_q5k ... ok
test quantized_matmul_q2k_cuda ... ok
test quantized_matmul_q5_1_cuda ... ok
test quantize_q5_0_cuda ... ok
test quantized_matmul_q5_0_cuda ... ok
test quantized_matmul_q3k_cuda ... ok
test quantized_matmul_q8_0_cuda ... ok
test quantized_matmul_q6k_cuda ... ok
test quantized_matmul_q4_1_cuda ... ok
test quantized_matmul_q4k_cuda ... ok
test quantized_matmul_q5k_cuda ... ok
test quantized_matmul_q4_0_cuda ... ok
test qmm_n_cuda ... ok
test qmm_cuda ... ok
test quantize_q4_1_cuda ... ok
test quantize_q5_1_cuda ... ok
test quantize_q6k_cuda ... ok
test quantize_q3k_cuda ... ok
test quantize_q5k_cuda ... ok
test quantize_q4_0_cuda ... ok
test quantize_q2k_cuda ... ok
test qmm_b_cuda ... ok
test quantize_q4k_cuda ... ok
test quantize_q8k_cuda ... ok

test result: ok. 53 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 1.68s

     Running tests/serialization_tests.rs (/home/ubuntu/candle/target/debug/deps/serialization_tests-588935abfd328fa2)

running 3 tests
test npy ... ok
test npz ... ok
test safetensors ... ok

test result: ok. 3 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

     Running tests/tensor_tests.rs (/home/ubuntu/candle/target/debug/deps/tensor_tests-01ea15803db3b4a3)

running 77 tests
test arange_cpu ... ok
test add_mul_cpu ... ok
test binary_op_cpu ... ok
test broadcast_cpu ... ok
test broadcasting_cpu ... ok
test cat_cpu ... ok
test clamp_cpu ... ok
test cmp_cpu ... ok
test cs_cpu ... ok
test embeddings_cpu ... ok
test full_cpu ... ok
test gather_cpu ... ok
test i64_abs ... ok
test index_add_cpu ... ok
test argmax_cpu ... ok
test index_select_fail ... ok
test index_select_cpu ... ok
test argmin_cpu ... ok
test log_sum_exp ... ok
test narrow_cpu ... ok
test ones_cpu ... ok
test asort_cpu ... ok
test pow ... ok
test pad_with_same ... ok
test randn_hasneg ... ok
test slice_scatter_cpu ... ok
test scatter_cpu ... ok
test randn_cpu ... ok
test ss_cpu ... ok
test max_cpu ... ok
test min_cpu ... ok
test tensor_2d_cpu ... ok
test tensor_new ... ok
test tensor_norm ... ok
test test_flip_1d ... ok
test test_flip_2d ... ok
test test_flip_3d_channels ... ok
test transpose_cpu ... ok
test sum_cpu ... ok
test tril_triu_eye ... ok
test cumsum ... ok
test unary_op_cpu ... ok
test var_cpu ... ok
test tensor_2d_gpu ... ok
test zero_dim_cpu ... ok
test full_gpu ... ok
test zeros_cpu ... ok
test broadcast_gpu ... ok
test add_mul_gpu ... ok
test arange_gpu ... ok
test ones_gpu ... ok
test index_add_gpu ... ok
test gather_gpu ... ok
test zeros_gpu ... ok
test embeddings_gpu ... ok
test transpose_gpu ... ok
test asort_gpu ... ok
test cs_gpu ... ok
test narrow_gpu ... ok
test clamp_gpu ... ok
test var_gpu ... ok
test zero_dim_gpu ... ok
test cmp_gpu ... ok
test binary_op_gpu ... ok
test slice_scatter_gpu ... ok
test scatter_gpu ... ok
test ss_gpu ... ok
test broadcasting_gpu ... ok
test unary_op_gpu ... ok
test cat_gpu ... ok
test index_select_gpu ... ok
test max_gpu ... ok
test argmin_gpu ... ok
test min_gpu ... ok
test argmax_gpu ... ok
test sum_gpu ... ok
test randn_gpu ... ok

test result: ok. 77 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.76s

   Doc-tests candle_core

running 41 tests
test candle-core/src/tensor.rs - tensor::Tensor::arange (line 427) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::flip (line 2739) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::reshape (line 2337) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::min_all (line 1936) ... ok
test candle-core/src/indexer.rs - indexer::Tensor::index (line 9) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::ones (line 196) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::arange_step (line 440) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::embedding (line 1418) ... ok
test candle-core/src/indexer.rs - indexer::Tensor::i (line 217) ... ok
test candle-core/src/indexer.rs - indexer::Tensor::i (line 144) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::dot (line 1245) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::get_on_dim (line 2052) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::from_vec (line 488) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::ones_like (line 221) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::get (line 2032) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::norm (line 1270) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::affine (line 718) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::from_iter (line 409) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::narrow (line 795) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::from_slice (line 508) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::full (line 386) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::flatten_all (line 2019) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::max_all (line 1918) ... ok
test candle-core/src/tensor.rs - tensor::Tensor (line 57) ... ok
test candle-core/src/indexer.rs - indexer::Tensor::i (line 180) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::meshgrid (line 659) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::mv (line 1291) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::mean_keepdim (line 980) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::permute (line 2115) ... ok
test candle-core/src/lib.rs - (line 3) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::roll (line 920) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::zeros_like (line 262) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::squeeze (line 2384) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::t (line 2071) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::sum_all (line 1954) ... ok
test candle-core/src/tensor_cat.rs - tensor_cat::Tensor::cat (line 9) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::stack (line 2459) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::to_dtype (line 2274) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::unsqueeze (line 2422) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::sum_keepdim (line 952) ... ok
test candle-core/src/tensor.rs - tensor::Tensor::zeros (line 248) ... ok

test result: ok. 41 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 5.56s

ubuntu@104-171-202-244:~/candle/candle-core$ cd ..
ubuntu@104-171-202-244:~/candle$ cd candle-examples/
ubuntu@104-171-202-244:~/candle/candle-examples$ cargo run -F cuda --example llama
   Compiling candle-core v0.9.1 (/home/ubuntu/candle/candle-core)
   Compiling candle-nn v0.9.1 (/home/ubuntu/candle/candle-nn)
   Compiling candle-transformers v0.9.1 (/home/ubuntu/candle/candle-transformers)
   Compiling candle-examples v0.9.1 (/home/ubuntu/candle/candle-examples)
    Finished `dev` profile [unoptimized + debuginfo] target(s) in 33.33s
     Running `/home/ubuntu/candle/target/debug/examples/llama`
loading the model weights from meta-llama/Meta-Llama-3-8B
starting the inference loop
My favorite theorem is 1+2+3… = -1/12

In May I gave a talk about the Riemann zeta function at my local science cafe. The talk generated quite a lot of interest, and so I made an attempt to write up the material in article form for publication in my university’s magazine. The resulting piece appeared recently, and can be found here:

http://www.folkes.utexas.edu/~zetafunction.pdf

For those who wish to delve further into this topic, I highly recommend Terry Tao’s excellent blog post on the topic: http://terrytao.wordpress.com/2010/04/10/the-euler-maclaurin-formula-bernoulli-numbers-the-zeta-function-and-real-variable-analytic-continuation/

159 tokens generated (40.56765183425539 token/s)

@greenrazer
Copy link
Contributor

Thanks @zackangelo for making this happen and @EricLBuehler for starting it off!

@greenrazer greenrazer merged commit af5a69e into huggingface:main Aug 4, 2025
20 checks passed
@metalmatze
Copy link

During development I'm running on a NVIDIA 1060 6GB and I just git bisected these failures to this change:

Processing batch cdb4bab9-f4b6-4031-97d0-d4b1afd3c896 with 1 items
✗ Failed to extract features for item 6722241: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")
✗ Failed to extract features for item 6722240: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")
✗ Failed to extract features for item 6722239: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")

Once I knew the commit I told Claude Code to check the Cargo.toml for the dependencies and it said to revert the update. Then I found this PR.

Is this a regression or is my 1060 not supported going forward? Do I maybe need to disable float8 somehow?

Wed Aug 13 20:51:20 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.64.05              Driver Version: 575.64.05      CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce GTX 1060 6GB    Off |   00000000:25:00.0 Off |                  N/A |
|  0%   39C    P8              7W /  200W |       5MiB /   6144MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A            1836      G   Hyprland                                  1MiB |
+-----------------------------------------------------------------------------------------+

@zackangelo
Copy link
Contributor Author

@metalmatze thanks for reporting this, do you have the ability to test a branch if I give you one?

@zackangelo
Copy link
Contributor Author

@metalmatze Do you happen to know the compute capability version for that GPU? It's not listed on Nvidia's website.

You can get nvidia-smi to give it to you:

$ nvidia-smi --query-gpu=name,compute_cap --format=csv
name, compute_cap
NVIDIA GeForce RTX 5070 Ti, 12.0

@metalmatze
Copy link

$ nvidia-smi --query-gpu=name,compute_cap --format=csv

name, compute_cap
NVIDIA GeForce GTX 1060 6GB, 6.1

@metalmatze
Copy link

@metalmatze thanks for reporting this, do you have the ability to test a branch if I give you one?

Let me know what to run and I'll happily try to!

@zackangelo
Copy link
Contributor Author

@metalmatze can you try building your code against this branch?

https://github.com/zackangelo/candle/tree/fix_1090_cuda_err

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants