Skip to content

Commit ab7c007

Browse files
committed
-p512: 8.4k->9.5k - Account for DataPadding for writing tile_y
1 parent b55d44a commit ab7c007

File tree

2 files changed

+52
-44
lines changed

2 files changed

+52
-44
lines changed

.devops/rocm.Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
ARG UBUNTU_VERSION=24.04
22

33
# This needs to generally match the container host's environment.
4-
ARG ROCM_VERSION=6.3
5-
ARG AMDGPU_VERSION=6.3
4+
ARG ROCM_VERSION=6.4
5+
ARG AMDGPU_VERSION=6.4
66

77
# Target the CUDA build image
88
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ struct tile_x_sizes {
9090
};
9191

9292
static int get_mmq_x_max_host(const int cc) {
93-
return amd_mma_available(cc) ? 64 : new_mma_available(cc) ? 128 :
93+
return (amd_mma_available(cc) || new_mma_available(cc)) ? 128 :
9494
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
9595
#ifdef GGML_CUDA_FORCE_MMQ
9696
128 : 64;
@@ -100,12 +100,9 @@ static int get_mmq_x_max_host(const int cc) {
100100
}
101101

102102
static constexpr __device__ int get_mmq_x_max_device() {
103-
#if defined(AMD_MMA_AVAILABLE)
104-
return 64;
105-
#else
106-
#if defined(NEW_MMA_AVAILABLE)
103+
#if defined(AMD_MMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
107104
return 128;
108-
#else // defined(NEW_MMA_AVAILABLE)
105+
#else // defined(AMD_MMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
109106

110107
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
111108
return 64;
@@ -122,8 +119,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
122119
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
123120

124121
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
125-
#endif // defined(NEW_MMA_AVAILABLE)
126-
#endif // defined(AMD_MMA_AVAILABLE)
122+
#endif // defined(AMD_MMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
127123
}
128124

129125
static int get_mmq_y_host(const int cc) {
@@ -1666,37 +1662,35 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
16661662
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
16671663
int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4;
16681664

1669-
if (i < mmq_y) {
1670-
if (need_check) {
1671-
i = min(i, i_max);
1672-
}
1665+
if (need_check) {
1666+
i = min(i, i_max);
1667+
}
16731668

1674-
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1669+
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
16751670

1676-
const int ksc = threadIdx.x % 4;
1671+
const int ksc = threadIdx.x % 4;
16771672

1678-
const int ksc_low = ksc % (QI3_K/8);
1679-
const int shift_low = 4 * (ksc / (QI3_K/8));
1680-
const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
1673+
const int ksc_low = ksc % (QI3_K/8);
1674+
const int shift_low = 4 * (ksc / (QI3_K/8));
1675+
const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
16811676

1682-
const int ksc_high = QI3_K/8;
1683-
const int shift_high = 2 * ksc;
1684-
const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
1677+
const int ksc_high = QI3_K/8;
1678+
const int shift_high = 2 * ksc;
1679+
const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
16851680

1686-
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1681+
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
16871682

16881683
#if defined(AMD_MMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1689-
const int8_t * sc8 = (const int8_t *) &sc;
1690-
const float d = bxi->d;
1684+
const int8_t * sc8 = (const int8_t *) &sc;
1685+
const float d = bxi->d;
16911686

16921687
#pragma unroll
1693-
for (int l = 0; l < int(sizeof(int)); ++l) {
1694-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l];
1695-
}
1688+
for (int l = 0; l < int(sizeof(int)); ++l) {
1689+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l];
1690+
}
16961691
#else
1697-
x_sc[i*4 + i/8 + ksc] = sc;
1692+
x_sc[i*4 + i/8 + ksc] = sc;
16981693
#endif // NEW_MMA_AVAILABLE
1699-
}
17001694
}
17011695

17021696
#if !(defined(AMD_MMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE))
@@ -1802,9 +1796,15 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
18021796
constexpr int rows_per_warp = warp_size / 2;
18031797
#pragma unroll
18041798
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1799+
#if defined(AMD_MMA_AVAILABLE)
1800+
// Need if on AMD instead of % because warp_size == 64
1801+
// This causes double work and throughput loss (MI300X)
1802+
// H100 loses about 100 t/s with 'if' condition over '%'
18051803
int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
1806-
18071804
if (i < mmq_y) {
1805+
#else
1806+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
1807+
#endif // defined(AMD_MMA_AVAILABLE)
18081808
if (need_check) {
18091809
i = min(i, i_max);
18101810
}
@@ -1826,7 +1826,9 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
18261826
for (int l = 0; l < sizeof(int); ++l) {
18271827
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
18281828
}
1829+
#if defined(AMD_MMA_AVAILABLE)
18291830
}
1831+
#endif // defined(AMD_MMA_AVAILABLE)
18301832
}
18311833
#else
18321834
#pragma unroll
@@ -1951,9 +1953,15 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
19511953
constexpr int rows_per_warp = warp_size / 2;
19521954
#pragma unroll
19531955
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1956+
#if defined(AMD_MMA_AVAILABLE)
1957+
// Need if on AMD instead of % because warp_size == 64
1958+
// This causes double work and throughput loss (MI300X)
1959+
// H100 loses about 100 t/s with 'if' condition over '%'
19541960
int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
1955-
19561961
if (i < mmq_y) {
1962+
#else
1963+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
1964+
#endif // defined(AMD_MMA_AVAILABLE)
19571965
if (need_check) {
19581966
i = min(i, i_max);
19591967
}
@@ -1975,7 +1983,9 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
19751983
for (int l = 0; l < int(sizeof(int)); ++l) {
19761984
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
19771985
}
1986+
#if defined(AMD_MMA_AVAILABLE)
19781987
}
1988+
#endif // defined(AMD_MMA_AVAILABLE)
19791989
}
19801990
#else
19811991
#pragma unroll
@@ -2117,21 +2127,19 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
21172127
constexpr int rows_per_warp = warp_size / 4;
21182128
#pragma unroll
21192129
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2120-
int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4;
2130+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/4) % mmq_y;
21212131

2122-
if (i < mmq_y) {
2123-
if (need_check) {
2124-
i = min(i, i_max);
2125-
}
2132+
if (need_check) {
2133+
i = min(i, i_max);
2134+
}
21262135

2127-
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % 4) / 4;
2136+
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % 4) / 4;
21282137

21292138
#if defined(AMD_MMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2130-
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x%4);
2139+
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x%4);
21312140
#else
2132-
x_sc[i*4 + i/8 + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x%4);
2141+
x_sc[i*4 + i/8 + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x%4);
21332142
#endif // NEW_MMA_AVAILABLE
2134-
}
21352143
}
21362144
}
21372145

@@ -3096,7 +3104,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
30963104
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
30973105
#pragma unroll
30983106
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
3099-
int l = (l0 + threadIdx.y*warp_size + threadIdx.x) % (mmq_x*MMQ_TILE_Y_K);
3107+
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
31003108

31013109
tile_y[l] = by0[l];
31023110
}
@@ -3112,7 +3120,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
31123120
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
31133121
#pragma unroll
31143122
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
3115-
int l = (l0 + threadIdx.y*warp_size + threadIdx.x) % (mmq_x*MMQ_TILE_Y_K);
3123+
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
31163124

31173125
tile_y[l] = by0[l];
31183126
}
@@ -3186,7 +3194,7 @@ static __global__ void mul_mat_q(
31863194
__syncthreads();
31873195

31883196
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
3189-
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(AMD_MMA_AVAILABLE)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3197+
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
31903198
{
31913199
const int wt = blockIdx.z / nchannels_y;
31923200
const int zt = blockIdx.z - wt*nchannels_y;

0 commit comments

Comments
 (0)