Conversation
|
Current test failures: |
| // The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher. | ||
| // Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119 | ||
| __device__ __half atomicAdd(__half *address, __half val) { | ||
| //__device__ __half atomicAdd(__half *address, __half val) { |
There was a problem hiding this comment.
unsure why just this signature was present
|
|
||
| template <> | ||
| __host__ __device__ | ||
| constexpr int32_t max_value<int32_t>() { |
There was a problem hiding this comment.
these symbols were missing when cuda bindgen ran for some reason
| WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16) | ||
| WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) | ||
|
|
||
| WHERE_OP(__nv_fp8_e4m3, int16_t, where_i16_fp8_e4m3) |
There was a problem hiding this comment.
__CUDA_ARCH__ guard should be 890 here
|
fp8 seems somewhat slower than I would expect in the candle benchmark harness (this is on a GH100): Probably because we're double-casting from fp8->half->f32? #define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3))
AFFINE_OP(__nv_fp8_e4m3, affine_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(mul) + F8E4M3_TO_FLOAT(add))) |
|
@LaurentMazare let me know if this directionally looks good, happy to make any changes to the approach if needed. |
|
I fixed a couple things for the CI's to pass, but besides that It looks good to me. |
|
Thanks for taking a look @greenrazer! If there's any hesitance to merge as-is, would putting it behind a feature help? We'd probably have to leave the kernel additions but could gate all of the Rust code. |
|
@zackangelo can you confirm that the CUDA build works on CC > 8 and CC < 8 (i.e. maintaining compatability)? |
|
@EricLBuehler I tested an earlier build but would probably be worth getting some time on an A100 and verifying again, I'll see if I can get around to that today or tomorrow |
|
If there’s still a plan to merge this MR? |
Yes, just need to get some time on a cuda machine to test that it works with CC > 8 and CC < 8. |
|
Some testing on an A100: |
|
Thanks @zackangelo for making this happen and @EricLBuehler for starting it off! |
|
During development I'm running on a NVIDIA 1060 6GB and I just Once I knew the commit I told Claude Code to check the Cargo.toml for the dependencies and it said to revert the update. Then I found this PR. Is this a regression or is my 1060 not supported going forward? Do I maybe need to disable float8 somehow? |
|
@metalmatze thanks for reporting this, do you have the ability to test a branch if I give you one? |
|
@metalmatze Do you happen to know the compute capability version for that GPU? It's not listed on Nvidia's website. You can get nvidia-smi to give it to you: |
|
|
Let me know what to run and I'll happily try to! |
|
@metalmatze can you try building your code against this branch? |
Plucked from @EricLBuehler's work in #2745.
This implements fp8 operations where they are straightforward. Many fp8 ops can't be implemented because they require a scale tensor alongside the main tensor to compensate for fp8's limited dynamic range (e.g. matmul).