File tree Expand file tree Collapse file tree 3 files changed +9
-3
lines changed Expand file tree Collapse file tree 3 files changed +9
-3
lines changed Original file line number Diff line number Diff line change @@ -72,7 +72,11 @@ static __global__ void mul_mat_f(
7272 const int j = j0 + itB*tile_B::I;
7373
7474 const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2 (0 .0f , 0 .0f );
75- tile_xy[j0*tile_k_padded + threadIdx .x ] = {tmp.x , tmp.y };
75+ if constexpr (std::is_same_v<T, half2>) {
76+ tile_xy[j0*tile_k_padded + threadIdx .x ] = make_half2 (tmp.x , tmp.y );
77+ } else {
78+ tile_xy[j0*tile_k_padded + threadIdx .x ] = make_bfloat162 (tmp.x , tmp.y );
79+ }
7680 }
7781 } else {
7882 static_assert (std::is_same_v<T, void >, " unsupported type" );
Original file line number Diff line number Diff line change 199199 #define __has_builtin (x ) 0
200200#endif
201201
202- typedef hip_bfloat16 nv_bfloat16;
202+ typedef __hip_bfloat16 nv_bfloat16;
203+ typedef __hip_bfloat162 nv_bfloat162;
203204
204205typedef int8_t int8x4_t __attribute__ ((ext_vector_type(4 )));
205206typedef uint8_t uint8x4_t __attribute__ ((ext_vector_type(4 )));
Original file line number Diff line number Diff line change 137137#define cudaStreamEndCapture musaStreamEndCapture
138138#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
139139
140- typedef mt_bfloat16 nv_bfloat16 ;
140+ typedef __mt_bfloat16 nv_bfloat16 ;
141+ typedef __mt_bfloat162 nv_bfloat162 ;
You can’t perform that action at this time.
0 commit comments