Skip to content

Commit 677f888

Browse files
committed
issue/571 - DeviceEvent implementation and interface
1 parent cc8784d commit 677f888

File tree

15 files changed

+1082
-4
lines changed

15 files changed

+1082
-4
lines changed

include/infinicore.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include "infinicore/device_event.hpp"
34
#include "infinicore/nn.hpp"
45
#include "infinicore/ops.hpp"
56
#include "infinicore/tensor.hpp"

include/infinicore/context/context.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ bool queryEvent(infinirtEvent_t event);
3838
void synchronizeEvent(infinirtEvent_t event);
3939
void destroyEvent(infinirtEvent_t event);
4040
float elapsedTime(infinirtEvent_t start, infinirtEvent_t end);
41+
void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event);
4142

4243
} // namespace context
4344

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#pragma once
2+
3+
#include "device.hpp"
4+
#include "infinirt.h"
5+
#include <memory>
6+
#include <stdexcept>
7+
8+
namespace infinicore {
9+
10+
/**
11+
* @brief A device event for timing operations and synchronization across devices.
12+
*
13+
* Similar to torch.cuda.Event, this class provides functionality to:
14+
* - Record events on specific device streams
15+
* - Synchronize with events
16+
* - Measure elapsed time between events
17+
* - Query event completion status
18+
* - Make streams wait for events
19+
*/
20+
class DeviceEvent {
21+
private:
22+
infinirtEvent_t event_; // Underlying event handle
23+
Device device_; // Device where this event was created
24+
bool is_recorded_; // Whether the event has been recorded
25+
26+
public:
27+
/**
28+
* @brief Construct a new DeviceEvent on the current device.
29+
*/
30+
DeviceEvent();
31+
32+
/**
33+
* @brief Construct a new DeviceEvent on the current device with specific flags.
34+
* @param flags Event creation flags (e.g., for timing, blocking sync)
35+
*/
36+
explicit DeviceEvent(uint32_t flags);
37+
38+
/**
39+
* @brief Construct a new DeviceEvent on a specific device.
40+
* @param device Target device for this event
41+
*/
42+
explicit DeviceEvent(Device device);
43+
44+
/**
45+
* @brief Construct a new DeviceEvent on a specific device with flags.
46+
* @param device Target device for this event
47+
* @param flags Event creation flags
48+
*/
49+
DeviceEvent(Device device, uint32_t flags);
50+
51+
// Disallow copying
52+
DeviceEvent(const DeviceEvent &) = delete;
53+
DeviceEvent &operator=(const DeviceEvent &) = delete;
54+
55+
/**
56+
* @brief Move constructor.
57+
*/
58+
DeviceEvent(DeviceEvent &&other) noexcept;
59+
60+
/**
61+
* @brief Move assignment operator.
62+
*/
63+
DeviceEvent &operator=(DeviceEvent &&other) noexcept;
64+
65+
/**
66+
* @brief Destroy the DeviceEvent and release underlying resources.
67+
*/
68+
~DeviceEvent();
69+
70+
/**
71+
* @brief Record the event on the current stream of its device.
72+
*/
73+
void record();
74+
75+
/**
76+
* @brief Record the event on a specific stream.
77+
* @param stream Stream to record the event on
78+
*/
79+
void record(infinirtStream_t stream);
80+
81+
/**
82+
* @brief Wait for the event to complete (blocking).
83+
*/
84+
void synchronize();
85+
86+
/**
87+
* @brief Check if the event has been completed.
88+
* @return true if completed, false otherwise
89+
*/
90+
bool query() const;
91+
92+
/**
93+
* @brief Calculate elapsed time between this event and another event (in milliseconds).
94+
* @param other The other event to compare with
95+
* @return Elapsed time in milliseconds
96+
* @throws std::runtime_error if events are on different devices or not recorded
97+
*/
98+
float elapsed_time(const DeviceEvent &other) const;
99+
100+
/**
101+
* @brief Make a stream wait for this event to complete.
102+
* @param stream Stream to make wait for this event (nullptr for current stream)
103+
*/
104+
void wait(infinirtStream_t stream = nullptr) const;
105+
106+
/**
107+
* @brief Get the device where this event was created.
108+
* @return Device associated with this event
109+
*/
110+
Device device() const { return device_; }
111+
112+
/**
113+
* @brief Get the underlying event handle.
114+
* @return Raw event handle
115+
*/
116+
infinirtEvent_t get() const { return event_; }
117+
118+
/**
119+
* @brief Check if the event has been recorded.
120+
* @return true if recorded, false otherwise
121+
*/
122+
bool is_recorded() const { return is_recorded_; }
123+
};
124+
125+
} // namespace infinicore

python/infinicore/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
import contextlib
22

33
import infinicore.nn as nn
4+
5+
# Import context functions
6+
from infinicore.context import (
7+
get_device,
8+
get_device_count,
9+
get_stream,
10+
set_device,
11+
sync_device,
12+
sync_stream,
13+
)
414
from infinicore.device import device
15+
from infinicore.device_event import DeviceEvent
516
from infinicore.dtype import (
617
bfloat16,
718
bool,
@@ -47,8 +58,16 @@
4758
"nn",
4859
# Classes.
4960
"device",
61+
"DeviceEvent",
5062
"dtype",
5163
"Tensor",
64+
# Context functions.
65+
"get_device",
66+
"get_device_count",
67+
"get_stream",
68+
"set_device",
69+
"sync_device",
70+
"sync_stream",
5271
# Data Types.
5372
"bfloat16",
5473
"bool",

python/infinicore/context.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from infinicore.lib import _infinicore
2+
3+
4+
def get_device():
5+
"""Get the current active device.
6+
7+
Returns:
8+
device: The current active device object
9+
"""
10+
return _infinicore.get_device()
11+
12+
13+
def get_device_count(device_type):
14+
"""Get the number of available devices of a specific type.
15+
16+
Args:
17+
device_type (str): The type of device to count (e.g., "cuda", "cpu", "npu")
18+
19+
Returns:
20+
int: The number of available devices of the specified type
21+
"""
22+
return _infinicore.get_device_count(device_type)
23+
24+
25+
def set_device(device):
26+
"""Set the current active device.
27+
28+
Args:
29+
device: The device to set as active
30+
"""
31+
_infinicore.set_device(device._underlying)
32+
33+
34+
def sync_stream():
35+
"""Synchronize the current stream."""
36+
_infinicore.sync_stream()
37+
38+
39+
def sync_device():
40+
"""Synchronize the current device."""
41+
_infinicore.sync_device()
42+
43+
44+
def get_stream():
45+
"""Get the current stream.
46+
47+
Returns:
48+
stream: The current stream object
49+
"""
50+
return _infinicore.get_stream()

python/infinicore/device_event.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import infinicore.device
2+
from infinicore.lib import _infinicore
3+
4+
5+
class DeviceEvent:
6+
"""A device event for timing operations and synchronization across devices.
7+
8+
Similar to torch.cuda.Event, this class provides functionality to:
9+
- Record events on specific device streams
10+
- Synchronize with events
11+
- Measure elapsed time between events
12+
- Query event completion status
13+
- Make streams wait for events
14+
15+
Args:
16+
device: Target device for this event. If None, uses current device.
17+
flags: Event creation flags (e.g., for timing, blocking sync). Default is 0.
18+
enable_timing: Whether the event should be created with timing enabled.
19+
"""
20+
21+
def __init__(self, device=None, enable_timing=True, flags=0):
22+
if not enable_timing:
23+
# You might want to handle this differently based on your flag system
24+
flags = flags # Adjust flags if timing is disabled
25+
26+
if device is None:
27+
# Use current device
28+
if flags == 0:
29+
self._underlying = _infinicore.DeviceEvent()
30+
else:
31+
self._underlying = _infinicore.DeviceEvent(flags)
32+
elif flags == 0:
33+
# Construct with device only
34+
self._underlying = _infinicore.DeviceEvent(device._underlying)
35+
else:
36+
# Construct with both device and flags
37+
self._underlying = _infinicore.DeviceEvent(device._underlying, flags)
38+
39+
def record(self, stream=None):
40+
"""Record the event.
41+
42+
Args:
43+
stream: Stream to record the event on. If None, uses current stream.
44+
"""
45+
if stream is None:
46+
self._underlying.record()
47+
else:
48+
self._underlying.record(stream)
49+
50+
def synchronize(self):
51+
"""Wait for the event to complete (blocking)."""
52+
self._underlying.synchronize()
53+
54+
def query(self):
55+
"""Check if the event has been completed.
56+
57+
Returns:
58+
bool: True if completed, False otherwise.
59+
"""
60+
return self._underlying.query()
61+
62+
def elapsed_time(self, other):
63+
"""Calculate elapsed time between this event and another event.
64+
65+
Args:
66+
other: The other DeviceEvent to compare with
67+
68+
Returns:
69+
float: Elapsed time in milliseconds between this event and the other event
70+
71+
Raises:
72+
RuntimeError: If events are on different devices or not recorded
73+
"""
74+
return self._underlying.elapsed_time(other._underlying)
75+
76+
def wait(self, stream=None):
77+
"""Make a stream wait for this event to complete.
78+
79+
Args:
80+
stream: Stream to make wait for this event. If None, uses current stream.
81+
"""
82+
self._underlying.wait(stream)
83+
84+
@property
85+
def device(self):
86+
"""Get the device where this event was created."""
87+
return infinicore.device._from_infinicore_device(self._underlying.device)
88+
89+
@property
90+
def is_recorded(self):
91+
"""Check if the event has been recorded."""
92+
return self._underlying.is_recorded
93+
94+
def __repr__(self):
95+
return f"DeviceEvent(device={self.device}, recorded={self.is_recorded})"

src/infinicore/context/context_impl.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ ContextImpl &ContextImpl::singleton() {
5858
}
5959

6060
ContextImpl::ContextImpl() {
61-
std::vector<int> device_counter(size_t(Device::Type::COUNT));
61+
std::vector<int> device_counter(static_cast<size_t>(Device::Type::COUNT));
6262
INFINICORE_CHECK_ERROR(infinirtGetAllDeviceCount(device_counter.data()));
6363

6464
// Reserve runtime slot for all devices.
@@ -168,6 +168,10 @@ float elapsedTime(infinirtEvent_t start, infinirtEvent_t end) {
168168
return ContextImpl::singleton().getCurrentRuntime()->elapsedTime(start, end);
169169
}
170170

171+
void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
172+
ContextImpl::singleton().getCurrentRuntime()->streamWaitEvent(stream, event);
173+
}
174+
171175
} // namespace context
172176

173177
} // namespace infinicore

src/infinicore/context/runtime/runtime.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ float Runtime::elapsedTime(infinirtEvent_t start, infinirtEvent_t end) {
128128
return ms;
129129
}
130130

131+
void Runtime::streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
132+
// Use current stream if no specific stream is provided
133+
if (stream == nullptr) {
134+
stream = stream_;
135+
}
136+
INFINICORE_CHECK_ERROR(infinirtStreamWaitEvent(stream, event));
137+
}
138+
131139
std::string Runtime::toString() const {
132140
return fmt::format("Runtime({})", device_.toString());
133141
}

src/infinicore/context/runtime/runtime.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class Runtime {
4646
void synchronizeEvent(infinirtEvent_t event);
4747
void destroyEvent(infinirtEvent_t event);
4848
float elapsedTime(infinirtEvent_t start, infinirtEvent_t end);
49+
void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event);
4950

5051
std::string toString() const;
5152

0 commit comments

Comments
 (0)