Skip to content

Commit 94d7361

Browse files
try CI fix
1 parent e70fa55 commit 94d7361

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

ggml/src/ggml-cuda/mmf.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,11 @@ static __global__ void mul_mat_f(
7272
const int j = j0 + itB*tile_B::I;
7373

7474
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
75-
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
75+
if constexpr (std::is_same_v<T, half2>) {
76+
tile_xy[j0*tile_k_padded + threadIdx.x] = make_half2(tmp.x, tmp.y);
77+
} else {
78+
tile_xy[j0*tile_k_padded + threadIdx.x] = make_bfloat162(tmp.x, tmp.y);
79+
}
7680
}
7781
} else {
7882
static_assert(std::is_same_v<T, void>, "unsupported type");

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@
199199
#define __has_builtin(x) 0
200200
#endif
201201

202-
typedef hip_bfloat16 nv_bfloat16;
202+
typedef __hip_bfloat16 nv_bfloat16;
203+
typedef __hip_bfloat162 nv_bfloat162;
203204

204205
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
205206
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));

ggml/src/ggml-cuda/vendors/musa.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,5 @@
137137
#define cudaStreamEndCapture musaStreamEndCapture
138138
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
139139

140-
typedef mt_bfloat16 nv_bfloat16;
140+
typedef __mt_bfloat16 nv_bfloat16;
141+
typedef __mt_bfloat162 nv_bfloat162;

0 commit comments

Comments
 (0)