-
Notifications
You must be signed in to change notification settings - Fork 108
Expand file tree
/
Copy pathpaged_attention_prefill.h
More file actions
96 lines (89 loc) · 3.83 KB
/
paged_attention_prefill.h
File metadata and controls
96 lines (89 loc) · 3.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#ifndef __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__
#define __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__
#include "../operator_descriptor.h"
// Define an opaque handle for the Paged Attention descriptor.
typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t;
/**
* @brief Creates a descriptor for the Paged Attention v1 operation.
*
* This function initializes a descriptor that holds all the metadata needed
* for the paged attention computation.
*
* @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param out_desc Descriptor for the output tensor.
* @param q_desc Descriptor for the query tensor.
* @param k_cache_desc Descriptor for the key cache tensor.
* @param v_cache_desc Descriptor for the value cache tensor.
* @param block_tables_desc Descriptor for the block tables tensor.
* @param seq_lens_desc Descriptor for the sequence lengths tensor.
* @param seq_offsets_desc
* @param cache_lens_desc Descriptor for the cache lengths tensor.
* @param alibi_slopes_desc Optional descriptor for the ALiBi slopes tensor. Can be NULL.
* @param scale The attention scaling factor.
* @param max_num_blocks_per_seq The maximum number of batched blocks tables.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopHandle_t handle,
infiniopPagedAttentionPrefillDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t seq_offsets_desc,
infiniopTensorDescriptor_t cache_lens_desc,
infiniopTensorDescriptor_t alibi_slopes_desc,
float scale);
/**
* @brief Retrieves the workspace size required for the Paged Attention operation.
*
* @param desc The Paged Attention descriptor.
* @param size A pointer to store the required workspace size in bytes.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
infiniopPagedAttentionPrefillDescriptor_t desc, size_t *size);
/**
* @brief Executes the Paged Attention v1 operation.
*
* @param desc The Paged Attention descriptor.
* @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace.
* @param out Pointer to the output tensor data.
* @param q Pointer to the query tensor data.
* @param k_cache Pointer to the key cache data.
* @param v_cache Pointer to the value cache data.
* @param block_tables Pointer to the block tables data.
* @param seq_lens Pointer to the sequence lengths data.
* @param seq_offsets Pointer to the sequence offsets data.
* @param cache_lens Pointer to the sequence lengths data.
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
* @param stream The CUDA stream for the operation. Can be NULL.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopPagedAttentionPrefill(
infiniopPagedAttentionPrefillDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *seq_offsets,
const void *cache_lens,
const void *alibi_slopes,
void *stream);
/**
* @brief Destroys a Paged Attention descriptor.
*
* @param desc The descriptor to be destroyed.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
infiniopPagedAttentionPrefillDescriptor_t desc);
#endif // __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__