diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index a7249ec9d..3fb47d383 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -4,12 +4,16 @@ #include "ops/add_rms_norm.hpp" #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" +#include "ops/embedding.hpp" +#include "ops/flash_attention.hpp" +#include "ops/kv_caching.hpp" #include "ops/matmul.hpp" #include "ops/ones.hpp" #include "ops/paged_attention.hpp" #include "ops/paged_attention_prefill.hpp" #include "ops/paged_caching.hpp" #include "ops/random_sample.hpp" +#include "ops/random_sample_batched.hpp" #include "ops/rearrange.hpp" #include "ops/rms_norm.hpp" #include "ops/rope.hpp" diff --git a/include/infinicore/ops/embedding.hpp b/include/infinicore/ops/embedding.hpp index 4fd9991c4..6be997134 100644 --- a/include/infinicore/ops/embedding.hpp +++ b/include/infinicore/ops/embedding.hpp @@ -4,6 +4,13 @@ namespace infinicore::op { +class Embedding { +public: + using schema = void (*)(Tensor, Tensor, Tensor); + static void execute(Tensor out, Tensor input, Tensor weight); + static common::OpDispatcher &dispatcher(); +}; + Tensor embedding(Tensor input, Tensor weight); void embedding_(Tensor out, Tensor input, Tensor weight); } // namespace infinicore::op diff --git a/include/infinicore/ops/flash_attention.hpp b/include/infinicore/ops/flash_attention.hpp new file mode 100644 index 000000000..957255192 --- /dev/null +++ b/include/infinicore/ops/flash_attention.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class FlashAttention { +public: + using schema = void (*)(Tensor, Tensor, Tensor, Tensor, float, bool); + static void execute(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal); + static common::OpDispatcher &dispatcher(); +}; + +Tensor flash_attention(Tensor q, Tensor k, Tensor v, float scale, bool is_causal); +void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal); +} // namespace infinicore::op diff --git a/include/infinicore/ops/kv_caching.hpp b/include/infinicore/ops/kv_caching.hpp new file mode 100644 index 000000000..e4b6f514c --- /dev/null +++ b/include/infinicore/ops/kv_caching.hpp @@ -0,0 +1,28 @@ +#pragma + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class KVCaching { +public: + using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor); + static void execute(Tensor k_cache, + Tensor v_cache, + Tensor k, + Tensor v, + Tensor past_kv_lengths); + static common::OpDispatcher &dispatcher(); +}; + +Tensor kv_caching(Tensor k_cache, + Tensor v_cache, + Tensor k, + Tensor v, + Tensor past_kv_lengths); +void kv_caching_(Tensor k_cache, + Tensor v_cache, + Tensor k, + Tensor v, + Tensor past_kv_lengths); +} // namespace infinicore::op diff --git a/include/infinicore/ops/random_sample_batched.hpp b/include/infinicore/ops/random_sample_batched.hpp new file mode 100644 index 000000000..8906bc12b --- /dev/null +++ b/include/infinicore/ops/random_sample_batched.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class RandomSampleBatched { +public: + using schema = void (*)(Tensor, Tensor, const float *, const float *, const int *, const float *, int); + static void execute(Tensor result, Tensor probs, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size); + static common::OpDispatcher &dispatcher(); +}; + +// Out-of-place API +Tensor random_sample_batched(Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size); +// In-place API +void random_sample_batched_(Tensor indices, Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size); + +} // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index c0a09fcb4..ca42e1509 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -9,8 +9,11 @@ #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" #include "infiniop/ops/dequantize_awq.h" +#include "infiniop/ops/embedding.h" +#include "infiniop/ops/flash_attention.h" #include "infiniop/ops/gelu.h" #include "infiniop/ops/gemm.h" +#include "infiniop/ops/kv_caching.h" #include "infiniop/ops/layer_norm.h" #include "infiniop/ops/logsoftmax.h" #include "infiniop/ops/lp_norm.h" @@ -20,6 +23,7 @@ #include "infiniop/ops/paged_attention_prefill.h" #include "infiniop/ops/paged_caching.h" #include "infiniop/ops/random_sample.h" +#include "infiniop/ops/random_sample_batched.h" #include "infiniop/ops/rearrange.h" #include "infiniop/ops/relu.h" #include "infiniop/ops/rms_norm.h" diff --git a/include/infiniop/ops/embedding.h b/include/infiniop/ops/embedding.h new file mode 100644 index 000000000..cd1df3a73 --- /dev/null +++ b/include/infiniop/ops/embedding.h @@ -0,0 +1,25 @@ +#ifndef __INFINIOP_EMBEDDING_API_H__ +#define __INFINIOP_EMBEDDING_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopEmbeddingDescriptor_t; + +__C __export infiniStatus_t infiniopCreateEmbeddingDescriptor( + infiniopHandle_t handle, + infiniopEmbeddingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc); + +__C __export infiniStatus_t infiniopEmbedding( + infiniopEmbeddingDescriptor_t desc, + void *output, + const void *input, + const void *weight, + void *stream); + +__C __export infiniStatus_t infiniopDestroyEmbeddingDescriptor( + infiniopEmbeddingDescriptor_t desc); + +#endif diff --git a/include/infiniop/ops/flash_attention.h b/include/infiniop/ops/flash_attention.h new file mode 100644 index 000000000..06c3ff47c --- /dev/null +++ b/include/infiniop/ops/flash_attention.h @@ -0,0 +1,34 @@ +#ifndef __INFINIOP_FLASH_ATTENTION_API_H__ +#define __INFINIOP_FLASH_ATTENTION_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopFlashAttentionDescriptor_t; + +__C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor( + infiniopHandle_t handle, + infiniopFlashAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + float scale, + char is_causal); + +__C __export infiniStatus_t infiniopGetFlashAttentionWorkspaceSize( + infiniopFlashAttentionDescriptor_t desc, + size_t *size); + +__C __export infiniStatus_t infiniopFlashAttention( + infiniopFlashAttentionDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k, + const void *v, + void *stream); + +__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor( + infiniopFlashAttentionDescriptor_t desc); +#endif diff --git a/include/infiniop/ops/kv_caching.h b/include/infiniop/ops/kv_caching.h new file mode 100644 index 000000000..e6efa48b3 --- /dev/null +++ b/include/infiniop/ops/kv_caching.h @@ -0,0 +1,31 @@ +#ifndef __INFINIOP_KV_CACHING_API_H__ +#define __INFINIOP_KV_CACHING_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopKVCachingDescriptor_t; + +__C __export infiniStatus_t infiniopCreateKVCachingDescriptor( + infiniopHandle_t handle, + infiniopKVCachingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t k_cache, + infiniopTensorDescriptor_t v_cache, + infiniopTensorDescriptor_t k, + infiniopTensorDescriptor_t v, + infiniopTensorDescriptor_t past_kv_lengths); + +__C __export infiniStatus_t infiniopGetKVCachingWorkspaceSize(infiniopKVCachingDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopKVCaching(infiniopKVCachingDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *k_cache, + void *v_cache, + const void *k, + const void *v, + const void *past_kv_lengths, + void *stream); + +__C __export infiniStatus_t infiniopDestroyKVCachingDescriptor(infiniopKVCachingDescriptor_t desc); + +#endif diff --git a/include/infiniop/ops/random_sample.h b/include/infiniop/ops/random_sample.h index 1c242d7ba..bb2b15959 100644 --- a/include/infiniop/ops/random_sample.h +++ b/include/infiniop/ops/random_sample.h @@ -15,12 +15,6 @@ __C __export infiniStatus_t infiniopGetRandomSampleWorkspaceSize( infiniopRandomSampleDescriptor_t desc, size_t *size); -__C __export infiniStatus_t infiniopCreateRandomSampleBatchDescriptor( - infiniopHandle_t handle, - infiniopRandomSampleDescriptor_t *desc_ptr, - infiniopTensorDescriptor_t result, - infiniopTensorDescriptor_t probs); - __C __export infiniStatus_t infiniopRandomSample( infiniopRandomSampleDescriptor_t desc, void *workspace, diff --git a/include/infiniop/ops/random_sample_batched.h b/include/infiniop/ops/random_sample_batched.h new file mode 100644 index 000000000..4512e7dcb --- /dev/null +++ b/include/infiniop/ops/random_sample_batched.h @@ -0,0 +1,34 @@ +#ifndef __INFINIOP_RANDOM_SAMPLE_BATCHED_API_H__ +#define __INFINIOP_RANDOM_SAMPLE_BATCHED_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopRandomSampleBatchedDescriptor_t; + +__C __export infiniStatus_t infiniopCreateRandomSampleBatchedDescriptor( + infiniopHandle_t handle, + infiniopRandomSampleBatchedDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t result, + infiniopTensorDescriptor_t probs); + +__C __export infiniStatus_t infiniopGetRandomSampleBatchedWorkspaceSize( + infiniopRandomSampleBatchedDescriptor_t desc, + size_t *size); + +__C __export infiniStatus_t infiniopRandomSampleBatched( + infiniopRandomSampleBatchedDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *result, + const void *probs, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size, + void *stream); + +__C __export infiniStatus_t infiniopDestroyRandomSampleBatchedDescriptor( + infiniopRandomSampleBatchedDescriptor_t desc); + +#endif diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index 255079790..f8c7d6ef0 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -4,6 +4,7 @@ from .random_sample import random_sample from .rms_norm import rms_norm from .rope import RopeAlgo, rope +from .scaled_dot_product_attention import scaled_dot_product_attention from .silu import silu from .swiglu import swiglu @@ -11,6 +12,7 @@ "causal_softmax", "random_sample", "rms_norm", + "scaled_dot_product_attention", "silu", "swiglu", "linear", diff --git a/python/infinicore/nn/functional/scaled_dot_product_attention.py b/python/infinicore/nn/functional/scaled_dot_product_attention.py new file mode 100644 index 000000000..d89f484fe --- /dev/null +++ b/python/infinicore/nn/functional/scaled_dot_product_attention.py @@ -0,0 +1,28 @@ +import math + +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0, + is_causal=False, + scale=None, + enable_gqa=False, +): + assert attn_mask is None and dropout_p == 0 and not enable_gqa + + emb_dim = query.shape[-1] + + if scale is None: + scale = 1 / math.sqrt(emb_dim) + + return Tensor( + _infinicore.flash_attention( + query._underlying, key._underlying, value._underlying, scale, is_causal + ) + ) diff --git a/scripts/build_ntops.py b/scripts/build_ntops.py index 1499b6bf8..e1397e56d 100644 --- a/scripts/build_ntops.py +++ b/scripts/build_ntops.py @@ -1,3 +1,4 @@ +import concurrent.futures import importlib import pathlib @@ -11,16 +12,27 @@ def _find_and_build_ops(): ops_path = SRC_DIR_PATH / "infiniop" / "ops" - for op_dir in ops_path.iterdir(): - ninetoothed_path = op_dir / "ninetoothed" + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = [] - if ninetoothed_path.is_dir(): - module_path = ninetoothed_path / "build" - relative_path = module_path.relative_to(SRC_DIR_PATH) - import_name = ".".join(relative_path.parts) - module = importlib.import_module(import_name) + for op_dir in ops_path.iterdir(): + ninetoothed_path = op_dir / "ninetoothed" - module.build() + if not ninetoothed_path.is_dir(): + continue + + futures.append(executor.submit(_build, ninetoothed_path)) + + concurrent.futures.as_completed(futures) + + +def _build(ninetoothed_path): + module_path = ninetoothed_path / "build" + relative_path = module_path.relative_to(SRC_DIR_PATH) + import_name = ".".join(relative_path.parts) + module = importlib.import_module(import_name) + + module.build() if __name__ == "__main__": diff --git a/src/infinicore/ops/embedding/embedding.cc b/src/infinicore/ops/embedding/embedding.cc index f1add0c97..96f19803c 100644 --- a/src/infinicore/ops/embedding/embedding.cc +++ b/src/infinicore/ops/embedding/embedding.cc @@ -1,15 +1,34 @@ #include "infinicore/ops/embedding.hpp" +#include "../../utils.hpp" #include "infinicore/context/context.hpp" #include +#include namespace infinicore::op { +common::OpDispatcher &Embedding::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +} + +void Embedding::execute(Tensor out, Tensor input, Tensor weight) { + // Check that all tensors are on the same device + // This is critical: if input is on CPU while out/weight are on GPU, + // passing CPU pointer to CUDA kernel will cause memory access errors + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, input, weight); + + // Set device context + infinicore::context::setDevice(out->device()); + + // Use dispatcher to lookup kernel (infiniop implementation) + dispatcher().lookup(out->device().getType())(out, input, weight); +} + Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1 ) { auto input_shape = input->shape(); auto weight_shape = weight->shape(); - // auto vocab_size = weight_shape[0]; auto embedding_dim = weight_shape[1]; // Assign memory to out variables @@ -22,68 +41,7 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i } void embedding_(Tensor out, Tensor input, Tensor weight) { - assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype())); - assert(infinicore::Device::Type::CPU == input->device().getType()); - - auto input_shape = input->shape(); - auto weight_shape = weight->shape(); - auto embedding_dim = weight_shape[1]; - - // Calculate the number of token - Size counts = 1; - for (auto &v : input_shape) { - counts *= v; - } - - // the bytes of one token - const Size bytes = dsize(weight->dtype()) * embedding_dim; - auto *weight_ptr = weight->data(); - auto *out_ptr = out->data(); - - // copies - if (weight->device().getType() == Device::Type::CPU) { - if (infinicore::DataType::I64 == input->dtype()) { - const int64_t *input_arr = reinterpret_cast(input->data()); - for (Size i = 0; i < counts; ++i) { - int64_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - std::memcpy(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } else if (infinicore::DataType::I32 == input->dtype()) { - const int32_t *input_arr = reinterpret_cast(input->data()); - - for (Size i = 0; i < counts; ++i) { - int32_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - std::memcpy(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } - - } else { - if (infinicore::DataType::I64 == input->dtype()) { - const int64_t *input_arr = reinterpret_cast(input->data()); - for (Size i = 0; i < counts; ++i) { - int64_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - context::memcpyD2D(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } else if (infinicore::DataType::I32 == input->dtype()) { - const int32_t *input_arr = reinterpret_cast(input->data()); - for (Size i = 0; i < counts; ++i) { - int32_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - context::memcpyD2D(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } - } + Embedding::execute(out, input, weight); } } // namespace infinicore::op diff --git a/src/infinicore/ops/embedding/embedding_infiniop.cc b/src/infinicore/ops/embedding/embedding_infiniop.cc new file mode 100644 index 000000000..dfbbb2f71 --- /dev/null +++ b/src/infinicore/ops/embedding/embedding_infiniop.cc @@ -0,0 +1,49 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/embedding.hpp" +#include + +namespace infinicore::op::embedding_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopEmbeddingDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyEmbeddingDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor out, Tensor input, Tensor weight) { + size_t seed = hash_combine(out, input, weight); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopEmbeddingDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateEmbeddingDescriptor( + context::getInfiniopHandle(device), &desc, + out->desc(), input->desc(), weight->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + INFINICORE_CHECK_ERROR(infiniopEmbedding( + desc, + out->data(), + input->data(), + weight->data(), + context::getStream())); +} + +static bool registered = []() { + Embedding::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::embedding_impl::infiniop diff --git a/src/infinicore/ops/flash_attention/flash_attention.cc b/src/infinicore/ops/flash_attention/flash_attention.cc new file mode 100644 index 000000000..97db6de79 --- /dev/null +++ b/src/infinicore/ops/flash_attention/flash_attention.cc @@ -0,0 +1,29 @@ +#include "infinicore/ops/flash_attention.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &FlashAttention::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void FlashAttention::execute(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v); + infinicore::context::setDevice(out->device()); + dispatcher().lookup(out->device().getType())( + out, q, k, v, scale, is_causal); +} + +Tensor flash_attention(Tensor q, Tensor k, Tensor v, float scale, bool is_causal) { + Shape shape = q->shape(); + auto out = Tensor::empty(shape, q->dtype(), q->device()); + flash_attention_(out, q, k, v, scale, is_causal); + return out; +} + +void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal) { + FlashAttention::execute(out, q, k, v, scale, is_causal); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc b/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc new file mode 100644 index 000000000..e0a91e681 --- /dev/null +++ b/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc @@ -0,0 +1,51 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/flash_attention.hpp" +#include + +namespace infinicore::op::flash_attention_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopFlashAttentionDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyFlashAttentionDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal) { + size_t seed = hash_combine(out, q, k, v, scale, is_causal); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopFlashAttentionDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateFlashAttentionDescriptor( + context::getInfiniopHandle(device), &desc, + out->desc(), q->desc(), k->desc(), v->desc(), + scale, static_cast(is_causal))); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetFlashAttentionWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopFlashAttention( + desc, workspace->data(), workspace_size, + out->data(), q->data(), k->data(), v->data(), context::getStream())); +} + +static bool registered = []() { + FlashAttention::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::flash_attention_impl::infiniop diff --git a/src/infinicore/ops/kv_caching/kv_caching.cc b/src/infinicore/ops/kv_caching/kv_caching.cc new file mode 100644 index 000000000..bed3a4566 --- /dev/null +++ b/src/infinicore/ops/kv_caching/kv_caching.cc @@ -0,0 +1,47 @@ +#include "infinicore/ops/kv_caching.hpp" + +#include "../../utils.hpp" + +#include + +namespace infinicore::op { + +common::OpDispatcher &KVCaching::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void KVCaching::execute(Tensor k_cache, + Tensor v_cache, + Tensor k, + Tensor v, + Tensor past_kv_lengths) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cache, v_cache, k, v, past_kv_lengths); + infinicore::context::setDevice(k_cache->device()); + auto device_type = k_cache->device().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No KVCaching implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(k_cache, v_cache, k, v, past_kv_lengths); +} + +Tensor kv_caching(Tensor k_cache, + Tensor v_cache, + Tensor k, + Tensor v, + Tensor past_kv_lengths) { + KVCaching::execute(k_cache, v_cache, k, v, past_kv_lengths); + return k_cache; // or v_cache, depending on the intended use +} + +void kv_caching_(Tensor k_cache, + Tensor v_cache, + Tensor k, + Tensor v, + Tensor past_kv_lengths) { + KVCaching::execute(k_cache, v_cache, k, v, past_kv_lengths); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc b/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc new file mode 100644 index 000000000..37d5e1fa3 --- /dev/null +++ b/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc @@ -0,0 +1,59 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/kv_caching.hpp" +#include + +namespace infinicore::op::kv_caching_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopKVCachingDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyKVCachingDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor k_cache, + Tensor v_cache, + Tensor k, + Tensor v, + Tensor past_kv_lengths) { + size_t seed = hash_combine(k_cache, v_cache, k, v, past_kv_lengths); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopKVCachingDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateKVCachingDescriptor( + context::getInfiniopHandle(device), &desc, + k_cache->desc(), v_cache->desc(), + k->desc(), v->desc(), + past_kv_lengths->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetKVCachingWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopKVCaching( + desc, workspace->data(), workspace_size, + k_cache->data(), v_cache->data(), + k->data(), v->data(), + past_kv_lengths->data(), + context::getStream())); +} + +static bool registered = []() { + KVCaching::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::kv_caching_impl::infiniop diff --git a/src/infinicore/ops/random_sample_batched/random_sample_batched.cc b/src/infinicore/ops/random_sample_batched/random_sample_batched.cc new file mode 100644 index 000000000..a02635f66 --- /dev/null +++ b/src/infinicore/ops/random_sample_batched/random_sample_batched.cc @@ -0,0 +1,54 @@ +#include "infinicore/ops/random_sample_batched.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &RandomSampleBatched::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void RandomSampleBatched::execute( + Tensor result, + Tensor probs, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(result, probs); + infinicore::context::setDevice(result->device()); + auto device_type = result->device().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No RandomSampleBatched implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(result, probs, random_val, topp, topk, temperature, batch_size); +} + +Tensor random_sample_batched( + Tensor logits, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size) { + Shape shape = logits->shape(); + auto result = Tensor::empty(shape, DataType::I32, logits->device()); + random_sample_batched_(result, logits, random_val, topp, topk, temperature, batch_size); + return result; +} +void random_sample_batched_( + Tensor result, + Tensor logits, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size) { + RandomSampleBatched::execute(result, logits, random_val, topp, topk, temperature, batch_size); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/random_sample_batched/random_sample_batched_infiniop.cc b/src/infinicore/ops/random_sample_batched/random_sample_batched_infiniop.cc new file mode 100644 index 000000000..2916c0b2a --- /dev/null +++ b/src/infinicore/ops/random_sample_batched/random_sample_batched_infiniop.cc @@ -0,0 +1,63 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/random_sample_batched.hpp" +#include + +namespace infinicore::op::random_sample_batched_impl::infiniop_backend { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopRandomSampleBatchedDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyRandomSampleBatchedDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate( + Tensor result, + Tensor probs, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size) { + size_t seed = hash_combine(result, probs, batch_size); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopRandomSampleBatchedDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleBatchedDescriptor( + context::getInfiniopHandle(device), &desc, + result->desc(), probs->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetRandomSampleBatchedWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopRandomSampleBatched( + desc, + workspace->data(), workspace_size, + result->data(), probs->data(), + random_val, topp, topk, temperature, + batch_size, + context::getStream())); +} + +} // namespace infinicore::op::random_sample_batched_impl::infiniop_backend + +namespace infinicore::op { +static bool registered = []() { + RandomSampleBatched::dispatcher().registerAll(&random_sample_batched_impl::infiniop_backend::calculate, false); + return true; +}(); +} // namespace infinicore::op diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 3d6ebe79a..e2d5aa00b 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -7,6 +7,7 @@ #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" +#include "ops/flash_attention.hpp" #include "ops/linear.hpp" #include "ops/matmul.hpp" #include "ops/mul.hpp" @@ -29,6 +30,7 @@ inline void bind(py::module &m) { bind_add_rms_norm(m); bind_attention(m); bind_causal_softmax(m); + bind_flash_attention(m); bind_random_sample(m); bind_linear(m); bind_matmul(m); diff --git a/src/infinicore/pybind11/ops/flash_attention.hpp b/src/infinicore/pybind11/ops/flash_attention.hpp new file mode 100644 index 000000000..09ec91980 --- /dev/null +++ b/src/infinicore/pybind11/ops/flash_attention.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include + +#include "infinicore/ops/flash_attention.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_flash_attention(py::module &m) { + m.def("flash_attention", + &op::flash_attention, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("scale"), + py::arg("is_causal")); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/ninetoothed/build.py b/src/infiniop/ninetoothed/build.py index aea421b7f..153e6b9f5 100644 --- a/src/infiniop/ninetoothed/build.py +++ b/src/infiniop/ninetoothed/build.py @@ -1,3 +1,4 @@ +import concurrent.futures import functools import inspect import itertools @@ -16,40 +17,28 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir): headers = [] all_param_names = [] + combinations = [] launches = [] - for combination in _generate_param_value_combinations(constexpr_param_grid): - arrangement, application, tensors = premake(**combination) + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = [] - for param_name, param_value in combination.items(): - if isinstance(param_value, str): - combination[param_name] = ( - f"INFINI_DTYPE_{combination[param_name].replace('fp', 'F').upper()}" - ) + for combination in tuple( + _generate_param_value_combinations(constexpr_param_grid) + ): + future = executor.submit( + _make, premake, combination, caller, op_name, output_dir + ) - combination = {f"{name}_": value for name, value in combination.items()} + futures.append(future) - kernel_name = f"{op_name}_{_generate_suffix(combination.values())}" + for future in concurrent.futures.as_completed(futures): + header, param_names, combination, launch = future.result() - ninetoothed.make( - arrangement, - application, - tensors, - caller=caller, - kernel_name=kernel_name, - output_dir=output_dir, - ) - - header = output_dir / f"{kernel_name}.h" - param_names = ("stream",) + tuple( - inspect.signature(application).parameters.keys() - ) - launch = f""" if ({_generate_condition(combination)}) - return launch_{kernel_name}({", ".join(param_names)});""" - - headers.append(header) - all_param_names.append(param_names) - launches.append(launch) + headers.append(header) + all_param_names.append(param_names) + combinations.append(combination) + launches.append(launch) includes = "\n".join(f'#include "{header}"' for header in headers) @@ -64,7 +53,7 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir): "NineToothedStream", ] + ["NineToothedTensor" for _ in range(len(param_names) - 1)] - for param_name in combination: + for param_name in functools.reduce(lambda x, y: x | y, combinations, {}): param_names.append(param_name) param_types.append("int") @@ -97,6 +86,36 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir): (BUILD_DIRECTORY_PATH / header_file_name).write_text(header_content) +def _make(premake, combination, caller, op_name, output_dir): + arrangement, application, tensors = premake(**combination) + + for param_name, param_value in combination.items(): + if isinstance(param_value, str): + combination[param_name] = ( + f"INFINI_DTYPE_{combination[param_name].replace('fp', 'F').upper()}" + ) + + combination = {f"{name}_": value for name, value in combination.items()} + + kernel_name = f"{op_name}_{_generate_suffix(combination.values())}" + + ninetoothed.make( + arrangement, + application, + tensors, + caller=caller, + kernel_name=kernel_name, + output_dir=output_dir, + ) + + header = output_dir / f"{kernel_name}.h" + param_names = ("stream",) + tuple(inspect.signature(application).parameters.keys()) + launch = f""" if ({_generate_condition(combination)}) + return launch_{kernel_name}({", ".join(param_names)});""" + + return header, param_names, combination, launch + + def _generate_condition(combination): return " && ".join(f"{param} == {value}" for param, value in combination.items()) diff --git a/src/infiniop/ninetoothed/utils.h b/src/infiniop/ninetoothed/utils.h new file mode 100644 index 000000000..1b7d1fe3a --- /dev/null +++ b/src/infiniop/ninetoothed/utils.h @@ -0,0 +1,75 @@ +#ifndef __NINETOOTHED_UTILS__ +#define __NINETOOTHED_UTILS__ + +#include +#include +#include +#include + +namespace ninetoothed { + +template +class Tensor { +public: + using Data = decltype(NineToothedTensor::data); + + using Size = std::remove_pointer_t; + + using Stride = std::remove_pointer_t; + + template + Tensor(const void *data, Shape shape, Strides strides) : data_{data}, shape_{shape}, strides_{strides}, ndim_{shape_.size()} {} + + Tensor(const void *data, std::initializer_list shape, std::initializer_list strides) : Tensor{data, decltype(shape_){shape}, decltype(strides_){strides}} {} + + Tensor(const void *data, const Size *shape, const Stride *strides, Size ndim) : data_{data}, shape_{shape, shape + ndim}, strides_{strides, strides + ndim}, ndim_{shape_.size()} {} + + Tensor(const T value) : value_{value}, data_{&value_}, ndim_{0} {} + + operator NineToothedTensor() { return {const_cast(data_), shape_.data(), strides_.data()}; } + + template + Tensor expand(const Shape &sizes) const { + auto new_ndim{sizes.size()}; + + decltype(shape_) shape(new_ndim, 1); + decltype(strides_) strides(new_ndim, 0); + + auto num_new_dims{new_ndim - ndim_}; + + for (auto dim{decltype(ndim_){0}}; dim < ndim_; ++dim) { + shape[dim + num_new_dims] = shape_[dim]; + strides[dim + num_new_dims] = strides_[dim]; + } + + for (auto dim{decltype(new_ndim){0}}; dim < new_ndim; ++dim) { + if (sizes[dim] == std::numeric_limits>::max() || shape[dim] != 1) { + continue; + } + + shape[dim] = sizes[dim]; + strides[dim] = 0; + } + + return {data_, shape, strides}; + } + + Tensor expand_as(const Tensor &other) const { + return expand(other.shape_); + } + +private: + const void *data_{nullptr}; + + std::vector shape_; + + std::vector strides_; + + Size ndim_{0}; + + T value_{0}; +}; + +} // namespace ninetoothed + +#endif diff --git a/src/infiniop/ops/embedding/operator.cc b/src/infiniop/ops/embedding/operator.cc new file mode 100644 index 000000000..0bf7864c9 --- /dev/null +++ b/src/infiniop/ops/embedding/operator.cc @@ -0,0 +1,89 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/embedding.h" + +#ifdef ENABLE_CPU_API +// #include "cpu/embedding_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +// #include "nvidia/embedding_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreateEmbeddingDescriptor( + infiniopHandle_t handle, + infiniopEmbeddingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::embedding::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + output_desc, \ + input_desc, \ + weight_desc) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + // CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopEmbedding( + infiniopEmbeddingDescriptor_t desc, + void *output, + const void *input, + const void *weight, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(output, input, weight, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + // CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyEmbeddingDescriptor(infiniopEmbeddingDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + } + +#undef DELETE + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} diff --git a/src/infiniop/ops/flash_attention/ninetoothed/build.py b/src/infiniop/ops/flash_attention/ninetoothed/build.py new file mode 100644 index 000000000..dfcce6910 --- /dev/null +++ b/src/infiniop/ops/flash_attention/ninetoothed/build.py @@ -0,0 +1,35 @@ +import ninetoothed +from ntops.kernels import scaled_dot_product_attention +from ntops.kernels.scaled_dot_product_attention import CausalVariant + +import infiniop.ninetoothed.build + + +def build(): + with_kv_cache_values = (0,) + emb_dim_values = (16, 32, 64, 128, 256) + is_causal_values = (0, 1) + with_attn_mask_values = (0,) + causal_variant_values = (CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT) + dtype_values = (ninetoothed.float16, ninetoothed.float32) + block_size_m_values = (64,) + block_size_n_values = (64,) + + constexpr_param_grid = { + "with_kv_cache": with_kv_cache_values, + "emb_dim": emb_dim_values, + "is_causal": is_causal_values, + "with_attn_mask": with_attn_mask_values, + "causal_variant": causal_variant_values, + "dtype": dtype_values, + "block_size_m": block_size_m_values, + "block_size_n": block_size_n_values, + } + + infiniop.ninetoothed.build.build( + scaled_dot_product_attention.premake, + constexpr_param_grid, + caller="cuda", + op_name="flash_attention", + output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH, + ) diff --git a/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h new file mode 100644 index 000000000..697891d3d --- /dev/null +++ b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h @@ -0,0 +1,133 @@ +#ifndef __FLASH_ATTENTION_DESCRIPTOR_H__ +#define __FLASH_ATTENTION_DESCRIPTOR_H__ + +#include "../../../handle.h" +#include "../../../operator.h" +#include "../../../tensor.h" + +#include "../../../../../build/ninetoothed/flash_attention.h" +#include "../../../ninetoothed/utils.h" + +namespace op::flash_attention::ninetoothed { + +class Descriptor final : public InfiniopDescriptor { +public: + Descriptor(infiniopHandle_t handle, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + double scale, + char is_causal) : InfiniopDescriptor{handle->device, handle->device_id}, + _query_shape{q_desc->shape()}, + _query_strides{q_desc->strides()}, + _key_shape{k_desc->shape()}, + _key_strides{k_desc->strides()}, + _value_shape{v_desc->shape()}, + _value_strides{v_desc->strides()}, + _output_strides{out_desc->strides()}, + _dtype{q_desc->dtype()}, + _scale{scale}, + _is_causal{is_causal} {} + + ~Descriptor() = default; + + size_t get_workspace_size() const { + return 0; + } + + infiniStatus_t calculate(void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k, + const void *v, + void *stream) const { + uint64_t empty_shape[4]; + int64_t empty_strides[4]; + + auto query{::ninetoothed::Tensor{q, _query_shape, _query_strides}}; + auto key{::ninetoothed::Tensor{k, _key_shape, _key_strides}}; + auto value{::ninetoothed::Tensor{v, _value_shape, _value_strides}}; + + NineToothedTensor attn_mask{nullptr, empty_shape, empty_strides}; + NineToothedTensor is_causal; + NineToothedTensor scale{const_cast(&_scale), nullptr, nullptr}; + auto output{::ninetoothed::Tensor{out, _query_shape, _output_strides}}; + NineToothedTensor with_attn_mask; + NineToothedTensor causal_variant; + + const auto with_kv_cache_{0}; + const auto emb_dim_{_query_shape[3]}; + const auto is_causal_{_is_causal}; + const auto with_attn_mask_{0}; + const auto causal_variant_{1}; + const auto dtype_{_dtype}; + + constexpr auto block_size_m_{64}; + constexpr auto block_size_n_{64}; + + launch_flash_attention(stream, + query, + key, + value, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + with_kv_cache_, + emb_dim_, + is_causal_, + with_attn_mask_, + causal_variant_, + dtype_, + block_size_m_, + block_size_n_); + + return INFINI_STATUS_SUCCESS; + } + + static infiniStatus_t create(infiniopHandle_t handle, + Descriptor **desc, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + double scale, + char is_causal) { + *desc = new Descriptor{handle, out_desc, q_desc, k_desc, v_desc, scale, is_causal}; + + return INFINI_STATUS_SUCCESS; + } + +private: + using Size = ::ninetoothed::Tensor<>::Size; + + using Stride = ::ninetoothed::Tensor<>::Stride; + + std::vector _query_shape; + + std::vector _query_strides; + + std::vector _key_shape; + + std::vector _key_strides; + + std::vector _value_shape; + + std::vector _value_strides; + + std::vector _output_strides; + + infiniDtype_t _dtype; + + double _scale; + + char _is_causal; +}; + +} // namespace op::flash_attention::ninetoothed + +#endif // __FLASH_ATTENTION_DESCRIPTOR_H__ diff --git a/src/infiniop/ops/flash_attention/operator.cc b/src/infiniop/ops/flash_attention/operator.cc new file mode 100644 index 000000000..e907d3c41 --- /dev/null +++ b/src/infiniop/ops/flash_attention/operator.cc @@ -0,0 +1,141 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/flash_attention.h" + +#ifdef ENABLE_CPU_API +// #include "cpu/flash_attention_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) +#include "ninetoothed/descriptor.h" +#else +// #include "nvidia/flash_attention_nvidia.cuh" +#endif +#endif + +__C infiniStatus_t infiniopCreateFlashAttentionDescriptor( + infiniopHandle_t handle, + infiniopFlashAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + float scale, + char is_causal) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::flash_attention::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, \ + q_desc, \ + k_desc, \ + v_desc, \ + scale, \ + is_causal); + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + // CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) + CREATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else + // CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CREATE +} + +__C infiniStatus_t infiniopGetFlashAttentionWorkspaceSize( + infiniopFlashAttentionDescriptor_t desc, + size_t *size) { + +#define GET_SIZE(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc) \ + ->get_workspace_size(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // GET_SIZE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) + GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else + // GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET_SIZE +} + +__C infiniStatus_t infiniopFlashAttention( + infiniopFlashAttentionDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k, + const void *v, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, out, q, k, v, stream); + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + // CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) + CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else + // CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyFlashAttentionDescriptor( + infiniopFlashAttentionDescriptor_t desc) { + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // DESTROY(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) + DESTROY(INFINI_DEVICE_NVIDIA, ninetoothed); +#else + // DESTROY(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DESTROY +} diff --git a/src/infiniop/ops/kv_caching/operator.cc b/src/infiniop/ops/kv_caching/operator.cc new file mode 100644 index 000000000..65b27a414 --- /dev/null +++ b/src/infiniop/ops/kv_caching/operator.cc @@ -0,0 +1,121 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/kv_caching.h" + +#ifdef ENABLE_CPU_API +// #include "cpu/kv_caching_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +// #include "nvidia/kv_caching_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreateKVCachingDescriptor( + infiniopHandle_t handle, + infiniopKVCachingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t k_cache, + infiniopTensorDescriptor_t v_cache, + infiniopTensorDescriptor_t k, + infiniopTensorDescriptor_t v, + infiniopTensorDescriptor_t past_kv_lengths) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::kv_caching::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + k_cache, \ + v_cache, \ + k, \ + v, \ + past_kv_lengths) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + // CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetKVCachingWorkspaceSize( + infiniopKVCachingDescriptor_t desc, + size_t *size) { + +#define GET_SIZE(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc) \ + ->get_workspace_size(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // GET_SIZE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET_SIZE +} + +__C infiniStatus_t infiniopKVCaching( + infiniopKVCachingDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *k_cache, + void *v_cache, + const void *k, + const void *v, + const void *past_kv_lengths, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, k_cache, v_cache, k, v, past_kv_lengths, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + // CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyKVCachingDescriptor( + infiniopKVCachingDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DELETE +} diff --git a/src/infiniop/ops/random_sample_batched/operator.cc b/src/infiniop/ops/random_sample_batched/operator.cc new file mode 100644 index 000000000..d0047ad53 --- /dev/null +++ b/src/infiniop/ops/random_sample_batched/operator.cc @@ -0,0 +1,128 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/random_sample_batched.h" + +#ifdef ENABLE_CPU_API +// #include "cpu/random_sample_batched_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +// #include "nvidia/random_sample_batched_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreateRandomSampleBatchedDescriptor( + infiniopHandle_t handle, + infiniopRandomSampleBatchedDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t result, + infiniopTensorDescriptor_t probs) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::random_sample::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + result, \ + probs) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + // CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetRandomSampleBatchedWorkspaceSize( + infiniopRandomSampleBatchedDescriptor_t desc, + size_t *size) { + +#define GET_SIZE(CASE, NAMESPACE) \ + case CASE: \ + using Ptr = const op::random_sample::NAMESPACE::Descriptor *; \ + *size = reinterpret_cast(desc)->minWorkspaceSize(); \ + } \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + // GET_SIZE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET_SIZE +} + +__C infiniStatus_t infiniopRandomSampleBatched( + infiniopRandomSampleBatchedDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *result, + const void *probs, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, \ + result, probs, \ + random_val, \ + topp, topk, temperature, \ + batch_size, \ + stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + // CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyRandomSampleBatchedDescriptor( + infiniopRandomSampleBatchedDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/src/infiniop/ops/relu/metax/relu_metax.maca b/src/infiniop/ops/relu/metax/relu_metax.maca index 900fce9e0..2c5104bdd 100644 --- a/src/infiniop/ops/relu/metax/relu_metax.maca +++ b/src/infiniop/ops/relu/metax/relu_metax.maca @@ -2,6 +2,7 @@ #include "../../../../../build/ninetoothed/relu.h" #include "../../../devices/metax/metax_common.h" +#include "../../../ninetoothed/utils.h" #include "relu_metax.h" namespace op::relu::metax { @@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate( } const auto &ndim{_info.getNdim()}; - const auto &x_shape_{_info.getInputShape(0)}; - const auto &x_strides_{_info.getInputStrides(0)}; - std::vector x_shape_vec{x_shape_, x_shape_ + ndim}; - std::vector x_strides_vec{x_strides_, x_strides_ + ndim}; - auto x_data{const_cast(inputs[0])}; - auto x_shape{x_shape_vec.data()}; - auto x_strides{x_strides_vec.data()}; - const NineToothedTensor x{x_data, x_shape, x_strides}; - const auto &y_shape_{_info.getOutputShape()}; - const auto &y_strides_{_info.getOutputStrides()}; - std::vector y_shape_vec{y_shape_, y_shape_ + ndim}; - std::vector y_strides_vec{y_strides_, y_strides_ + ndim}; - auto y_data{output}; - auto y_shape{y_shape_vec.data()}; - auto y_strides{y_strides_vec.data()}; - const NineToothedTensor y{y_data, y_shape, y_strides}; + + auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}}; + auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}}; + constexpr auto block_size{1024}; switch (_dtype) { diff --git a/src/infiniop/ops/relu/nvidia/relu_nvidia.cu b/src/infiniop/ops/relu/nvidia/relu_nvidia.cu index 22b85a401..a3c79fb52 100644 --- a/src/infiniop/ops/relu/nvidia/relu_nvidia.cu +++ b/src/infiniop/ops/relu/nvidia/relu_nvidia.cu @@ -1,5 +1,6 @@ #ifdef ENABLE_NINETOOTHED #include "../../../../../build/ninetoothed/relu.h" +#include "../../../ninetoothed/utils.h" #endif #include "../../../devices/nvidia/nvidia_common.cuh" #include "../../../elementwise/nvidia/elementwise_nvidia.cuh" @@ -45,22 +46,10 @@ infiniStatus_t Descriptor::calculate( } #ifdef ENABLE_NINETOOTHED const auto &ndim{_info.getNdim()}; - const auto &x_shape_{_info.getInputShape(0)}; - const auto &x_strides_{_info.getInputStrides(0)}; - std::vector x_shape_vec{x_shape_, x_shape_ + ndim}; - std::vector x_strides_vec{x_strides_, x_strides_ + ndim}; - auto x_data{const_cast(inputs[0])}; - auto x_shape{x_shape_vec.data()}; - auto x_strides{x_strides_vec.data()}; - const NineToothedTensor x{x_data, x_shape, x_strides}; - const auto &y_shape_{_info.getOutputShape()}; - const auto &y_strides_{_info.getOutputStrides()}; - std::vector y_shape_vec{y_shape_, y_shape_ + ndim}; - std::vector y_strides_vec{y_strides_, y_strides_ + ndim}; - auto y_data{output}; - auto y_shape{y_shape_vec.data()}; - auto y_strides{y_strides_vec.data()}; - const NineToothedTensor y{y_data, y_shape, y_strides}; + + auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}}; + auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}}; + constexpr auto block_size{1024}; switch (_dtype) { diff --git a/test/infinicore/ops/scaled_dot_product_attention.py b/test/infinicore/ops/scaled_dot_product_attention.py index 218420d72..644fb6f99 100644 --- a/test/infinicore/ops/scaled_dot_product_attention.py +++ b/test/infinicore/ops/scaled_dot_product_attention.py @@ -11,17 +11,16 @@ # q/k/v typically have shape (..., seq_len, head_dim) or (batch, seq_len, num_heads, head_dim) _TEST_CASES_DATA = [ - ((2, 8, 16), (2, 8, 16), (2, 8, 16), None, 0.0, False), - ((1, 4, 32), (1, 4, 32), (1, 4, 32), None, 0.0, False), - ((2, 6, 12), (2, 6, 12), (2, 6, 12), None, 0.0, True), - ((3, 8, 8), (3, 8, 8), (3, 8, 8), None, 0.0, False), - ((2, 4, 16), (2, 4, 16), (2, 4, 16), None, 0.0, True), - ((1, 2, 64), (1, 2, 64), (1, 2, 64), None, 0.0, False), + ((1, 1, 2, 16), (1, 1, 2, 16), (1, 1, 2, 16), None, 0.0, False), + ((1, 2, 8, 16), (1, 2, 8, 16), (1, 2, 8, 16), None, 0.0, False), + ((1, 1, 4, 32), (1, 1, 4, 32), (1, 1, 4, 32), None, 0.0, False), + ((1, 2, 4, 16), (1, 2, 4, 16), (1, 2, 4, 16), None, 0.0, True), + ((1, 1, 2, 64), (1, 1, 2, 64), (1, 1, 2, 64), None, 0.0, False), ] _TOLERANCE_MAP = { infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, - infinicore.float32: {"atol": 1e-4, "rtol": 1e-4}, + infinicore.float32: {"atol": 1e-3, "rtol": 1e-3}, } _TENSOR_DTYPES = [infinicore.float16, infinicore.float32] @@ -68,9 +67,8 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.nn.functional.scaled_dot_product_attention(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.nn.functional.scaled_dot_product_attention(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + return infinicore.nn.functional.scaled_dot_product_attention(*args, **kwargs) def main():