Skip to content

Commit 4dfc773

Browse files
committed
issue/889 - added interface definitions
1 parent 3b5afff commit 4dfc773

File tree

23 files changed

+1060
-69
lines changed

23 files changed

+1060
-69
lines changed

include/infinicore/ops.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
#include "ops/add_rms_norm.hpp"
55
#include "ops/attention.hpp"
66
#include "ops/causal_softmax.hpp"
7+
#include "ops/embedding.hpp"
8+
#include "ops/flash_attention.hpp"
9+
#include "ops/kv_caching.hpp"
710
#include "ops/matmul.hpp"
811
#include "ops/ones.hpp"
912
#include "ops/paged_attention.hpp"
1013
#include "ops/paged_attention_prefill.hpp"
1114
#include "ops/paged_caching.hpp"
1215
#include "ops/random_sample.hpp"
16+
#include "ops/random_sample_batched.hpp"
1317
#include "ops/rearrange.hpp"
1418
#include "ops/rms_norm.hpp"
1519
#include "ops/rope.hpp"

include/infinicore/ops/embedding.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44

55
namespace infinicore::op {
66

7+
class Embedding {
8+
public:
9+
using schema = void (*)(Tensor, Tensor, Tensor);
10+
static void execute(Tensor out, Tensor input, Tensor weight);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
714
Tensor embedding(Tensor input, Tensor weight);
815
void embedding_(Tensor out, Tensor input, Tensor weight);
916
} // namespace infinicore::op
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
class FlashAttention {
8+
public:
9+
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, float, bool);
10+
static void execute(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
14+
Tensor flash_attention(Tensor q, Tensor k, Tensor v, float scale, bool is_causal);
15+
void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal);
16+
} // namespace infinicore::op
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#pragma
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
class KVCaching {
8+
public:
9+
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor);
10+
static void execute(Tensor k_cache,
11+
Tensor v_cache,
12+
Tensor k,
13+
Tensor v,
14+
Tensor offsets,
15+
Tensor past_kv_lengths,
16+
Tensor cache_ids);
17+
static common::OpDispatcher<schema> &dispatcher();
18+
};
19+
20+
Tensor kv_caching(Tensor k_cache,
21+
Tensor v_cache,
22+
Tensor k,
23+
Tensor v,
24+
Tensor offsets,
25+
Tensor past_kv_lengths,
26+
Tensor cache_ids);
27+
void kv_caching_(Tensor k_cache,
28+
Tensor v_cache,
29+
Tensor k,
30+
Tensor v,
31+
Tensor offsets,
32+
Tensor past_kv_lengths,
33+
Tensor cache_ids);
34+
} // namespace infinicore::op
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
8+
class RandomSampleBatched {
9+
public:
10+
using schema = void (*)(Tensor, Tensor, const float *, const float *, const int *, const float *, int);
11+
static void execute(Tensor result, Tensor probs, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size);
12+
static common::OpDispatcher<schema> &dispatcher();
13+
};
14+
15+
// Out-of-place API
16+
Tensor random_sample_batched(Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size);
17+
// In-place API
18+
void random_sample_batched_(Tensor indices, Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size);
19+
20+
} // namespace infinicore::op

include/infiniop.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
#include "infiniop/ops/clip.h"
1010
#include "infiniop/ops/conv.h"
1111
#include "infiniop/ops/dequantize_awq.h"
12+
#include "infiniop/ops/embedding.h"
13+
#include "infiniop/ops/flash_attention.h"
1214
#include "infiniop/ops/gelu.h"
1315
#include "infiniop/ops/gemm.h"
16+
#include "infiniop/ops/kv_caching.h"
1417
#include "infiniop/ops/layer_norm.h"
1518
#include "infiniop/ops/logsoftmax.h"
1619
#include "infiniop/ops/lp_norm.h"
@@ -20,6 +23,7 @@
2023
#include "infiniop/ops/paged_attention_prefill.h"
2124
#include "infiniop/ops/paged_caching.h"
2225
#include "infiniop/ops/random_sample.h"
26+
#include "infiniop/ops/random_sample_batched.h"
2327
#include "infiniop/ops/rearrange.h"
2428
#include "infiniop/ops/relu.h"
2529
#include "infiniop/ops/rms_norm.h"

include/infiniop/ops/embedding.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef __INFINIOP_EMBEDDING_API_H__
2+
#define __INFINIOP_EMBEDDING_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopEmbeddingDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateEmbeddingDescriptor(
9+
infiniopHandle_t handle,
10+
infiniopEmbeddingDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t output_desc,
12+
infiniopTensorDescriptor_t input_desc,
13+
infiniopTensorDescriptor_t weight_desc);
14+
15+
__C __export infiniStatus_t infiniopEmbedding(
16+
infiniopEmbeddingDescriptor_t desc,
17+
void *output,
18+
const void *input,
19+
const void *weight,
20+
void *stream);
21+
22+
__C __export infiniStatus_t infiniopDestroyEmbeddingDescriptor(
23+
infiniopEmbeddingDescriptor_t desc);
24+
25+
#endif
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#ifndef __INFINIOP_FLASH_ATTENTION_API_H__
2+
#define __INFINIOP_FLASH_ATTENTION_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopFlashAttentionDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor(
9+
infiniopHandle_t handle,
10+
infiniopFlashAttentionDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t out_desc,
12+
infiniopTensorDescriptor_t q_desc,
13+
infiniopTensorDescriptor_t k_desc,
14+
infiniopTensorDescriptor_t v_desc,
15+
float scale,
16+
char is_causal);
17+
18+
__C __export infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
19+
infiniopFlashAttentionDescriptor_t desc,
20+
size_t *size);
21+
22+
__C __export infiniStatus_t infiniopFlashAttention(
23+
infiniopFlashAttentionDescriptor_t desc,
24+
void *workspace,
25+
size_t workspace_size,
26+
void *out,
27+
const void *q,
28+
const void *k,
29+
const void *v,
30+
void *stream);
31+
32+
__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor(
33+
infiniopFlashAttentionDescriptor_t desc);
34+
#endif

include/infiniop/ops/kv_caching.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#ifndef __INFINIOP_KV_CACHING_API_H__
2+
#define __INFINIOP_KV_CACHING_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopKVCachingDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateKVCachingDescriptor(
9+
infiniopHandle_t handle,
10+
infiniopKVCachingDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t k_cache,
12+
infiniopTensorDescriptor_t v_cache,
13+
infiniopTensorDescriptor_t k,
14+
infiniopTensorDescriptor_t v,
15+
infiniopTensorDescriptor_t offsets,
16+
infiniopTensorDescriptor_t past_kv_lengths,
17+
infiniopTensorDescriptor_t cache_ids);
18+
19+
__C __export infiniStatus_t infiniopGetKVCachingWorkspaceSize(infiniopKVCachingDescriptor_t desc, size_t *size);
20+
21+
__C __export infiniStatus_t infiniopKVCaching(infiniopKVCachingDescriptor_t desc,
22+
void *workspace,
23+
size_t workspace_size,
24+
void *k_cache,
25+
void *v_cache,
26+
const void *k,
27+
const void *v,
28+
const void *offsets,
29+
const void *past_kv_lengths,
30+
const void *cache_ids,
31+
void *stream);
32+
33+
__C __export infiniStatus_t infiniopDestroyKVCachingDescriptor(infiniopKVCachingDescriptor_t desc);
34+
35+
#endif

include/infiniop/ops/random_sample.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@ __C __export infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
1515
infiniopRandomSampleDescriptor_t desc,
1616
size_t *size);
1717

18-
__C __export infiniStatus_t infiniopCreateRandomSampleBatchDescriptor(
19-
infiniopHandle_t handle,
20-
infiniopRandomSampleDescriptor_t *desc_ptr,
21-
infiniopTensorDescriptor_t result,
22-
infiniopTensorDescriptor_t probs);
23-
2418
__C __export infiniStatus_t infiniopRandomSample(
2519
infiniopRandomSampleDescriptor_t desc,
2620
void *workspace,

0 commit comments

Comments
 (0)