Skip to content

Commit 01a4a0c

Browse files
authored
Merge pull request #882 from InfiniTensor/issue/810
issue/810 static compute graph infra
2 parents 3883f32 + 39f9c34 commit 01a4a0c

File tree

29 files changed

+566
-79
lines changed

29 files changed

+566
-79
lines changed

include/infinicore/context/context.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "../device.hpp"
44
#include "../memory.hpp"
55

6+
#include "../graph/graph.hpp"
7+
68
#include <infiniop.h>
79
#include <infinirt.h>
810

@@ -40,6 +42,12 @@ void destroyEvent(infinirtEvent_t event);
4042
float elapsedTime(infinirtEvent_t start, infinirtEvent_t end);
4143
void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event);
4244

45+
// Graph recording APIs
46+
bool isGraphRecording();
47+
void startGraphRecording();
48+
void addGraphOperator(std::shared_ptr<graph::GraphOperator> op);
49+
std::shared_ptr<graph::Graph> stopGraphRecording();
50+
4351
} // namespace context
4452

4553
} // namespace infinicore

include/infinicore/graph/graph.hpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <vector>
5+
6+
#include "../tensor.hpp"
7+
8+
namespace infinicore::graph {
9+
// Forward declarations
10+
class GraphManager;
11+
12+
class GraphTensor : public Tensor {
13+
public:
14+
GraphTensor(const Tensor &);
15+
};
16+
17+
class GraphOperator {
18+
19+
public:
20+
void run() const;
21+
~GraphOperator();
22+
23+
protected:
24+
using run_schema = void (*)(void *);
25+
using cleanup_schema = void (*)(void **);
26+
void *planned_meta_;
27+
run_schema runner_;
28+
cleanup_schema deleter_;
29+
};
30+
31+
class Graph {
32+
public:
33+
Graph() = default;
34+
~Graph() = default;
35+
36+
void run() const;
37+
38+
protected:
39+
void add_operator(std::shared_ptr<GraphOperator> op);
40+
41+
std::vector<std::shared_ptr<GraphOperator>> op_list_;
42+
43+
friend class GraphManager;
44+
};
45+
} // namespace infinicore::graph
46+
47+
#define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \
48+
class __OP_NAME__ : public graph::GraphOperator { \
49+
public: \
50+
using schema = void (*)(__VA_ARGS__); \
51+
using plan_schema = void *(*)(__VA_ARGS__); \
52+
static common::OpDispatcher<plan_schema> &plan_dispatcher(); \
53+
static common::OpDispatcher<run_schema> &run_dispatcher(); \
54+
static common::OpDispatcher<cleanup_schema> &cleanup_dispatcher(); \
55+
__OP_NAME__(__VA_ARGS__); \
56+
static void execute(__VA_ARGS__); \
57+
};
58+
59+
#define INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(__OP_NAME__) \
60+
common::OpDispatcher<__OP_NAME__::plan_schema> &__OP_NAME__::plan_dispatcher() { \
61+
static common::OpDispatcher<__OP_NAME__::plan_schema> dispatcher_; \
62+
return dispatcher_; \
63+
} \
64+
common::OpDispatcher<__OP_NAME__::run_schema> &__OP_NAME__::run_dispatcher() { \
65+
static common::OpDispatcher<__OP_NAME__::run_schema> dispatcher_; \
66+
return dispatcher_; \
67+
} \
68+
common::OpDispatcher<__OP_NAME__::cleanup_schema> &__OP_NAME__::cleanup_dispatcher() { \
69+
static common::OpDispatcher<__OP_NAME__::cleanup_schema> dispatcher_; \
70+
return dispatcher_; \
71+
}
72+
73+
#define INFINICORE_GRAPH_OP_DISPATCH(__DEVICE_TYPE__, ...) \
74+
planned_meta_ = plan_dispatcher().lookup(__DEVICE_TYPE__)(__VA_ARGS__); \
75+
runner_ = run_dispatcher().lookup(__DEVICE_TYPE__); \
76+
deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__);
77+
78+
#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \
79+
auto op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
80+
if (context::isGraphRecording()) { \
81+
context::addGraphOperator(op); \
82+
} else { \
83+
op->run(); \
84+
}
85+
86+
#define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \
87+
static bool registered = []() { \
88+
__OP_NAME__::plan_dispatcher().registerAll(__PLAN_F__, false); \
89+
__OP_NAME__::run_dispatcher().registerAll(__RUN_F__, false); \
90+
__OP_NAME__::cleanup_dispatcher().registerAll(__CLEANUP_F__, false); \
91+
return true; \
92+
}();

include/infinicore/ops/gemm.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
78

8-
class Gemm {
9-
public:
10-
using schema = void (*)(Tensor, Tensor, Tensor, float, float);
11-
static void execute(Tensor c, Tensor a, Tensor b, float alpha, float beta);
12-
static common::OpDispatcher<schema> &dispatcher();
13-
};
9+
INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, Tensor, Tensor, float, float);
1410

1511
Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f);
1612
void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta);

include/infinicore/tensor.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ class TensorImpl : public std::enable_shared_from_this<TensorImpl> {
133133

134134
void debug() const;
135135

136+
Tensor to_blob() const;
137+
136138
///
137139
/// Data Transfer APIs
138140
///
@@ -294,7 +296,7 @@ class TensorImpl : public std::enable_shared_from_this<TensorImpl> {
294296

295297
friend class Tensor;
296298

297-
private:
299+
protected:
298300
TensorMetaData meta_;
299301
TensorData data_;
300302
};

python/infinicore/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
get_device,
99
get_device_count,
1010
get_stream,
11+
is_graph_recording,
1112
set_device,
13+
start_graph_recording,
14+
stop_graph_recording,
1215
sync_device,
1316
sync_stream,
1417
)
@@ -81,6 +84,9 @@
8184
"set_device",
8285
"sync_device",
8386
"sync_stream",
87+
"is_graph_recording",
88+
"start_graph_recording",
89+
"stop_graph_recording",
8490
# Data Types.
8591
"bfloat16",
8692
"bool",

python/infinicore/context.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import infinicore.device
2+
from infinicore.graph import Graph
23
from infinicore.lib import _infinicore
34

45

@@ -49,3 +50,24 @@ def get_stream():
4950
stream: The current stream object
5051
"""
5152
return _infinicore.get_stream()
53+
54+
55+
def is_graph_recording():
56+
"""Check if the current graph is recording.
57+
58+
Returns:
59+
bool: True if the current graph is recording, False otherwise
60+
"""
61+
return _infinicore.is_graph_recording()
62+
63+
64+
def start_graph_recording(device=None):
65+
"""Start recording the current graph."""
66+
if device is not None:
67+
set_device(device)
68+
_infinicore.start_graph_recording()
69+
70+
71+
def stop_graph_recording():
72+
"""Stop recording the current graph."""
73+
return Graph(_infinicore.stop_graph_recording())

python/infinicore/graph.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from infinicore.lib import _infinicore
2+
3+
4+
class Graph:
5+
"""
6+
Python wrapper around a InfiniCore Graph instance.
7+
"""
8+
9+
def __init__(self, graph: _infinicore.Graph):
10+
if not isinstance(graph, _infinicore.Graph):
11+
raise TypeError("Expected _infinicore.Graph")
12+
self._graph = graph
13+
14+
def run(self):
15+
return self._graph.run()
16+
17+
def __repr__(self):
18+
return f"<Graph wrapper of {self._graph!r}>"

src/infinicore/context/allocators/device_pinned_allocator.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,18 @@ DevicePinnedHostAllocator::~DevicePinnedHostAllocator() {
1212
}
1313

1414
std::byte *DevicePinnedHostAllocator::allocate(size_t size) {
15+
if (size == 0) {
16+
return nullptr;
17+
}
1518
void *ptr;
1619
INFINICORE_CHECK_ERROR(infinirtMallocHost(&ptr, size));
1720
return (std::byte *)ptr;
1821
}
1922

2023
void DevicePinnedHostAllocator::deallocate(std::byte *ptr) {
24+
if (ptr == nullptr) {
25+
return;
26+
}
2127
if (owner_ == context::getDevice()) {
2228
INFINICORE_CHECK_ERROR(infinirtFreeHost(ptr));
2329
gc();

src/infinicore/context/allocators/host_allocator.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@
44

55
namespace infinicore {
66
std::byte *HostAllocator::allocate(size_t size) {
7+
if (size == 0) {
8+
return nullptr;
9+
}
710
return (std::byte *)std::malloc(size);
811
}
912

1013
void HostAllocator::deallocate(std::byte *ptr) {
14+
if (ptr == nullptr) {
15+
return;
16+
}
1117
std::free(ptr);
1218
}
1319

src/infinicore/context/allocators/pinnable_block_allocator.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "pinnable_block_allocator.hpp"
22

3+
#include "../context_impl.hpp"
4+
35
#include "../../utils.hpp"
46

57
#include <algorithm>
@@ -35,6 +37,9 @@ PinnableBlockAllocator::PinnableBlockAllocator(Device device)
3537

3638
// ------------------- allocate -------------------
3739
std::byte *PinnableBlockAllocator::allocate(size_t size) {
40+
if (size == 0) {
41+
return nullptr;
42+
}
3843
std::lock_guard<std::mutex> lock(mutex_);
3944

4045
// Align size to 256 bytes for GPU
@@ -92,7 +97,7 @@ std::byte *PinnableBlockAllocator::allocate(size_t size) {
9297

9398
// ------------------- deallocate -------------------
9499
void PinnableBlockAllocator::deallocate(std::byte *ptr) {
95-
if (!ptr) {
100+
if (ptr == nullptr) {
96101
return;
97102
}
98103

0 commit comments

Comments
 (0)