Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/flash_attention.hpp"
#include "ops/kv_caching.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/random_sample_batched.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
#include "ops/rope.hpp"
Expand Down
7 changes: 7 additions & 0 deletions include/infinicore/ops/embedding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@

namespace infinicore::op {

class Embedding {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor out, Tensor input, Tensor weight);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor embedding(Tensor input, Tensor weight);
void embedding_(Tensor out, Tensor input, Tensor weight);
} // namespace infinicore::op
16 changes: 16 additions & 0 deletions include/infinicore/ops/flash_attention.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class FlashAttention {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, float, bool);
static void execute(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor flash_attention(Tensor q, Tensor k, Tensor v, float scale, bool is_causal);
void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal);
} // namespace infinicore::op
28 changes: 28 additions & 0 deletions include/infinicore/ops/kv_caching.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class KVCaching {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
static void execute(Tensor k_cache,
Tensor v_cache,
Tensor k,
Tensor v,
Tensor past_kv_lengths);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor kv_caching(Tensor k_cache,
Tensor v_cache,
Tensor k,
Tensor v,
Tensor past_kv_lengths);
void kv_caching_(Tensor k_cache,
Tensor v_cache,
Tensor k,
Tensor v,
Tensor past_kv_lengths);
} // namespace infinicore::op
20 changes: 20 additions & 0 deletions include/infinicore/ops/random_sample_batched.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {

class RandomSampleBatched {
public:
using schema = void (*)(Tensor, Tensor, const float *, const float *, const int *, const float *, int);
static void execute(Tensor result, Tensor probs, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size);
static common::OpDispatcher<schema> &dispatcher();
};

// Out-of-place API
Tensor random_sample_batched(Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size);
// In-place API
void random_sample_batched_(Tensor indices, Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size);

} // namespace infinicore::op
4 changes: 4 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
#include "infiniop/ops/clip.h"
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/embedding.h"
#include "infiniop/ops/flash_attention.h"
#include "infiniop/ops/gelu.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/kv_caching.h"
#include "infiniop/ops/layer_norm.h"
#include "infiniop/ops/logsoftmax.h"
#include "infiniop/ops/lp_norm.h"
Expand All @@ -20,6 +23,7 @@
#include "infiniop/ops/paged_attention_prefill.h"
#include "infiniop/ops/paged_caching.h"
#include "infiniop/ops/random_sample.h"
#include "infiniop/ops/random_sample_batched.h"
#include "infiniop/ops/rearrange.h"
#include "infiniop/ops/relu.h"
#include "infiniop/ops/rms_norm.h"
Expand Down
25 changes: 25 additions & 0 deletions include/infiniop/ops/embedding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef __INFINIOP_EMBEDDING_API_H__
#define __INFINIOP_EMBEDDING_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopEmbeddingDescriptor_t;

__C __export infiniStatus_t infiniopCreateEmbeddingDescriptor(
infiniopHandle_t handle,
infiniopEmbeddingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc);

__C __export infiniStatus_t infiniopEmbedding(
infiniopEmbeddingDescriptor_t desc,
void *output,
const void *input,
const void *weight,
void *stream);

__C __export infiniStatus_t infiniopDestroyEmbeddingDescriptor(
infiniopEmbeddingDescriptor_t desc);

#endif
34 changes: 34 additions & 0 deletions include/infiniop/ops/flash_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef __INFINIOP_FLASH_ATTENTION_API_H__
#define __INFINIOP_FLASH_ATTENTION_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopFlashAttentionDescriptor_t;

__C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor(
infiniopHandle_t handle,
infiniopFlashAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
float scale,
char is_causal);

__C __export infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
infiniopFlashAttentionDescriptor_t desc,
size_t *size);

__C __export infiniStatus_t infiniopFlashAttention(
infiniopFlashAttentionDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k,
const void *v,
void *stream);

__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor(
infiniopFlashAttentionDescriptor_t desc);
#endif
31 changes: 31 additions & 0 deletions include/infiniop/ops/kv_caching.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef __INFINIOP_KV_CACHING_API_H__
#define __INFINIOP_KV_CACHING_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopKVCachingDescriptor_t;

__C __export infiniStatus_t infiniopCreateKVCachingDescriptor(
infiniopHandle_t handle,
infiniopKVCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths);

__C __export infiniStatus_t infiniopGetKVCachingWorkspaceSize(infiniopKVCachingDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopKVCaching(infiniopKVCachingDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *past_kv_lengths,
void *stream);

__C __export infiniStatus_t infiniopDestroyKVCachingDescriptor(infiniopKVCachingDescriptor_t desc);

#endif
6 changes: 0 additions & 6 deletions include/infiniop/ops/random_sample.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ __C __export infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
infiniopRandomSampleDescriptor_t desc,
size_t *size);

__C __export infiniStatus_t infiniopCreateRandomSampleBatchDescriptor(
infiniopHandle_t handle,
infiniopRandomSampleDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t result,
infiniopTensorDescriptor_t probs);

__C __export infiniStatus_t infiniopRandomSample(
infiniopRandomSampleDescriptor_t desc,
void *workspace,
Expand Down
34 changes: 34 additions & 0 deletions include/infiniop/ops/random_sample_batched.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef __INFINIOP_RANDOM_SAMPLE_BATCHED_API_H__
#define __INFINIOP_RANDOM_SAMPLE_BATCHED_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopRandomSampleBatchedDescriptor_t;

__C __export infiniStatus_t infiniopCreateRandomSampleBatchedDescriptor(
infiniopHandle_t handle,
infiniopRandomSampleBatchedDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t result,
infiniopTensorDescriptor_t probs);

__C __export infiniStatus_t infiniopGetRandomSampleBatchedWorkspaceSize(
infiniopRandomSampleBatchedDescriptor_t desc,
size_t *size);

__C __export infiniStatus_t infiniopRandomSampleBatched(
infiniopRandomSampleBatchedDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *result,
const void *probs,
const float *random_val,
const float *topp,
const int *topk,
const float *temperature,
int batch_size,
void *stream);

__C __export infiniStatus_t infiniopDestroyRandomSampleBatchedDescriptor(
infiniopRandomSampleBatchedDescriptor_t desc);

#endif
2 changes: 2 additions & 0 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from .random_sample import random_sample
from .rms_norm import rms_norm
from .rope import RopeAlgo, rope
from .scaled_dot_product_attention import scaled_dot_product_attention
from .silu import silu
from .swiglu import swiglu

__all__ = [
"causal_softmax",
"random_sample",
"rms_norm",
"scaled_dot_product_attention",
"silu",
"swiglu",
"linear",
Expand Down
28 changes: 28 additions & 0 deletions python/infinicore/nn/functional/scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import math

from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0,
is_causal=False,
scale=None,
enable_gqa=False,
):
assert attn_mask is None and dropout_p == 0 and not enable_gqa

emb_dim = query.shape[-1]

if scale is None:
scale = 1 / math.sqrt(emb_dim)

return Tensor(
_infinicore.flash_attention(
query._underlying, key._underlying, value._underlying, scale, is_causal
)
)
28 changes: 20 additions & 8 deletions scripts/build_ntops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import concurrent.futures
import importlib
import pathlib

Expand All @@ -11,16 +12,27 @@
def _find_and_build_ops():
ops_path = SRC_DIR_PATH / "infiniop" / "ops"

for op_dir in ops_path.iterdir():
ninetoothed_path = op_dir / "ninetoothed"
with concurrent.futures.ProcessPoolExecutor() as executor:
futures = []

if ninetoothed_path.is_dir():
module_path = ninetoothed_path / "build"
relative_path = module_path.relative_to(SRC_DIR_PATH)
import_name = ".".join(relative_path.parts)
module = importlib.import_module(import_name)
for op_dir in ops_path.iterdir():
ninetoothed_path = op_dir / "ninetoothed"

module.build()
if not ninetoothed_path.is_dir():
continue

futures.append(executor.submit(_build, ninetoothed_path))

concurrent.futures.as_completed(futures)


def _build(ninetoothed_path):
module_path = ninetoothed_path / "build"
relative_path = module_path.relative_to(SRC_DIR_PATH)
import_name = ".".join(relative_path.parts)
module = importlib.import_module(import_name)

module.build()


if __name__ == "__main__":
Expand Down
Loading