Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/infinicore/context/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "../device.hpp"
#include "../memory.hpp"

#include "../graph/graph.hpp"

#include <infiniop.h>
#include <infinirt.h>

Expand Down Expand Up @@ -40,6 +42,12 @@ void destroyEvent(infinirtEvent_t event);
float elapsedTime(infinirtEvent_t start, infinirtEvent_t end);
void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event);

// Graph recording APIs
bool isGraphRecording();
void startGraphRecording();
void addGraphOperator(std::shared_ptr<graph::GraphOperator> op);
std::shared_ptr<graph::Graph> stopGraphRecording();

} // namespace context

} // namespace infinicore
92 changes: 92 additions & 0 deletions include/infinicore/graph/graph.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#pragma once

#include <memory>
#include <vector>

#include "../tensor.hpp"

namespace infinicore::graph {
// Forward declarations
class GraphManager;

class GraphTensor : public Tensor {
public:
GraphTensor(const Tensor &);
};

class GraphOperator {

public:
void run() const;
~GraphOperator();

protected:
using run_schema = void (*)(void *);
using cleanup_schema = void (*)(void **);
void *planned_meta_;
run_schema runner_;
cleanup_schema deleter_;
};

class Graph {
public:
Graph() = default;
~Graph() = default;

void run() const;

protected:
void add_operator(std::shared_ptr<GraphOperator> op);

std::vector<std::shared_ptr<GraphOperator>> op_list_;

friend class GraphManager;
};
} // namespace infinicore::graph

#define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \
class __OP_NAME__ : public graph::GraphOperator { \
public: \
using schema = void (*)(__VA_ARGS__); \
using plan_schema = void *(*)(__VA_ARGS__); \
static common::OpDispatcher<plan_schema> &plan_dispatcher(); \
static common::OpDispatcher<run_schema> &run_dispatcher(); \
static common::OpDispatcher<cleanup_schema> &cleanup_dispatcher(); \
__OP_NAME__(__VA_ARGS__); \
static void execute(__VA_ARGS__); \
};

#define INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(__OP_NAME__) \
common::OpDispatcher<__OP_NAME__::plan_schema> &__OP_NAME__::plan_dispatcher() { \
static common::OpDispatcher<__OP_NAME__::plan_schema> dispatcher_; \
return dispatcher_; \
} \
common::OpDispatcher<__OP_NAME__::run_schema> &__OP_NAME__::run_dispatcher() { \
static common::OpDispatcher<__OP_NAME__::run_schema> dispatcher_; \
return dispatcher_; \
} \
common::OpDispatcher<__OP_NAME__::cleanup_schema> &__OP_NAME__::cleanup_dispatcher() { \
static common::OpDispatcher<__OP_NAME__::cleanup_schema> dispatcher_; \
return dispatcher_; \
}

#define INFINICORE_GRAPH_OP_DISPATCH(__DEVICE_TYPE__, ...) \
planned_meta_ = plan_dispatcher().lookup(__DEVICE_TYPE__)(__VA_ARGS__); \
runner_ = run_dispatcher().lookup(__DEVICE_TYPE__); \
deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__);

#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \
auto op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
if (context::isGraphRecording()) { \
context::addGraphOperator(op); \
} else { \
op->run(); \
}

#define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \
static bool registered = []() { \
__OP_NAME__::plan_dispatcher().registerAll(__PLAN_F__, false); \
__OP_NAME__::run_dispatcher().registerAll(__RUN_F__, false); \
__OP_NAME__::cleanup_dispatcher().registerAll(__CLEANUP_F__, false); \
return true; \
}();
8 changes: 2 additions & 6 deletions include/infinicore/ops/gemm.hpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {

class Gemm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, float, float);
static void execute(Tensor c, Tensor a, Tensor b, float alpha, float beta);
static common::OpDispatcher<schema> &dispatcher();
};
INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, Tensor, Tensor, float, float);

Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f);
void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta);
Expand Down
4 changes: 2 additions & 2 deletions include/infinicore/ops/paged_caching.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ namespace infinicore::op {
class PagedCaching {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
static void execute(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping);
static void execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
static common::OpDispatcher<schema> &dispatcher();
};

void paged_caching_(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping);
void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);

} // namespace infinicore::op
4 changes: 3 additions & 1 deletion include/infinicore/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ class TensorImpl : public std::enable_shared_from_this<TensorImpl> {

void debug() const;

Tensor to_blob() const;

///
/// Data Transfer APIs
///
Expand Down Expand Up @@ -294,7 +296,7 @@ class TensorImpl : public std::enable_shared_from_this<TensorImpl> {

friend class Tensor;

private:
protected:
TensorMetaData meta_;
TensorData data_;
};
Expand Down
16 changes: 8 additions & 8 deletions include/infiniop/ops/paged_caching.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@ typedef struct InfiniopDescriptor *infiniopPagedCachingDescriptor_t;
*
* @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param k_desc Descriptor for the source key tensor.
* @param v_desc Descriptor for the source value tensor.
* @param k_cache_desc Descriptor for the key cache pool tensor.
* @param v_cache_desc Descriptor for the value cache pool tensor.
* @param k_desc Descriptor for the source key tensor.
* @param v_desc Descriptor for the source value tensor.
* @param slot_mapping_desc Descriptor for the slot mapping tensor.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopCreatePagedCachingDescriptor(
infiniopHandle_t handle,
infiniopPagedCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc);

/**
Expand All @@ -46,10 +46,10 @@ __C __export infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
* @param desc The Paged Caching descriptor.
* @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace.
* @param k Pointer to the source key tensor data.
* @param v Pointer to the source value tensor data.
* @param k_cache Pointer to the key cache pool data.
* @param v_cache Pointer to the value cache pool data.
* @param k Pointer to the source key tensor data.
* @param v Pointer to the source value tensor data.
* @param slot_mapping Pointer to the slot mapping data.
* @param stream The CUDA stream for the operation. Can be NULL.
* @return infiniStatus_t Status code of the operation.
Expand All @@ -58,10 +58,10 @@ __C __export infiniStatus_t infiniopPagedCaching(
infiniopPagedCachingDescriptor_t desc,
void *workspace,
size_t workspace_size,
const void *k,
const void *v,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *slot_mapping,
void *stream);

Expand Down
6 changes: 6 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
get_device,
get_device_count,
get_stream,
is_graph_recording,
set_device,
start_graph_recording,
stop_graph_recording,
sync_device,
sync_stream,
)
Expand Down Expand Up @@ -81,6 +84,9 @@
"set_device",
"sync_device",
"sync_stream",
"is_graph_recording",
"start_graph_recording",
"stop_graph_recording",
# Data Types.
"bfloat16",
"bool",
Expand Down
22 changes: 22 additions & 0 deletions python/infinicore/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import infinicore.device
from infinicore.graph import Graph
from infinicore.lib import _infinicore


Expand Down Expand Up @@ -49,3 +50,24 @@ def get_stream():
stream: The current stream object
"""
return _infinicore.get_stream()


def is_graph_recording():
"""Check if the current graph is recording.

Returns:
bool: True if the current graph is recording, False otherwise
"""
return _infinicore.is_graph_recording()


def start_graph_recording(device=None):
"""Start recording the current graph."""
if device is not None:
set_device(device)
_infinicore.start_graph_recording()


def stop_graph_recording():
"""Stop recording the current graph."""
return Graph(_infinicore.stop_graph_recording())
18 changes: 18 additions & 0 deletions python/infinicore/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from infinicore.lib import _infinicore


class Graph:
"""
Python wrapper around a InfiniCore Graph instance.
"""

def __init__(self, graph: _infinicore.Graph):
if not isinstance(graph, _infinicore.Graph):
raise TypeError("Expected _infinicore.Graph")
self._graph = graph

def run(self):
return self._graph.run()

def __repr__(self):
return f"<Graph wrapper of {self._graph!r}>"
8 changes: 4 additions & 4 deletions python/infinicore/ops/paged_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@


def paged_caching(
k: Tensor,
v: Tensor,
k_cache: Tensor,
v_cache: Tensor,
k: Tensor,
v: Tensor,
slot_mapping: Tensor,
):
Tensor(
_infinicore.paged_caching_(
k._underlying,
v._underlying,
k_cache._underlying,
v_cache._underlying,
k._underlying,
v._underlying,
slot_mapping._underlying,
)
)
Expand Down
6 changes: 6 additions & 0 deletions src/infinicore/context/allocators/device_pinned_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@ DevicePinnedHostAllocator::~DevicePinnedHostAllocator() {
}

std::byte *DevicePinnedHostAllocator::allocate(size_t size) {
if (size == 0) {
return nullptr;
}
void *ptr;
INFINICORE_CHECK_ERROR(infinirtMallocHost(&ptr, size));
return (std::byte *)ptr;
}

void DevicePinnedHostAllocator::deallocate(std::byte *ptr) {
if (ptr == nullptr) {
return;
}
if (owner_ == context::getDevice()) {
INFINICORE_CHECK_ERROR(infinirtFreeHost(ptr));
gc();
Expand Down
6 changes: 6 additions & 0 deletions src/infinicore/context/allocators/host_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@

namespace infinicore {
std::byte *HostAllocator::allocate(size_t size) {
if (size == 0) {
return nullptr;
}
return (std::byte *)std::malloc(size);
}

void HostAllocator::deallocate(std::byte *ptr) {
if (ptr == nullptr) {
return;
}
std::free(ptr);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "pinnable_block_allocator.hpp"

#include "../context_impl.hpp"

#include "../../utils.hpp"

#include <algorithm>
Expand Down Expand Up @@ -35,6 +37,9 @@ PinnableBlockAllocator::PinnableBlockAllocator(Device device)

// ------------------- allocate -------------------
std::byte *PinnableBlockAllocator::allocate(size_t size) {
if (size == 0) {
return nullptr;
}
std::lock_guard<std::mutex> lock(mutex_);

// Align size to 256 bytes for GPU
Expand Down Expand Up @@ -92,7 +97,7 @@ std::byte *PinnableBlockAllocator::allocate(size_t size) {

// ------------------- deallocate -------------------
void PinnableBlockAllocator::deallocate(std::byte *ptr) {
if (!ptr) {
if (ptr == nullptr) {
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

#include "memory_allocator.hpp"

#include "../context_impl.hpp"

#include <mutex>
#include <unordered_map>
#include <vector>
Expand All @@ -25,7 +23,7 @@ class PinnableBlockAllocator : public MemoryAllocator {
};

public:
explicit PinnableBlockAllocator(Device device);
PinnableBlockAllocator(Device device);
~PinnableBlockAllocator();

std::byte *allocate(size_t size) override;
Expand Down
Loading