Skip to content

Commit dca354d

Browse files
Merge pull request #251 from menloresearch/update-dev-from-master-2025-09-16-00-32
Sync master with upstream release b6482
2 parents f453acf + 3d4053f commit dca354d

File tree

9 files changed

+125
-72
lines changed

9 files changed

+125
-72
lines changed

.devops/rocm.Dockerfile

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,11 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build
1717
# gfx906 is deprecated
1818
#check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html
1919

20-
ARG ROCM_DOCKER_ARCH='gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201'
21-
#ARG ROCM_DOCKER_ARCH=gfx1100
20+
ARG ROCM_DOCKER_ARCH='gfx803;gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1010;gfx1030;gfx1032;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx1151'
21+
#ARG ROCM_DOCKER_ARCH='gfx1151'
2222

23-
# Set ROCm architectured
23+
# Set ROCm architectures
2424
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
25-
# Enable ROCm
26-
# ENV CC=/opt/rocm/llvm/bin/clang
27-
# ENV CXX=/opt/rocm/llvm/bin/clang++
2825

2926
RUN apt-get update \
3027
&& apt-get install -y \
@@ -39,8 +36,16 @@ WORKDIR /app
3936

4037
COPY . .
4138

39+
RUN git clone https://github.com/rocm/rocwmma --branch develop --depth 1
40+
4241
RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
43-
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=$ROCM_DOCKER_ARCH -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DCMAKE_BUILD_TYPE=Release -DLLAMA_BUILD_TESTS=OFF \
42+
cmake -S . -B build \
43+
-DGGML_HIP=ON \
44+
-DGGML_HIP_ROCWMMA_FATTN=ON \
45+
-DCMAKE_HIP_FLAGS="-I$(pwd)/rocwmma/library/include/" \
46+
-DAMDGPU_TARGETS="$ROCM_DOCKER_ARCH" \
47+
-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON \
48+
-DCMAKE_BUILD_TYPE=Release -DLLAMA_BUILD_TESTS=OFF \
4449
&& cmake --build build --config Release -j$(nproc)
4550

4651
RUN mkdir -p /app/lib \

.github/workflows/release.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -530,15 +530,13 @@ jobs:
530530
runs-on: windows-2022
531531

532532
env:
533-
# The ROCm version must correspond to the version used in the HIP SDK.
534-
ROCM_VERSION: "6.4.2"
535533
HIPSDK_INSTALLER_VERSION: "25.Q3"
536534

537535
strategy:
538536
matrix:
539537
include:
540538
- name: "radeon"
541-
gpu_targets: "gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
539+
gpu_targets: "gfx1151;gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
542540

543541
steps:
544542
- name: Clone
@@ -548,7 +546,7 @@ jobs:
548546
- name: Clone rocWMMA repository
549547
id: clone_rocwmma
550548
run: |
551-
git clone https://github.com/rocm/rocwmma --branch rocm-${{ env.ROCM_VERSION }} --depth 1
549+
git clone https://github.com/rocm/rocwmma --branch develop --depth 1
552550
553551
- name: Cache ROCm Installation
554552
id: cache-rocm

ggml/src/ggml-cuda/im2col.cu

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,14 @@ static __global__ void im2col_3d_kernel(
122122
int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
123123
int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
124124
int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
125+
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
125126
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
126127
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
127128
if (i >= IC_KD_KH_KW) {
128129
return;
129130
}
131+
GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH);
132+
GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW);
130133

131134
const int64_t iic = i / KD_KH_KW;
132135
const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
@@ -148,7 +151,7 @@ static __global__ void im2col_3d_kernel(
148151
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
149152
dst[offset_dst] = 0.0f;
150153
} else {
151-
const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
154+
const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
152155
dst[offset_dst] = src[offset_src];
153156
}
154157
}
@@ -159,6 +162,7 @@ template <typename T>
159162
static void im2col_3d_cuda(const float * src, T* dst,
160163
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
161164
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
165+
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
162166
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
163167
const int64_t OH_OW = OH*OW;
164168
const int64_t KD_KH_KW = KD*KH*KW;
@@ -179,23 +183,30 @@ static void im2col_3d_cuda(const float * src, T* dst,
179183
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
180184
IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
181185
OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
186+
stride_q, stride_z, stride_y, stride_x,
182187
s0, s1, s2, p0, p1, p2, d0, d1, d2);
183188
}
184189

185190
static void im2col_3d_cuda_f16(const float * src, half * dst,
186191
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
187192
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
193+
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
188194
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
189195

190-
im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
196+
im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
197+
stride_q, stride_z, stride_y, stride_x,
198+
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
191199
}
192200

193201
static void im2col_3d_cuda_f32(const float * src, float * dst,
194202
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
195203
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
204+
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
196205
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
197206

198-
im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
207+
im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
208+
stride_q, stride_z, stride_y, stride_x,
209+
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
199210
}
200211

201212
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -235,9 +246,19 @@ void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
235246
const int64_t OH = ne2;
236247
const int64_t OW = ne1;
237248

249+
const size_t es = ggml_element_size(src1);
250+
const int64_t stride_x = src1->nb[0] / es;
251+
const int64_t stride_y = src1->nb[1] / es;
252+
const int64_t stride_z = src1->nb[2] / es;
253+
const int64_t stride_q = src1->nb[3] / es;
254+
238255
if(dst->type == GGML_TYPE_F16) {
239-
im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
256+
im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
257+
stride_q, stride_z, stride_y, stride_x,
258+
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
240259
} else {
241-
im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
260+
im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
261+
stride_q, stride_z, stride_y, stride_x,
262+
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
242263
}
243264
}

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -57,31 +57,33 @@ static __global__ void mul_mat_f(
5757
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
5858

5959
if constexpr (has_ids) {
60-
__shared__ int has_any;
61-
if (threadIdx.y == 0) {
62-
int local_has_any = 0;
63-
for (int j = threadIdx.x; j < cols_per_block; j += warp_size) {
64-
int slot = -1;
65-
for (int k = 0; k < nchannels_dst; ++k) {
66-
const int idv = ids[j*stride_row_id + k*stride_col_id];
67-
if (idv == expert_idx) {
68-
slot = k;
69-
break;
70-
}
71-
}
72-
if (j < cols_per_block) {
73-
local_has_any |= (slot >= 0);
74-
slot_map[j] = slot;
60+
int found = 0;
61+
62+
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
63+
const int j = j0 + threadIdx.y;
64+
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
65+
66+
if (threadIdx.x == 0) {
67+
slot_map[j] = -1;
68+
}
69+
70+
for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
71+
int match = id_row[k*stride_col_id] == expert_idx;
72+
73+
if (match) {
74+
slot_map[j] = k;
75+
found = 1;
76+
break;
7577
}
7678
}
77-
has_any = warp_reduce_any(local_has_any);
7879
}
79-
__syncthreads();
80-
if (has_any == 0) {
80+
81+
if (!__syncthreads_or(found)) {
8182
return;
8283
}
8384
}
8485

86+
8587
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
8688
tile_A A[ntA][warp_size / tile_A::J];
8789
#pragma unroll
@@ -106,14 +108,7 @@ static __global__ void mul_mat_f(
106108
if constexpr (!has_ids) {
107109
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
108110
} else {
109-
float val = 0.0f;
110-
if (j < cols_per_block) {
111-
const int slot = slot_map[j];
112-
if (slot >= 0) {
113-
val = y[slot*stride_channel_y + j*stride_col_y + col];
114-
}
115-
}
116-
tile_xy[j0*tile_k_padded + threadIdx.x] = val;
111+
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
117112
}
118113
}
119114
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
@@ -125,14 +120,7 @@ static __global__ void mul_mat_f(
125120
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
126121
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
127122
} else {
128-
float2 tmp = make_float2(0.0f, 0.0f);
129-
if (j < cols_per_block) {
130-
const int slot = slot_map[j];
131-
if (slot >= 0) {
132-
const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y);
133-
tmp = y2_slot[j*stride_col_y + col];
134-
}
135-
}
123+
float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
136124
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
137125
}
138126
}
@@ -221,7 +209,7 @@ static inline void mul_mat_f_switch_ids(
221209
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
222210
if (ids) {
223211
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
224-
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
212+
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
225213
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
226214
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
227215
} else {

ggml/src/ggml-sycl/binbcast.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,10 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
303303
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
304304
}
305305

306+
inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
307+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_count_equal>>(ctx, dst->src[0], dst->src[1], dst);
308+
}
309+
306310
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
307311

308312
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
@@ -328,6 +332,11 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
328332
ggml_sycl_op_sub(ctx, dst);
329333
}
330334

335+
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
336+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
337+
ggml_sycl_op_count_equal(ctx, dst);
338+
}
339+
331340
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
332341
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
333342
ggml_sycl_op_mul(ctx, dst);

ggml/src/ggml-sycl/binbcast.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ static __dpct_inline__ float op_sub(const float a, const float b) {
1616
return a - b;
1717
}
1818

19+
static __dpct_inline__ float op_count_equal(const float a, const float b) {
20+
return (a == b) ? 1.0f : 0.0f;
21+
}
22+
23+
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
24+
1925
static __dpct_inline__ float op_mul(const float a, const float b) {
2026
return a * b;
2127
}

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3577,6 +3577,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
35773577
case GGML_OP_SUB:
35783578
ggml_sycl_sub(ctx, dst);
35793579
break;
3580+
case GGML_OP_COUNT_EQUAL:
3581+
ggml_sycl_count_equal(ctx, dst);
3582+
break;
35803583
case GGML_OP_ACC:
35813584
ggml_sycl_acc(ctx, dst);
35823585
break;
@@ -4356,6 +4359,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
43564359
case GGML_OP_ADD:
43574360
case GGML_OP_ADD1:
43584361
case GGML_OP_SUB:
4362+
case GGML_OP_COUNT_EQUAL:
43594363
case GGML_OP_MUL:
43604364
case GGML_OP_DIV:
43614365
case GGML_OP_REPEAT:

tools/perplexity/perplexity.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1931,7 +1931,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
19311931
LOG("Maximum KLD: %10.6f\n", kld_values.back());
19321932
LOG("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f));
19331933
LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
1934-
LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
1934+
LOG("90.0%% KLD: %10.6f\n", percentile(kld_values, 0.900f));
19351935
LOG("Median KLD: %10.6f\n", kld_median);
19361936
LOG("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f));
19371937
LOG(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f));

0 commit comments

Comments
 (0)