Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions ggml/include/ggml-alloc.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ struct ggml_tallocr {
void * base;
size_t alignment;
size_t offset;
size_t page_size_for_allocs; // if > 0, allocations by this linear allocator are rounded up to this size
};

GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer);
GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer); // User should ensure page_size_for_allocs is set if needed after creation, or a new constructor variant.
GGML_API enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor);

// Graph allocator
Expand Down Expand Up @@ -57,8 +58,19 @@ GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * g
GGML_API bool ggml_gallocr_reserve_n(
ggml_gallocr_t galloc,
struct ggml_cgraph * graph,
const int * node_buffer_ids,
const int * leaf_buffer_ids);
const int * node_buffer_ids, // maps graph node index to buffer_id
const int * leaf_buffer_ids); // maps graph leaf index to buffer_id

// internal measure allocator structure is defined in ggml-alloc.c, not exported directly
struct ggml_dyn_tallocr;

// Creates a new dynamic tensor allocator that manages allocations in pages.
// alignment: base alignment for the start of any allocated block (typically a page multiple itself or derived from page_size).
// page_size: the size of pages to manage. If 0, a default page size will be used.
GGML_API struct ggml_dyn_tallocr * ggml_dyn_tallocr_new_paged(size_t alignment, size_t page_size);
// Note: ggml_dyn_tallocr_new (for byte-based allocation) is already implicitly declared via its use in ggml_gallocr,
// but its definition is in ggml-alloc.c. To be fully explicit, it could be added here too if it were public.
// For now, only adding the new paged constructor.

// automatic reallocation if the topology changes when using a single buffer
// returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers)
Expand Down
7 changes: 5 additions & 2 deletions ggml/include/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ extern "C" {
GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size);
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft);
GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft);
GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); // size of tensor data in bytes, including padding
GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft); // true if the buffer is allocated in host memory
// Returns page size if this buffer type should use a paged allocator, 0 otherwise.
// If NULL, it's assumed not paged (returns 0).
GGML_API size_t ggml_backend_buft_get_page_size (ggml_backend_buffer_type_t buft); // NEW
GGML_API ggml_backend_dev_t ggml_backend_buft_get_device (ggml_backend_buffer_type_t buft);

//
Expand Down
253 changes: 183 additions & 70 deletions ggml/src/ggml-alloc.c

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions ggml/src/ggml-backend-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ extern "C" {
size_t (*get_alloc_size)(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
// (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false)
bool (*is_host) (ggml_backend_buffer_type_t buft);
// (optional) Returns page size if this buffer type should use a paged allocator, 0 otherwise.
// If NULL, it's assumed not paged (returns 0).
size_t (*get_page_size) (ggml_backend_buffer_type_t buft);
};

struct ggml_backend_buffer_type {
Expand Down
183 changes: 183 additions & 0 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,27 @@ typedef void (* fattn_kernel_t)(
const int ne2,
const int ne3);

// New type for paged attention kernels
typedef void (* fattn_paged_kernel_t)(
const char * __restrict__ Q_data,
const paged_kv_sequence_view_gpu k_view_params,
const paged_kv_sequence_view_gpu v_view_params,
const char * __restrict__ mask_data,
float * __restrict__ dst_data,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0, const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int q_ne0, const int q_ne1, const int q_ne2, const int q_ne3, // Q dims
const int k_ne2, // num_kv_heads from K_meta_tensor->ne[2]
const int q_nb1, const int q_nb2, const int q_nb3, // Q byte strides
const int mask_k_seq_len, // mask_tensor ? mask_tensor->ne[1] : 0
const int mask_k_stride_bytes, // mask_tensor ? mask_tensor->nb[1] : 0
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3 // Dst dims (usually from Q)
);

typedef half (*vec_dot_KQ_f16_t)(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
typedef float (*vec_dot_KQ_f32_t)(
Expand Down Expand Up @@ -879,3 +900,165 @@ void launch_fattn(
}
CUDA_CHECK(cudaGetLastError());
}

// Launcher for PAGED kernels
template <int DV, int ncols1, int ncols2>
void launch_fattn_paged(
ggml_backend_cuda_context & ctx,
ggml_tensor * dst_tensor, // The output tensor, also contains Q, K_meta, V_meta, mask, op_params
const paged_kv_sequence_view_gpu & k_view, // Paged K view
const paged_kv_sequence_view_gpu & v_view, // Paged V view
fattn_paged_kernel_t paged_kernel, // Pointer to the __global__ paged kernel
const int nwarps, const size_t nbytes_shared,
const int KQ_row_granularity, // May not be directly used if k_view.num_tokens is used for iter_k
// but useful for block calculation or assertions.
const bool stream_k // For fixup logic, might need rethink for paged.
// const int warp_size = WARP_SIZE // Already available globally
) {
constexpr int ncols = ncols1 * ncols2;
// const bool is_mla = DV == 512; // TODO: Pass DKQ if needed for this or get from Q->ne[0]

const ggml_tensor * Q_tensor = dst_tensor->src[0];
const ggml_tensor * K_meta_tensor = dst_tensor->src[1]; // For metadata like n_head_k
// const ggml_tensor * V_meta_tensor = dst_tensor->src[2]; // For metadata like n_head_v
const ggml_tensor * mask_tensor = dst_tensor->src[3];
ggml_tensor * KQV_tensor = dst_tensor; // Output tensor, also source of op_params

GGML_ASSERT(Q_tensor->type == GGML_TYPE_F32); // Kernels expect Q as F32 (then convert to F16 if needed)
GGML_ASSERT(KQV_tensor->type == GGML_TYPE_F32);
GGML_ASSERT(Q_tensor->nb[0] == ggml_element_size(Q_tensor));
// K and V data are paged, so their ggml_tensor structs are for metadata.
GGML_ASSERT(k_view.dtype == GGML_TYPE_F16 || ggml_is_quantized(k_view.dtype)); // Example assertion
GGML_ASSERT(v_view.dtype == GGML_TYPE_F16 || ggml_is_quantized(v_view.dtype));


GGML_ASSERT(!mask_tensor || mask_tensor->type == GGML_TYPE_F16);
GGML_ASSERT(!mask_tensor || mask_tensor->ne[1] >= GGML_PAD(Q_tensor->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");

// k_view.num_tokens_in_logical_sequence replaces K_meta_tensor->ne[1] for sequence length
GGML_ASSERT(k_view.num_tokens_in_logical_sequence % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding for paged view.");
GGML_ASSERT(Q_tensor->ne[3] == 1); // Assuming batch_size = 1 for Q for now in this launcher context

cudaStream_t main_stream = ctx.stream();
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int nsm = ggml_cuda_info().devices[id].nsm;

ggml_cuda_pool_alloc<float2> dst_tmp_meta(ctx.pool()); // For fixup data if needed

// Note: K_f16, V_f16, dst_tmp are not needed here as K/V data comes from pages,
// and output is directly to dst_tensor->data (or its paged equivalent if dst were also paged).

const char * Q_data_ptr = (const char *) Q_tensor->data;
const char * mask_data_ptr = mask_tensor ? (const char *)mask_tensor->data : nullptr;
float * dst_data_ptr = (float *) KQV_tensor->data;

int parallel_blocks = 1; // As in original launch_fattn
const int ntiles_x = ((Q_tensor->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q_tensor->ne[2] / ncols2) * Q_tensor->ne[3];

const dim3 block_dim(WARP_SIZE, nwarps, 1);
int max_blocks_per_sm = 1;
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, (const void*)paged_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));

dim3 blocks_num;
if (stream_k) { // stream_k logic might need re-evaluation for paged attention
const int max_blocks = max_blocks_per_sm*nsm;
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
const int nblocks_stream_k = max_blocks;
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
blocks_num.y = 1;
blocks_num.z = 1;
// This allocation for dst_tmp_meta might be needed if fixup path is taken
dst_tmp_meta.alloc(blocks_num.x * ncols * (2*2 + DV) * sizeof(float)); // DV needs to be passed or templated
} else {
GGML_ASSERT(k_view.num_tokens_in_logical_sequence % KQ_row_granularity == 0);
const int ntiles_KQ = k_view.num_tokens_in_logical_sequence / KQ_row_granularity;
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
// ... (efficiency logic for parallel_blocks from original launch_fattn) ...
blocks_num.x = ntiles_x;
blocks_num.y = parallel_blocks; // This seems to be for splitting work over K sequence length, may not apply directly
blocks_num.z = Q_tensor->ne[2] * Q_tensor->ne[3]; // n_q_heads * batch_size_q
if (parallel_blocks > 1) {
// dst_tmp for combining partial results if parallel_blocks > 1
// This needs careful thought with paged KV, as dst is usually contiguous.
// dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV_tensor));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV_tensor)); // nrows is q_seq_len * n_q_heads * batch_size
}
}

float scale_param = 1.0f, max_bias_param = 0.0f, logit_softcap_param = 0.0f;
memcpy(&scale_param, (const float *) KQV_tensor->op_params + 0, sizeof(float));
memcpy(&max_bias_param, (const float *) KQV_tensor->op_params + 1, sizeof(float));
memcpy(&logit_softcap_param, (const float *) KQV_tensor->op_params + 2, sizeof(float));
if (logit_softcap_param != 0.0f) {
scale_param /= logit_softcap_param;
}

const uint32_t n_head_q_val = Q_tensor->ne[2];
const uint32_t n_head_log2_val = (n_head_q_val > 0) ? (1u << uint32_t(floorf(log2f(float(n_head_q_val))))) : 0;
const float m0_val = powf(2.0f, -(max_bias_param) / (n_head_log2_val ? n_head_log2_val : 1.0f) ); // Avoid div by zero
const float m1_val = powf(2.0f, -(max_bias_param / 2.0f) / (n_head_log2_val ? n_head_log2_val : 1.0f));

// Q strides in bytes, as expected by the kernel for pointer arithmetic
const int q_nb1 = Q_tensor->nb[1];
const int q_nb2 = Q_tensor->nb[2];
const int q_nb3 = Q_tensor->nb[3];

// Dst tensor parameters for the kernel (ne0, ne1, ne2, ne3 for Dst)
const int dst_ne0_param = Q_tensor->ne[0]; // D
const int dst_ne1_param = Q_tensor->ne[1]; // n_q
const int dst_ne2_param = Q_tensor->ne[2]; // n_heads_q
const int dst_ne3_param = Q_tensor->ne[3]; // batch_size_q

// Mask parameters
const int mask_k_seq_len_param = mask_tensor ? mask_tensor->ne[1] : 0;
const int mask_k_stride_bytes_param = mask_tensor ? mask_tensor->nb[1] : 0;


// Simplified parallel_blocks logic for now (assume 1)
// If parallel_blocks > 1, dst_tmp_meta pointer needs careful offsetting.
// For parallel_blocks = 1, kernel's dst_meta usage for multi-pass reduction is skipped.
float2* final_dst_meta_ptr = dst_tmp_meta.ptr; // No offsetting if parallel_blocks = 1 and gridDim.y=1 in kernel
if (parallel_blocks > 1) {
// TODO: Implement proper dst_meta pointer offsetting for the kernel if parallel_blocks > 1.
// This would involve calculating an offset based on blockIdx.x and blockIdx.z from the grid
// and the total number of Q heads and Q sequence length, to point to the correct
// segment of dst_tmp_meta.ptr for the current Q-block and Q-head.
// For now, this path will likely be incorrect if parallel_blocks > 1.
// GGML_LOG_WARN("Paged attention with parallel_blocks > 1 for K/V sequence might lead to incorrect dst_meta handling.\n");
}


GGML_ASSERT(block_dim.x % WARP_SIZE == 0);
paged_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
Q_data_ptr,
k_view, v_view,
mask_data_ptr, // Base pointer for mask. Kernel uses ic0 and its own logic for indexing.
dst_data_ptr,
final_dst_meta_ptr, // Potentially offset pointer if parallel_blocks > 1 (currently base ptr)
scale_param, max_bias_param, m0_val, m1_val, n_head_log2_val, logit_softcap_param,
Q_tensor->ne[0], Q_tensor->ne[1], Q_tensor->ne[2], Q_tensor->ne[3], // Q dims
K_meta_tensor ? K_meta_tensor->ne[2] : Q_tensor->ne[2], // num_kv_heads
q_nb1, q_nb2, q_nb3, // Q byte strides
mask_k_seq_len_param, mask_k_stride_bytes_param, // Mask K-dim and K-stride
dst_ne0_param, dst_ne1_param, dst_ne2_param, dst_ne3_param // Dst dims
);
CUDA_CHECK(cudaGetLastError());

// Post-launch fixup logic (e.g., flash_attn_stream_k_fixup or flash_attn_combine_results)
// would need to be adapted if used with paged attention, especially if parallel_blocks > 1.
// This simplified launcher does not include that for now.
if (stream_k && (ntiles_total % blocks_num.x != 0)) {
// ... call paged version of flash_attn_stream_k_fixup ...
GGML_LOG_WARN("Stream K fixup for paged attention not implemented yet.\n");
} else if (!stream_k && parallel_blocks > 1) {
GGML_LOG_WARN("Parallel blocks combine for paged attention not implemented yet.\n");
}
}

#endif // FATTN_COMMON_CUH
Loading