5
5
#include < cuda_fp16.h>
6
6
#endif
7
7
8
+ #include < ATen/ATen.h>
8
9
#include < stdint.h>
9
10
10
11
#include " common/data_type.h"
@@ -27,6 +28,7 @@ struct FloatVecTypeTrait {};
27
28
VEC_TYPE_TRAITS_SPECIALIZATION (T, 1 , T, typename T)
28
29
29
30
#if defined(COLOSSAL_WITH_CUDA)
31
+
30
32
VEC_TYPE_TRAITS_SPECIALIZATION (at::BFloat16, 1 , __nv_bfloat16)
31
33
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 2 , __nv_bfloat162)
32
34
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 4 , float2)
@@ -35,18 +37,19 @@ VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 1, half)
35
37
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 2 , half2)
36
38
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 4 , float2)
37
39
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 8 , float4)
38
- VEC_TYPE_TRAITS_SPECIALIZATION(float , 2 , float2)
39
- VEC_TYPE_TRAITS_SPECIALIZATION(float , 4 , float4)
40
- VEC_TYPE_TRAITS_SPECIALIZATION(float , 8 , dtype::float8_)
41
- VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t , 2 , half)
42
- VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t , 4 , half2)
43
- VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t , 8 , float2)
40
+
41
+ VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t , 2 , uint16_t )
42
+ VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t , 4 , uint32_t )
43
+ VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t , 8 , uint2)
44
44
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 2 , __nv_bfloat162);
45
45
VEC_TYPE_TRAITS_SPECIALIZATION (__nv_bfloat16, 4 , dtype::bfloat164);
46
46
VEC_TYPE_TRAITS_SPECIALIZATION (__nv_bfloat16, 8 , dtype::bfloat168);
47
47
VEC_TYPE_TRAITS_SPECIALIZATION (half, 2 , half2);
48
48
VEC_TYPE_TRAITS_SPECIALIZATION (half, 4 , dtype::half4);
49
49
VEC_TYPE_TRAITS_SPECIALIZATION (half, 8 , dtype::half8);
50
+ VEC_TYPE_TRAITS_SPECIALIZATION (float , 2 , float2)
51
+ VEC_TYPE_TRAITS_SPECIALIZATION(float , 4 , float4)
52
+ VEC_TYPE_TRAITS_SPECIALIZATION(float , 8 , dtype::float8_)
50
53
#endif /* defined(COLOSSAL_WITH_CUDA) */
51
54
52
55
#undef VEC_TYPE_TRAITS_SPECIALIZATION
0 commit comments