Skip to content

Commit d0c8e24

Browse files
committed
fp8: fix failing tests
1 parent 8e7e581 commit d0c8e24

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed

candle-core/tests/tensor_tests.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,21 @@ fn arange(device: &Device) -> Result<()> {
126126
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
127127
[5, 4, 3, 2, 1],
128128
);
129+
130+
assert_eq!(
131+
Tensor::arange_step(
132+
F8E4M3::from_f32(0.),
133+
F8E4M3::from_f32(5.),
134+
F8E4M3::from_f32(2.),
135+
device
136+
)?
137+
.to_vec1::<F8E4M3>()?,
138+
[
139+
F8E4M3::from_f32(0.),
140+
F8E4M3::from_f32(2.),
141+
F8E4M3::from_f32(4.),
142+
],
143+
);
129144
Ok(())
130145
}
131146

candle-kernels/src/compatibility.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,11 @@ __device__ double atomicAdd(double* address, double val) {
3535
}
3636
#endif
3737

38-
3938
#if __CUDA_ARCH__ < 700
4039
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomicadd
4140
// The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher.
4241
// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119
43-
__device__ __half atomicAdd(__half *address, __half val) {
42+
//__device__ __half atomicAdd(__half *address, __half val) {
4443
// unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
4544
// unsigned int old = *address_as_ui;
4645
// unsigned int assumed;
@@ -56,7 +55,7 @@ __device__ __half atomicAdd(__half *address, __half val) {
5655

5756
// } while (assumed != old);
5857
// return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff));
59-
}
58+
//}
6059
#endif
6160

6261

candle-kernels/src/indexing.cu

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,18 @@ constexpr uint8_t max_value<uint8_t>() {
2525
return 0xFFu;
2626
}
2727

28+
template <>
29+
__host__ __device__
30+
constexpr int32_t max_value<int32_t>() {
31+
return 0x7FFFFFFF;
32+
}
33+
34+
template <>
35+
__host__ __device__
36+
constexpr int16_t max_value<int16_t>() {
37+
return 0x7FFF;
38+
}
39+
2840
template<typename T, typename I>
2941
__device__ void index_select(
3042
const size_t numel,
@@ -134,7 +146,7 @@ __device__ void index_add(
134146
}
135147
}
136148

137-
#if __CUDA_ARCH__ >= 800
149+
#if __CUDA_ARCH__ >= 890
138150
#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3))
139151

140152
template<typename I>
@@ -311,7 +323,9 @@ SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
311323
S_OP(__nv_bfloat16, int64_t, s_i64_bf16)
312324
S_OP(__nv_bfloat16, uint32_t, s_u32_bf16)
313325
S_OP(__nv_bfloat16, uint8_t, s_u8_bf16)
326+
#endif
314327

328+
#if __CUDA_ARCH__ >= 890
315329
IS_OP(__nv_fp8_e4m3, int16_t, is_i16_f8_e4m3)
316330
IS_OP(__nv_fp8_e4m3, int32_t, is_i32_f8_e4m3)
317331
IS_OP(__nv_fp8_e4m3, int64_t, is_i64_f8_e4m3)

0 commit comments

Comments
 (0)