diff --git a/include/infinicore.hpp b/include/infinicore.hpp index 480ab6bf8..95e4243d9 100644 --- a/include/infinicore.hpp +++ b/include/infinicore.hpp @@ -1,5 +1,6 @@ #pragma once +#include "infinicore/device_event.hpp" #include "infinicore/nn.hpp" #include "infinicore/ops.hpp" #include "infinicore/tensor.hpp" diff --git a/include/infinicore/context/context.hpp b/include/infinicore/context/context.hpp index 093004565..a7fd4e378 100644 --- a/include/infinicore/context/context.hpp +++ b/include/infinicore/context/context.hpp @@ -30,6 +30,16 @@ void memcpyD2H(void *dst, const void *src, size_t size); void memcpyD2D(void *dst, const void *src, size_t size); void memcpyH2H(void *dst, const void *src, size_t size); +// Timing APIs for performance measurement +infinirtEvent_t createEvent(); +infinirtEvent_t createEventWithFlags(uint32_t flags); +void recordEvent(infinirtEvent_t event, infinirtStream_t stream = nullptr); +bool queryEvent(infinirtEvent_t event); +void synchronizeEvent(infinirtEvent_t event); +void destroyEvent(infinirtEvent_t event); +float elapsedTime(infinirtEvent_t start, infinirtEvent_t end); +void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event); + } // namespace context } // namespace infinicore diff --git a/include/infinicore/device_event.hpp b/include/infinicore/device_event.hpp new file mode 100644 index 000000000..8ce38d645 --- /dev/null +++ b/include/infinicore/device_event.hpp @@ -0,0 +1,125 @@ +#pragma once + +#include "device.hpp" +#include "infinirt.h" +#include +#include + +namespace infinicore { + +/** + * @brief A device event for timing operations and synchronization across devices. + * + * Similar to torch.cuda.Event, this class provides functionality to: + * - Record events on specific device streams + * - Synchronize with events + * - Measure elapsed time between events + * - Query event completion status + * - Make streams wait for events + */ +class DeviceEvent { +private: + infinirtEvent_t event_; // Underlying event handle + Device device_; // Device where this event was created + bool is_recorded_; // Whether the event has been recorded + +public: + /** + * @brief Construct a new DeviceEvent on the current device. + */ + DeviceEvent(); + + /** + * @brief Construct a new DeviceEvent on the current device with specific flags. + * @param flags Event creation flags (e.g., for timing, blocking sync) + */ + explicit DeviceEvent(uint32_t flags); + + /** + * @brief Construct a new DeviceEvent on a specific device. + * @param device Target device for this event + */ + explicit DeviceEvent(Device device); + + /** + * @brief Construct a new DeviceEvent on a specific device with flags. + * @param device Target device for this event + * @param flags Event creation flags + */ + DeviceEvent(Device device, uint32_t flags); + + // Disallow copying + DeviceEvent(const DeviceEvent &) = delete; + DeviceEvent &operator=(const DeviceEvent &) = delete; + + /** + * @brief Move constructor. + */ + DeviceEvent(DeviceEvent &&other) noexcept; + + /** + * @brief Move assignment operator. + */ + DeviceEvent &operator=(DeviceEvent &&other) noexcept; + + /** + * @brief Destroy the DeviceEvent and release underlying resources. + */ + ~DeviceEvent(); + + /** + * @brief Record the event on the current stream of its device. + */ + void record(); + + /** + * @brief Record the event on a specific stream. + * @param stream Stream to record the event on + */ + void record(infinirtStream_t stream); + + /** + * @brief Wait for the event to complete (blocking). + */ + void synchronize(); + + /** + * @brief Check if the event has been completed. + * @return true if completed, false otherwise + */ + bool query() const; + + /** + * @brief Calculate elapsed time between this event and another event (in milliseconds). + * @param other The other event to compare with + * @return Elapsed time in milliseconds + * @throws std::runtime_error if events are on different devices or not recorded + */ + float elapsed_time(const DeviceEvent &other) const; + + /** + * @brief Make a stream wait for this event to complete. + * @param stream Stream to make wait for this event (nullptr for current stream) + */ + void wait(infinirtStream_t stream = nullptr) const; + + /** + * @brief Get the device where this event was created. + * @return Device associated with this event + */ + Device device() const { return device_; } + + /** + * @brief Get the underlying event handle. + * @return Raw event handle + */ + infinirtEvent_t get() const { return event_; } + + /** + * @brief Check if the event has been recorded. + * @return true if recorded, false otherwise + */ + bool is_recorded() const { return is_recorded_; } +}; + +} // namespace infinicore diff --git a/include/infinirt.h b/include/infinirt.h index ffecfef80..ba16c19b2 100644 --- a/include/infinirt.h +++ b/include/infinirt.h @@ -2,6 +2,7 @@ #define __INFINIRT_API_H__ #include "infinicore.h" +#include typedef void *infinirtStream_t; typedef void *infinirtEvent_t; @@ -27,11 +28,20 @@ typedef enum { INFINIRT_EVENT_NOT_READY = 1, } infinirtEventStatus_t; +// Event flags for precise timing +typedef enum { + INFINIRT_EVENT_DEFAULT = 0x0, // Default event creation flags + INFINIRT_EVENT_DISABLE_TIMING = 0x1, // Event will not record timing data + INFINIRT_EVENT_BLOCKING_SYNC = 0x2, // Event uses blocking synchronization +} infinirtEventFlags_t; + __C __export infiniStatus_t infinirtEventCreate(infinirtEvent_t *event_ptr); +__C __export infiniStatus_t infinirtEventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags); __C __export infiniStatus_t infinirtEventRecord(infinirtEvent_t event, infinirtStream_t stream); __C __export infiniStatus_t infinirtEventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr); __C __export infiniStatus_t infinirtEventSynchronize(infinirtEvent_t event); __C __export infiniStatus_t infinirtEventDestroy(infinirtEvent_t event); +__C __export infiniStatus_t infinirtEventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end); // Memory typedef enum { diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 014b58e97..0af940275 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -1,7 +1,18 @@ import contextlib import infinicore.nn as nn + +# Import context functions +from infinicore.context import ( + get_device, + get_device_count, + get_stream, + set_device, + sync_device, + sync_stream, +) from infinicore.device import device +from infinicore.device_event import DeviceEvent from infinicore.dtype import ( bfloat16, bool, @@ -50,8 +61,16 @@ "nn", # Classes. "device", + "DeviceEvent", "dtype", "Tensor", + # Context functions. + "get_device", + "get_device_count", + "get_stream", + "set_device", + "sync_device", + "sync_stream", # Data Types. "bfloat16", "bool", diff --git a/python/infinicore/context.py b/python/infinicore/context.py new file mode 100644 index 000000000..5e8f71028 --- /dev/null +++ b/python/infinicore/context.py @@ -0,0 +1,50 @@ +from infinicore.lib import _infinicore + + +def get_device(): + """Get the current active device. + + Returns: + device: The current active device object + """ + return _infinicore.get_device() + + +def get_device_count(device_type): + """Get the number of available devices of a specific type. + + Args: + device_type (str): The type of device to count (e.g., "cuda", "cpu", "npu") + + Returns: + int: The number of available devices of the specified type + """ + return _infinicore.get_device_count(device_type) + + +def set_device(device): + """Set the current active device. + + Args: + device: The device to set as active + """ + _infinicore.set_device(device._underlying) + + +def sync_stream(): + """Synchronize the current stream.""" + _infinicore.sync_stream() + + +def sync_device(): + """Synchronize the current device.""" + _infinicore.sync_device() + + +def get_stream(): + """Get the current stream. + + Returns: + stream: The current stream object + """ + return _infinicore.get_stream() diff --git a/python/infinicore/device_event.py b/python/infinicore/device_event.py new file mode 100644 index 000000000..5d1ef46eb --- /dev/null +++ b/python/infinicore/device_event.py @@ -0,0 +1,95 @@ +import infinicore.device +from infinicore.lib import _infinicore + + +class DeviceEvent: + """A device event for timing operations and synchronization across devices. + + Similar to torch.cuda.Event, this class provides functionality to: + - Record events on specific device streams + - Synchronize with events + - Measure elapsed time between events + - Query event completion status + - Make streams wait for events + + Args: + device: Target device for this event. If None, uses current device. + flags: Event creation flags (e.g., for timing, blocking sync). Default is 0. + enable_timing: Whether the event should be created with timing enabled. + """ + + def __init__(self, device=None, enable_timing=True, flags=0): + if not enable_timing: + # You might want to handle this differently based on your flag system + flags = flags # Adjust flags if timing is disabled + + if device is None: + # Use current device + if flags == 0: + self._underlying = _infinicore.DeviceEvent() + else: + self._underlying = _infinicore.DeviceEvent(flags) + elif flags == 0: + # Construct with device only + self._underlying = _infinicore.DeviceEvent(device._underlying) + else: + # Construct with both device and flags + self._underlying = _infinicore.DeviceEvent(device._underlying, flags) + + def record(self, stream=None): + """Record the event. + + Args: + stream: Stream to record the event on. If None, uses current stream. + """ + if stream is None: + self._underlying.record() + else: + self._underlying.record(stream) + + def synchronize(self): + """Wait for the event to complete (blocking).""" + self._underlying.synchronize() + + def query(self): + """Check if the event has been completed. + + Returns: + bool: True if completed, False otherwise. + """ + return self._underlying.query() + + def elapsed_time(self, other): + """Calculate elapsed time between this event and another event. + + Args: + other: The other DeviceEvent to compare with + + Returns: + float: Elapsed time in milliseconds between this event and the other event + + Raises: + RuntimeError: If events are on different devices or not recorded + """ + return self._underlying.elapsed_time(other._underlying) + + def wait(self, stream=None): + """Make a stream wait for this event to complete. + + Args: + stream: Stream to make wait for this event. If None, uses current stream. + """ + self._underlying.wait(stream) + + @property + def device(self): + """Get the device where this event was created.""" + return infinicore.device._from_infinicore_device(self._underlying.device) + + @property + def is_recorded(self): + """Check if the event has been recorded.""" + return self._underlying.is_recorded + + def __repr__(self): + return f"DeviceEvent(device={self.device}, recorded={self.is_recorded})" diff --git a/src/infinicore/context/context_impl.cc b/src/infinicore/context/context_impl.cc index c7a96d163..9c567c53e 100644 --- a/src/infinicore/context/context_impl.cc +++ b/src/infinicore/context/context_impl.cc @@ -58,7 +58,7 @@ ContextImpl &ContextImpl::singleton() { } ContextImpl::ContextImpl() { - std::vector device_counter(size_t(Device::Type::COUNT)); + std::vector device_counter(static_cast(Device::Type::COUNT)); INFINICORE_CHECK_ERROR(infinirtGetAllDeviceCount(device_counter.data())); // Reserve runtime slot for all devices. @@ -145,6 +145,39 @@ void memcpyH2H(void *dst, const void *src, size_t size) { return ContextImpl::singleton().getCpuRuntime()->memcpyD2D(dst, src, size); } +// Timing API implementations +infinirtEvent_t createEvent() { + return ContextImpl::singleton().getCurrentRuntime()->createEvent(); +} + +infinirtEvent_t createEventWithFlags(uint32_t flags) { + return ContextImpl::singleton().getCurrentRuntime()->createEventWithFlags(flags); +} + +void recordEvent(infinirtEvent_t event, infinirtStream_t stream) { + ContextImpl::singleton().getCurrentRuntime()->recordEvent(event, stream); +} + +bool queryEvent(infinirtEvent_t event) { + return ContextImpl::singleton().getCurrentRuntime()->queryEvent(event); +} + +void synchronizeEvent(infinirtEvent_t event) { + ContextImpl::singleton().getCurrentRuntime()->synchronizeEvent(event); +} + +void destroyEvent(infinirtEvent_t event) { + ContextImpl::singleton().getCurrentRuntime()->destroyEvent(event); +} + +float elapsedTime(infinirtEvent_t start, infinirtEvent_t end) { + return ContextImpl::singleton().getCurrentRuntime()->elapsedTime(start, end); +} + +void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) { + ContextImpl::singleton().getCurrentRuntime()->streamWaitEvent(stream, event); +} + } // namespace context } // namespace infinicore diff --git a/src/infinicore/context/runtime/runtime.cc b/src/infinicore/context/runtime/runtime.cc index 1f192011d..5e87eabf2 100644 --- a/src/infinicore/context/runtime/runtime.cc +++ b/src/infinicore/context/runtime/runtime.cc @@ -88,6 +88,54 @@ void Runtime::memcpyD2D(void *dst, const void *src, size_t size) { INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_D2D, stream_)); } +// Timing method implementations +infinirtEvent_t Runtime::createEvent() { + infinirtEvent_t event; + INFINICORE_CHECK_ERROR(infinirtEventCreate(&event)); + return event; +} + +infinirtEvent_t Runtime::createEventWithFlags(uint32_t flags) { + infinirtEvent_t event; + INFINICORE_CHECK_ERROR(infinirtEventCreateWithFlags(&event, flags)); + return event; +} + +void Runtime::recordEvent(infinirtEvent_t event, infinirtStream_t stream) { + if (stream == nullptr) { + stream = stream_; + } + INFINICORE_CHECK_ERROR(infinirtEventRecord(event, stream)); +} + +bool Runtime::queryEvent(infinirtEvent_t event) { + infinirtEventStatus_t status; + INFINICORE_CHECK_ERROR(infinirtEventQuery(event, &status)); + return status == INFINIRT_EVENT_COMPLETE; +} + +void Runtime::synchronizeEvent(infinirtEvent_t event) { + INFINICORE_CHECK_ERROR(infinirtEventSynchronize(event)); +} + +void Runtime::destroyEvent(infinirtEvent_t event) { + INFINICORE_CHECK_ERROR(infinirtEventDestroy(event)); +} + +float Runtime::elapsedTime(infinirtEvent_t start, infinirtEvent_t end) { + float ms; + INFINICORE_CHECK_ERROR(infinirtEventElapsedTime(&ms, start, end)); + return ms; +} + +void Runtime::streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) { + // Use current stream if no specific stream is provided + if (stream == nullptr) { + stream = stream_; + } + INFINICORE_CHECK_ERROR(infinirtStreamWaitEvent(stream, event)); +} + std::string Runtime::toString() const { return fmt::format("Runtime({})", device_.toString()); } diff --git a/src/infinicore/context/runtime/runtime.hpp b/src/infinicore/context/runtime/runtime.hpp index 4e0ba7abc..a2d8c8115 100644 --- a/src/infinicore/context/runtime/runtime.hpp +++ b/src/infinicore/context/runtime/runtime.hpp @@ -38,6 +38,16 @@ class Runtime { void memcpyD2H(void *dst, const void *src, size_t size); void memcpyD2D(void *dst, const void *src, size_t size); + // Timing methods + infinirtEvent_t createEvent(); + infinirtEvent_t createEventWithFlags(uint32_t flags); + void recordEvent(infinirtEvent_t event, infinirtStream_t stream = nullptr); + bool queryEvent(infinirtEvent_t event); + void synchronizeEvent(infinirtEvent_t event); + void destroyEvent(infinirtEvent_t event); + float elapsedTime(infinirtEvent_t start, infinirtEvent_t end); + void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event); + std::string toString() const; friend class ContextImpl; diff --git a/src/infinicore/device_event.cc b/src/infinicore/device_event.cc new file mode 100644 index 000000000..347e7effd --- /dev/null +++ b/src/infinicore/device_event.cc @@ -0,0 +1,180 @@ +#include "infinicore.hpp" + +namespace infinicore { + +DeviceEvent::DeviceEvent() + : device_(context::getDevice()), is_recorded_(false) { + event_ = context::createEvent(); +} + +DeviceEvent::DeviceEvent(uint32_t flags) + : device_(context::getDevice()), is_recorded_(false) { + event_ = context::createEventWithFlags(flags); +} + +DeviceEvent::DeviceEvent(Device device) + : device_(device), is_recorded_(false) { + // Switch to target device for event creation + Device current_device = context::getDevice(); + context::setDevice(device_); + event_ = context::createEvent(); + // Restore original device + context::setDevice(current_device); +} + +DeviceEvent::DeviceEvent(Device device, uint32_t flags) + : device_(device), is_recorded_(false) { + // Switch to target device for event creation + Device current_device = context::getDevice(); + context::setDevice(device_); + event_ = context::createEventWithFlags(flags); + // Restore original device + context::setDevice(current_device); +} + +DeviceEvent::DeviceEvent(DeviceEvent &&other) noexcept + : event_(other.event_), device_(other.device_), is_recorded_(other.is_recorded_) { + other.event_ = nullptr; + other.is_recorded_ = false; +} + +DeviceEvent &DeviceEvent::operator=(DeviceEvent &&other) noexcept { + if (this != &other) { + // Clean up current resources + if (event_ != nullptr) { + context::destroyEvent(event_); + } + + // Transfer ownership + event_ = other.event_; + device_ = other.device_; + is_recorded_ = other.is_recorded_; + + // Reset source + other.event_ = nullptr; + other.is_recorded_ = false; + } + return *this; +} + +DeviceEvent::~DeviceEvent() { + if (event_ != nullptr) { + context::destroyEvent(event_); + } +} + +void DeviceEvent::record() { + Device current_device = context::getDevice(); + + // Ensure we're on the correct device + if (current_device != device_) { + context::setDevice(device_); + } + + context::recordEvent(event_); + is_recorded_ = true; + + // Restore original device if we changed it + if (current_device != device_) { + context::setDevice(current_device); + } +} + +void DeviceEvent::record(infinirtStream_t stream) { + Device current_device = context::getDevice(); + + // Ensure we're on the correct device + if (current_device != device_) { + context::setDevice(device_); + } + + context::recordEvent(event_, stream); + is_recorded_ = true; + + // Restore original device if we changed it + if (current_device != device_) { + context::setDevice(current_device); + } +} + +void DeviceEvent::synchronize() { + Device current_device = context::getDevice(); + + // Ensure we're on the correct device + if (current_device != device_) { + context::setDevice(device_); + } + + context::synchronizeEvent(event_); + + // Restore original device if we changed it + if (current_device != device_) { + context::setDevice(current_device); + } +} + +bool DeviceEvent::query() const { + Device current_device = context::getDevice(); + bool result = false; + + // Ensure we're on the correct device + if (current_device != device_) { + context::setDevice(device_); + } + + result = context::queryEvent(event_); + + // Restore original device if we changed it + if (current_device != device_) { + context::setDevice(current_device); + } + + return result; +} + +float DeviceEvent::elapsed_time(const DeviceEvent &other) const { + // Both events must be on the same device + if (device_ != other.device_) { + throw std::runtime_error("Cannot measure elapsed time between events on different devices"); + } + + // Both events must be recorded + if (!is_recorded_ || !other.is_recorded_) { + throw std::runtime_error("Both events must be recorded before measuring elapsed time"); + } + + Device current_device = context::getDevice(); + + // Switch to the device where events reside + if (current_device != device_) { + context::setDevice(device_); + } + + float elapsed_ms = context::elapsedTime(event_, other.event_); + + // Restore original device if we changed it + if (current_device != device_) { + context::setDevice(current_device); + } + + return elapsed_ms; +} + +void DeviceEvent::wait(infinirtStream_t stream) const { + Device current_device = context::getDevice(); + + // Ensure we're on the correct device + if (current_device != device_) { + context::setDevice(device_); + } + + // Make the stream wait for this event + context::streamWaitEvent(stream, event_); + + // Restore original device if we changed it + if (current_device != device_) { + context::setDevice(current_device); + } +} + +} // namespace infinicore diff --git a/src/infinicore/pybind11/context.hpp b/src/infinicore/pybind11/context.hpp index 2f215f9a1..774074e4f 100644 --- a/src/infinicore/pybind11/context.hpp +++ b/src/infinicore/pybind11/context.hpp @@ -9,8 +9,21 @@ namespace py = pybind11; namespace infinicore::context { inline void bind(py::module &m) { - m.def("get_device", &getDevice); - m.def("get_device_count", &getDeviceCount); + // Device management + m.def("get_device", &getDevice, "Get the current active device"); + m.def("get_device_count", &getDeviceCount, + "Get the number of available devices of a specific type", + py::arg("device_type")); + m.def("set_device", &setDevice, + "Set the current active device", + py::arg("device")); + + // Stream and handle management + m.def("get_stream", &getStream, "Get the current stream"); + + // Synchronization + m.def("sync_stream", &syncStream, "Synchronize the current stream"); + m.def("sync_device", &syncDevice, "Synchronize the current device"); } -} // namespace infinicore::context +} // namespace infinicore::context \ No newline at end of file diff --git a/src/infinicore/pybind11/device_event.hpp b/src/infinicore/pybind11/device_event.hpp new file mode 100644 index 000000000..f482422bb --- /dev/null +++ b/src/infinicore/pybind11/device_event.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include "infinicore.hpp" +#include + +namespace py = pybind11; + +namespace infinicore::device_event { + +inline void bind(py::module &m) { + py::class_(m, "DeviceEvent") + .def(py::init<>(), "Construct a DeviceEvent on the current device") + .def(py::init(), "Construct a DeviceEvent with specific flags", py::arg("flags")) + .def(py::init(), "Construct a DeviceEvent on a specific device", py::arg("device")) + .def(py::init(), "Construct a DeviceEvent on a specific device with flags", + py::arg("device"), py::arg("flags")) + + .def("record", py::overload_cast<>(&DeviceEvent::record), + "Record the event on the current stream of its device") + .def("record", py::overload_cast(&DeviceEvent::record), + "Record the event on a specific stream", py::arg("stream")) + + .def("synchronize", &DeviceEvent::synchronize, + "Wait for the event to complete (blocking)") + .def("query", &DeviceEvent::query, + "Check if the event has been completed") + + .def("elapsed_time", &DeviceEvent::elapsed_time, + "Calculate elapsed time between this event and another event (in milliseconds)", + py::arg("other")) + + .def("wait", &DeviceEvent::wait, + "Make a stream wait for this event to complete", + py::arg("stream") = nullptr) + + .def_property_readonly("device", &DeviceEvent::device, + "Get the device where this event was created") + .def_property_readonly("is_recorded", &DeviceEvent::is_recorded, + "Check if the event has been recorded"); +} + +} // namespace infinicore::device_event diff --git a/src/infinicore/pybind11/infinicore.cc b/src/infinicore/pybind11/infinicore.cc index 981c727d6..32a6c419e 100644 --- a/src/infinicore/pybind11/infinicore.cc +++ b/src/infinicore/pybind11/infinicore.cc @@ -3,6 +3,7 @@ #include "context.hpp" #include "device.hpp" +#include "device_event.hpp" #include "dtype.hpp" #include "ops.hpp" #include "tensor.hpp" @@ -12,6 +13,7 @@ namespace infinicore { PYBIND11_MODULE(_infinicore, m) { context::bind(m); device::bind(m); + device_event::bind(m); dtype::bind(m); ops::bind(m); tensor::bind(m); diff --git a/src/infinirt/ascend/infinirt_ascend.cc b/src/infinirt/ascend/infinirt_ascend.cc index 4512d7746..4731f086a 100644 --- a/src/infinirt/ascend/infinirt_ascend.cc +++ b/src/infinirt/ascend/infinirt_ascend.cc @@ -64,6 +64,10 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) { return INFINI_STATUS_SUCCESS; } +infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) { + return INFINI_STATUS_NOT_IMPLEMENTED; +} + infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) { CHECK_ACLRT(aclrtRecordEvent((aclrtEvent)event, (aclrtStream)stream)); return INFINI_STATUS_SUCCESS; @@ -90,6 +94,10 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) { return INFINI_STATUS_SUCCESS; } +infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) { + return INFINI_STATUS_NOT_IMPLEMENTED; +} + infiniStatus_t mallocDevice(void **p_ptr, size_t size) { CHECK_ACLRT(aclrtMallocAlign32(p_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); return INFINI_STATUS_SUCCESS; diff --git a/src/infinirt/bang/infinirt_bang.cc b/src/infinirt/bang/infinirt_bang.cc index c04c45584..bccbb7c19 100644 --- a/src/infinirt/bang/infinirt_bang.cc +++ b/src/infinirt/bang/infinirt_bang.cc @@ -51,6 +51,10 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) { return INFINI_STATUS_SUCCESS; } +infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) { + return INFINI_STATUS_NOT_IMPLEMENTED; +} + infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) { CHECK_BANGRT(cnrtPlaceNotifier((cnrtNotifier_t)event, (cnrtQueue_t)stream)); return INFINI_STATUS_SUCCESS; @@ -78,6 +82,10 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) { return INFINI_STATUS_SUCCESS; } +infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) { + return INFINI_STATUS_NOT_IMPLEMENTED; +} + infiniStatus_t mallocDevice(void **p_ptr, size_t size) { CHECK_BANGRT(cnrtMalloc(p_ptr, size)); return INFINI_STATUS_SUCCESS; diff --git a/src/infinirt/cpu/infinirt_cpu.cc b/src/infinirt/cpu/infinirt_cpu.cc index ea46deb02..c8709b1d4 100644 --- a/src/infinirt/cpu/infinirt_cpu.cc +++ b/src/infinirt/cpu/infinirt_cpu.cc @@ -1,4 +1,5 @@ #include "infinirt_cpu.h" +#include #include #include @@ -34,23 +35,50 @@ infiniStatus_t streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) { } infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) { - return INFINI_STATUS_NOT_IMPLEMENTED; + // For CPU implementation, we use a simple timestamp as event + auto now = std::chrono::steady_clock::now(); + auto *timestamp = new std::chrono::steady_clock::time_point(now); + *event_ptr = timestamp; + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) { + // CPU implementation ignores flags for simplicity + return eventCreate(event_ptr); } infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) { - return INFINI_STATUS_NOT_IMPLEMENTED; + // Update the event timestamp + auto *timestamp = static_cast(event); + *timestamp = std::chrono::steady_clock::now(); + return INFINI_STATUS_SUCCESS; } infiniStatus_t eventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr) { - return INFINI_STATUS_NOT_IMPLEMENTED; + // CPU events are always complete immediately + *status_ptr = INFINIRT_EVENT_COMPLETE; + return INFINI_STATUS_SUCCESS; } infiniStatus_t eventSynchronize(infinirtEvent_t event) { - return INFINI_STATUS_NOT_IMPLEMENTED; + // CPU events are synchronized immediately + return INFINI_STATUS_SUCCESS; } infiniStatus_t eventDestroy(infinirtEvent_t event) { - return INFINI_STATUS_NOT_IMPLEMENTED; + auto *timestamp = static_cast(event); + delete timestamp; + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) { + auto *start_time = static_cast(start); + auto *end_time = static_cast(end); + + auto duration = std::chrono::duration_cast(*end_time - *start_time); + *ms_ptr = static_cast(duration.count()) / 1000.0f; // Convert microseconds to milliseconds + + return INFINI_STATUS_SUCCESS; } infiniStatus_t mallocDevice(void **p_ptr, size_t size) { diff --git a/src/infinirt/cuda/infinirt_cuda.cu b/src/infinirt/cuda/infinirt_cuda.cu index cc41617ac..d6fd9ccbe 100644 --- a/src/infinirt/cuda/infinirt_cuda.cu +++ b/src/infinirt/cuda/infinirt_cuda.cu @@ -53,7 +53,27 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) { return INFINI_STATUS_SUCCESS; } +infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) { + cudaEvent_t event; + unsigned int cuda_flags = cudaEventDefault; + + // Convert infinirt flags to CUDA flags + if (flags & INFINIRT_EVENT_DISABLE_TIMING) { + cuda_flags |= cudaEventDisableTiming; + } + if (flags & INFINIRT_EVENT_BLOCKING_SYNC) { + cuda_flags |= cudaEventBlockingSync; + } + + CHECK_CUDART(cudaEventCreateWithFlags(&event, cuda_flags)); + *event_ptr = event; + return INFINI_STATUS_SUCCESS; +} + infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) { + if (event == nullptr) { + std::cout << "Event is nullptr!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" << std::endl; + } CHECK_CUDART(cudaEventRecord((cudaEvent_t)event, (cudaStream_t)stream)); return INFINI_STATUS_SUCCESS; } @@ -80,6 +100,11 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) { return INFINI_STATUS_SUCCESS; } +infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) { + CHECK_CUDART(cudaEventElapsedTime(ms_ptr, (cudaEvent_t)start, (cudaEvent_t)end)); + return INFINI_STATUS_SUCCESS; +} + infiniStatus_t mallocDevice(void **p_ptr, size_t size) { CHECK_CUDART(cudaMalloc(p_ptr, size)); return INFINI_STATUS_SUCCESS; diff --git a/src/infinirt/infinirt.cc b/src/infinirt/infinirt.cc index 119771475..c2f50b027 100644 --- a/src/infinirt/infinirt.cc +++ b/src/infinirt/infinirt.cc @@ -126,6 +126,10 @@ __C infiniStatus_t infinirtEventCreate(infinirtEvent_t *event_ptr) { INFINIRT_CALL_DEVICE_API(eventCreate, (event_ptr)); } +__C infiniStatus_t infinirtEventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) { + INFINIRT_CALL_DEVICE_API(eventCreateWithFlags, (event_ptr, flags)); +} + __C infiniStatus_t infinirtEventRecord(infinirtEvent_t event, infinirtStream_t stream) { INFINIRT_CALL_DEVICE_API(eventRecord, (event, stream)); } @@ -142,6 +146,10 @@ __C infiniStatus_t infinirtEventDestroy(infinirtEvent_t event) { INFINIRT_CALL_DEVICE_API(eventDestroy, (event)); } +__C infiniStatus_t infinirtEventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) { + INFINIRT_CALL_DEVICE_API(eventElapsedTime, (ms_ptr, start, end)); +} + __C infiniStatus_t infinirtMalloc(void **p_ptr, size_t size) { INFINIRT_CALL_DEVICE_API(mallocDevice, (p_ptr, size)); } diff --git a/src/infinirt/infinirt_impl.h b/src/infinirt/infinirt_impl.h index 0d6f8cf05..9a426c040 100644 --- a/src/infinirt/infinirt_impl.h +++ b/src/infinirt/infinirt_impl.h @@ -1,6 +1,7 @@ #ifndef __INFINIRT_IMPL_H__ #define __INFINIRT_IMPL_H__ #include "infinirt.h" +#include #define INFINIRT_DEVICE_API(IMPL, COUNT) \ infiniStatus_t getDeviceCount(int *count) COUNT; \ @@ -13,10 +14,12 @@ infiniStatus_t streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) IMPL; \ \ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) IMPL; \ + infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) IMPL; \ infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) IMPL; \ infiniStatus_t eventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr) IMPL; \ infiniStatus_t eventSynchronize(infinirtEvent_t event) IMPL; \ infiniStatus_t eventDestroy(infinirtEvent_t event) IMPL; \ + infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) IMPL; \ \ infiniStatus_t mallocDevice(void **p_ptr, size_t size) IMPL; \ infiniStatus_t mallocHost(void **p_ptr, size_t size) IMPL; \ diff --git a/src/infinirt/kunlun/infinirt_kunlun.cc b/src/infinirt/kunlun/infinirt_kunlun.cc index 726a67f8c..f2fe43680 100644 --- a/src/infinirt/kunlun/infinirt_kunlun.cc +++ b/src/infinirt/kunlun/infinirt_kunlun.cc @@ -55,6 +55,10 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) { return INFINI_STATUS_SUCCESS; } +infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) { + return INFINI_STATUS_NOT_IMPLEMENTED; +} + infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) { CHECK_KUNLUNRT(xpu_event_record((kunlunEvent_t)event, (kunlunStream_t)stream)); return INFINI_STATUS_SUCCESS; @@ -75,6 +79,10 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) { return INFINI_STATUS_SUCCESS; } +infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) { + return INFINI_STATUS_NOT_IMPLEMENTED; +} + infiniStatus_t mallocDevice(void **p_ptr, size_t size) { CHECK_KUNLUNRT(xpu_malloc(p_ptr, static_cast(size))); return INFINI_STATUS_SUCCESS; diff --git a/src/infinirt/metax/infinirt_metax.cc b/src/infinirt/metax/infinirt_metax.cc index 362a7d7ca..9cdabfcf1 100644 --- a/src/infinirt/metax/infinirt_metax.cc +++ b/src/infinirt/metax/infinirt_metax.cc @@ -50,6 +50,10 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) { return INFINI_STATUS_SUCCESS; } +infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) { + return INFINI_STATUS_NOT_IMPLEMENTED; +} + infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) { CHECK_MACART(hcEventRecord((hcEvent_t)event, (hcStream_t)stream)); return INFINI_STATUS_SUCCESS; @@ -70,6 +74,10 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) { return INFINI_STATUS_SUCCESS; } +infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) { + return INFINI_STATUS_NOT_IMPLEMENTED; +} + infiniStatus_t mallocDevice(void **p_ptr, size_t size) { CHECK_MACART(hcMalloc(p_ptr, size)); return INFINI_STATUS_SUCCESS; diff --git a/src/infinirt/moore/infinirt_moore.cc b/src/infinirt/moore/infinirt_moore.cc index e805958d5..18c966801 100644 --- a/src/infinirt/moore/infinirt_moore.cc +++ b/src/infinirt/moore/infinirt_moore.cc @@ -50,6 +50,10 @@ infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) { return INFINI_STATUS_SUCCESS; } +infiniStatus_t eventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags) { + return INFINI_STATUS_NOT_IMPLEMENTED; +} + infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) { CHECK_MUSART(musaEventRecord((musaEvent_t)event, (musaStream_t)stream)); return INFINI_STATUS_SUCCESS; @@ -77,6 +81,10 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) { return INFINI_STATUS_SUCCESS; } +infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) { + return INFINI_STATUS_NOT_IMPLEMENTED; +} + infiniStatus_t mallocDevice(void **p_ptr, size_t size) { CHECK_MUSART(musaMalloc(p_ptr, size)); return INFINI_STATUS_SUCCESS; diff --git a/test/infinicore/device_event.py b/test/infinicore/device_event.py new file mode 100644 index 000000000..acbc8e058 --- /dev/null +++ b/test/infinicore/device_event.py @@ -0,0 +1,534 @@ +import infinicore +import torch + + +def test_device_event_timing(): + """Test DeviceEvent for timing operations - using instance method API""" + print("\nTesting DeviceEvent timing...") + + # Create events + start_event = infinicore.DeviceEvent() + end_event = infinicore.DeviceEvent() + + # Create test tensors + shape = [1000, 1000] + device = infinicore.device("cuda", 0) + + # Time tensor creation and operations + start_event.record() + + # Perform some operations + t1 = infinicore.ones(shape, dtype=infinicore.float32, device=device) + t2 = infinicore.zeros(shape, dtype=infinicore.float32, device=device) + + # Simulate some computation by multiple operations + for _ in range(10): + t1 = t1.permute([1, 0]) + t2 = t2.permute([1, 0]) + + end_event.record() + + # Wait for operations to complete + end_event.synchronize() + + # Calculate elapsed time - USING INSTANCE METHOD (torch-compatible) + elapsed_time = start_event.elapsed_time(end_event) + + print(f"✓ DeviceEvent timing test passed - Elapsed time: {elapsed_time:.3f} ms") + assert elapsed_time >= 0, "Elapsed time should be non-negative" + + return elapsed_time + + +def test_device_event_query(): + """Test DeviceEvent query functionality""" + print("\nTesting DeviceEvent query...") + + event = infinicore.DeviceEvent() + + # Event should not be completed before recording + assert not event.is_recorded, "Event should not be recorded initially" + + # Record the event + event.record() + assert event.is_recorded, "Event should be recorded after record()" + + # Query completion (might be immediate for simple cases) + completed = event.query() + print(f"✓ DeviceEvent query test passed - Event completed: {completed}") + + # Ensure synchronization works + event.synchronize() + assert event.query(), "Event should be completed after synchronize()" + + +def test_multiple_devices(): + """Test operations across multiple devices""" + print("\nTesting multiple devices...") + + cuda_count = 8 + + if cuda_count > 1: + # Test operations on different devices + shape = [100, 100] + + # Create events for timing + event0_start = infinicore.DeviceEvent(device=infinicore.device("cuda", 0)) + event0_end = infinicore.DeviceEvent(device=infinicore.device("cuda", 0)) + event1_start = infinicore.DeviceEvent(device=infinicore.device("cuda", 1)) + event1_end = infinicore.DeviceEvent(device=infinicore.device("cuda", 1)) + + # Create tensors on different devices + event0_start.record() + t_device0 = infinicore.ones( + shape, dtype=infinicore.float32, device=infinicore.device("cuda", 0) + ) + event0_end.record() + + event1_start.record() + t_device1 = infinicore.zeros( + shape, dtype=infinicore.float32, device=infinicore.device("cuda", 1) + ) + event1_end.record() + + # Synchronize both devices + event0_end.synchronize() + event1_end.synchronize() + + # Calculate elapsed times + time_device0 = event0_start.elapsed_time(event0_end) + time_device1 = event1_start.elapsed_time(event1_end) + + print(f"✓ Multiple devices test passed") + print(f" Device 0 tensor creation time: {time_device0:.3f} ms") + print(f" Device 1 tensor creation time: {time_device1:.3f} ms") + + # Test operations timing + event0_start.record() + for _ in range(20): + t_device0 = t_device0.permute([1, 0]) + event0_end.record() + + event1_start.record() + for _ in range(20): + t_device1 = t_device1.permute([1, 0]) + event1_end.record() + + # Synchronize again + event0_end.synchronize() + event1_end.synchronize() + + # Calculate operation times + op_time_device0 = event0_start.elapsed_time(event0_end) + op_time_device1 = event1_start.elapsed_time(event1_end) + + print(f" Device 0 operations time: {op_time_device0:.3f} ms") + print(f" Device 1 operations time: {op_time_device1:.3f} ms") + + # Test cross-device operations if supported + try: + # Try to create an event that measures cross-device operations + cross_start = infinicore.DeviceEvent(device=infinicore.device("cuda", 0)) + cross_end = infinicore.DeviceEvent(device=infinicore.device("cuda", 0)) + + cross_start.record() + # Perform operations on both devices + for _ in range(10): + t_device0 = t_device0.permute([1, 0]) + # Note: Actual cross-device operations would require explicit synchronization + cross_end.record() + cross_end.synchronize() + + cross_time = cross_start.elapsed_time(cross_end) + print(f" Cross-device operations time: {cross_time:.3f} ms") + + except Exception as e: + print(f" Cross-device timing skipped: {e}") + + else: + print("⚠ Skipping multiple devices test (only 1 CUDA device available)") + + +def test_event_flags(): + """Test DeviceEvent with different flags""" + print("\nTesting DeviceEvent flags...") + + try: + # Test with default flags (0) + event_default = infinicore.DeviceEvent(flags=0) + event_default.record() + event_default.synchronize() + + # Test with different flag values (adjust based on available flags) + event_with_flags = infinicore.DeviceEvent(flags=1) # Example flag + event_with_flags.record() + event_with_flags.synchronize() + + print("✓ DeviceEvent flags test passed") + except Exception as e: + print(f"⚠ DeviceEvent flags test skipped: {e}") + + +def test_event_stream(): + """Test DeviceEvent with different streams""" + print("\nTesting DeviceEvent with streams...") + + try: + # Get default stream + default_stream = None + if hasattr(infinicore, "get_stream"): + default_stream = infinicore.get_stream() + else: + print("⚠ infinicore.get_stream() not available, using default stream") + + # Create event and record + event = infinicore.DeviceEvent() + if default_stream is not None: + event.record(stream=default_stream) + else: + event.record() + + event.synchronize() + + print("✓ DeviceEvent stream test passed") + except Exception as e: + print(f"⚠ DeviceEvent stream test skipped: {e}") + + +def test_concurrent_events(): + """Test multiple concurrent events""" + print("\nTesting concurrent events...") + + # Create multiple events + events = [] + for i in range(5): + events.append(infinicore.DeviceEvent()) + + # Record events with small delays + for i, event in enumerate(events): + event.record() + # Small operation + temp = infinicore.ones( + [10, 10], dtype=infinicore.float32, device=infinicore.device("cuda", 0) + ) + temp = temp.permute([1, 0]) + + # Synchronize all events + for event in events: + event.synchronize() + assert event.query(), "All events should be completed" + + print("✓ Concurrent events test passed") + + +def test_torch_style_usage(): + """Test that our API matches torch.cuda.Event usage pattern""" + print("\nTesting torch.cuda.Event style usage...") + + # This should work exactly like torch.cuda.Event + start = infinicore.DeviceEvent() + end = infinicore.DeviceEvent() + + # Record events + start.record() + + # Some operations + tensor = infinicore.ones( + [100, 100], dtype=infinicore.float32, device=infinicore.device("cuda", 0) + ) + for _ in range(5): + tensor = tensor.permute([1, 0]) + + end.record() + end.synchronize() + + # This is the torch-compatible API + time_taken = start.elapsed_time(end) + + print(f"✓ Torch-style usage test passed - Time: {time_taken:.3f} ms") + + +def test_event_synchronization(): + """Test event synchronization behavior""" + print("\nTesting event synchronization...") + + event1 = infinicore.DeviceEvent() + event2 = infinicore.DeviceEvent() + + # Record events in sequence + event1.record() + + # Some work + temp = infinicore.zeros( + [50, 50], dtype=infinicore.float32, device=infinicore.device("cuda", 0) + ) + + event2.record() + + # event2 should complete after event1 + event2.synchronize() + assert event2.query(), "event2 should be completed" + assert event1.query(), "event1 should also be completed after event2 sync" + + print("✓ Event synchronization test passed") + + +def test_event_wait_functionality(): + """Test the wait functionality of DeviceEvent""" + print("\nTesting DeviceEvent wait functionality...") + + # Create events + event1 = infinicore.DeviceEvent() + event2 = infinicore.DeviceEvent() + + # Record first event + event1.record() + + # Perform some work + tensor1 = infinicore.ones( + [500, 500], dtype=infinicore.float32, device=infinicore.device("cuda", 0) + ) + for _ in range(10): + tensor1 = tensor1.permute([1, 0]) + + # Record second event + event2.record() + + # Make event2 wait for event1 using wait() method + event2.wait() + + # Both events should be completed now + assert event1.query(), "event1 should be completed" + assert event2.query(), "event2 should be completed after waiting" + + print("✓ Event wait functionality test passed") + + +def test_stream_wait_event(): + """Test stream waiting for events""" + print("\nTesting stream wait event functionality...") + + try: + # Get the current stream + current_stream = infinicore.get_stream() + + # Create events + dependency_event = infinicore.DeviceEvent() + dependent_event = infinicore.DeviceEvent() + + # Record dependency event + dependency_event.record() + + # Perform some work that creates a dependency + tensor = infinicore.ones( + [300, 300], dtype=infinicore.float32, device=infinicore.device("cuda", 0) + ) + for _ in range(5): + tensor = tensor.permute([1, 0]) + + # Make the stream wait for the dependency event before recording dependent event + dependency_event.wait(current_stream) + + # Record dependent event after the wait + dependent_event.record() + + # Synchronize and verify + dependent_event.synchronize() + assert dependency_event.query(), "Dependency event should be completed" + assert dependent_event.query(), "Dependent event should be completed" + + print("✓ Stream wait event test passed") + + except Exception as e: + print(f"⚠ Stream wait event test skipped: {e}") + + +def test_multiple_stream_synchronization(): + """Test event-based synchronization between multiple streams""" + print("\nTesting multiple stream synchronization...") + + try: + # This test simulates a producer-consumer pattern using events + producer_event = infinicore.DeviceEvent() + consumer_event = infinicore.DeviceEvent() + + # Producer work + producer_event.record() + + # Simulate producer work (data generation) + data_tensor = infinicore.ones( + [200, 200], dtype=infinicore.float32, device=infinicore.device("cuda", 0) + ) + for _ in range(8): + data_tensor = data_tensor.permute([1, 0]) + + # Make consumer wait for producer to finish + producer_event.wait() # Wait on current stream + + # Consumer work (depends on producer's output) + processed_tensor = data_tensor.permute([1, 0]) # Consumer operation + consumer_event.record() + + # Verify the synchronization worked + consumer_event.synchronize() + assert producer_event.query(), "Producer event should be completed" + assert consumer_event.query(), "Consumer event should be completed" + + print("✓ Multiple stream synchronization test passed") + + except Exception as e: + print(f"⚠ Multiple stream synchronization test skipped: {e}") + + +def test_event_wait_with_specific_stream(): + """Test waiting on specific streams""" + print("\nTesting event wait with specific streams...") + + try: + # Get current stream + main_stream = infinicore.get_stream() + + # Create events + compute_event = infinicore.DeviceEvent() + transfer_event = infinicore.DeviceEvent() + + # Record compute event after some computation + compute_event.record() + + # Simulate computation + compute_tensor = infinicore.ones( + [150, 150], dtype=infinicore.float32, device=infinicore.device("cuda", 0) + ) + for _ in range(6): + compute_tensor = compute_tensor.permute([1, 0]) + + # Make data transfer wait for computation to complete + compute_event.wait(main_stream) + + # Record transfer event + transfer_event.record() + + # Verify synchronization + transfer_event.synchronize() + assert compute_event.query(), "Compute event should be completed" + assert transfer_event.query(), "Transfer event should be completed" + + print("✓ Event wait with specific stream test passed") + + except Exception as e: + print(f"⚠ Event wait with specific stream test skipped: {e}") + + +def test_complex_dependency_chain(): + """Test complex dependency chains using events""" + print("\nTesting complex dependency chains...") + + try: + # Create multiple events for a dependency chain + event_a = infinicore.DeviceEvent() + event_b = infinicore.DeviceEvent() + event_c = infinicore.DeviceEvent() + event_d = infinicore.DeviceEvent() + + # Stage A + event_a.record() + tensor_a = infinicore.ones( + [100, 100], dtype=infinicore.float32, device=infinicore.device("cuda", 0) + ) + for _ in range(3): + tensor_a = tensor_a.permute([1, 0]) + + # Stage B depends on A + event_a.wait() + event_b.record() + tensor_b = tensor_a.permute([1, 0]) # Depends on tensor_a + for _ in range(3): + tensor_b = tensor_b.permute([1, 0]) + + # Stage C depends on B + event_b.wait() + event_c.record() + tensor_c = tensor_b.permute([1, 0]) # Depends on tensor_b + for _ in range(3): + tensor_c = tensor_c.permute([1, 0]) + + # Stage D depends on C + event_c.wait() + event_d.record() + tensor_d = tensor_c.permute([1, 0]) # Depends on tensor_c + + # Final synchronization + event_d.synchronize() + + # Verify all events completed in order + assert event_a.query(), "Event A should be completed" + assert event_b.query(), "Event B should be completed" + assert event_c.query(), "Event C should be completed" + assert event_d.query(), "Event D should be completed" + + print("✓ Complex dependency chain test passed") + + except Exception as e: + print(f"⚠ Complex dependency chain test skipped: {e}") + + +def test_wait_before_record(): + """Test waiting for an event that hasn't been recorded yet""" + print("\nTesting wait before record behavior...") + + try: + event = infinicore.DeviceEvent() + + # This should not crash, but the behavior depends on the underlying implementation + # In most systems, waiting for an unrecorded event is undefined behavior + # We're testing that our API handles this gracefully + event.wait() + + print( + "✓ Wait before record test completed (behavior may vary by implementation)" + ) + + except Exception as e: + print(f"⚠ Wait before record test encountered expected behavior: {e}") + + +def run_all_tests(): + """Run all device-related tests""" + print("Starting DeviceEvent and device tests...") + print("=" * 50) + + try: + # Basic functionality tests + test_device_event_timing() + test_device_event_query() + test_torch_style_usage() + test_event_synchronization() + test_concurrent_events() + + # Wait functionality tests (new) + test_event_wait_functionality() + test_stream_wait_event() + test_multiple_stream_synchronization() + test_event_wait_with_specific_stream() + test_complex_dependency_chain() + test_wait_before_record() + + # Optional tests (may depend on system capabilities) + test_multiple_devices() + test_event_flags() + test_event_stream() + + print("\n" + "=" * 50) + print("🎉 All tests passed successfully!") + print("DeviceEvent wait functionality is working correctly!") + + except Exception as e: + print(f"\n❌ Test failed: {e}") + import traceback + + traceback.print_exc() + raise + + +if __name__ == "__main__": + run_all_tests() diff --git a/test/infinicore/framework/__init__.py b/test/infinicore/framework/__init__.py index dcbc57595..b794e7d30 100644 --- a/test/infinicore/framework/__init__.py +++ b/test/infinicore/framework/__init__.py @@ -1,4 +1,13 @@ from .base import TestConfig, TestRunner, TestCase, BaseOperatorTest +from .benchmark import BenchmarkUtils, BenchmarkResult +from .config import ( + get_args, + get_hardware_args_group, + get_test_devices, +) +from .datatypes import to_torch_dtype, to_infinicore_dtype +from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map +from .runner import GenericTestRunner from .tensor import TensorSpec, TensorInitializer from .utils import ( compare_results, @@ -6,21 +15,12 @@ debug, get_tolerance, infinicore_tensor_from_torch, - profile_operation, rearrange_tensor, convert_infinicore_to_torch, is_integer_dtype, is_complex_dtype, is_floating_dtype, ) -from .config import ( - get_args, - get_hardware_args_group, - get_test_devices, -) -from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map -from .datatypes import to_torch_dtype, to_infinicore_dtype -from .runner import GenericTestRunner __all__ = [ # Core types and classes @@ -43,7 +43,6 @@ "get_test_devices", "get_tolerance", "infinicore_tensor_from_torch", - "profile_operation", "rearrange_tensor", # Utility functions "to_infinicore_dtype", @@ -53,4 +52,7 @@ "is_integer_dtype", "is_complex_dtype", "is_floating_dtype", + # Benchmarking utilities + "BenchmarkUtils", + "BenchmarkResult", ] diff --git a/test/infinicore/framework/base.py b/test/infinicore/framework/base.py index 87e62088f..0124276fb 100644 --- a/test/infinicore/framework/base.py +++ b/test/infinicore/framework/base.py @@ -11,8 +11,8 @@ from .utils import ( create_test_comparator, infinicore_tensor_from_torch, - profile_operation, ) +from .benchmark import BenchmarkUtils @dataclass @@ -21,8 +21,10 @@ class TestResult: success: bool return_code: int # 0: success, -1: failure, -2: skipped, -3: partial - torch_time: float = 0.0 - infini_time: float = 0.0 + torch_host_time: float = 0.0 + torch_device_time: float = 0.0 + infini_host_time: float = 0.0 + infini_device_time: float = 0.0 error_message: str = "" test_case: Any = None device: Any = None @@ -202,8 +204,10 @@ def __init__(self, test_cases, test_config): ) # Track passed tests (both operators implemented and passed) # Add benchmark timing statistics self.benchmark_times = { - "torch_total": 0.0, - "infinicore_total": 0.0, + "torch_host_total": 0.0, + "torch_device_total": 0.0, + "infinicore_host_total": 0.0, + "infinicore_device_total": 0.0, "per_test_case": {}, # Store timing per test case } # Store test results @@ -329,8 +333,10 @@ def print_summary(self): # Print benchmark summary if benchmarking was enabled if self.config.bench and ( - self.benchmark_times["torch_total"] > 0 - or self.benchmark_times["infinicore_total"] > 0 + self.benchmark_times["torch_host_total"] > 0 + or self.benchmark_times["torch_device_total"] > 0 + or self.benchmark_times["infinicore_host_total"] > 0 + or self.benchmark_times["infinicore_device_total"] > 0 ): self._print_benchmark_summary() @@ -342,19 +348,30 @@ def _print_benchmark_summary(self): print(f"{'-'*60}") print("BENCHMARK SUMMARY") - torch_total = self.benchmark_times["torch_total"] - infinicore_total = self.benchmark_times["infinicore_total"] + torch_host_total = self.benchmark_times["torch_host_total"] + torch_device_total = self.benchmark_times["torch_device_total"] + infinicore_host_total = self.benchmark_times["infinicore_host_total"] + infinicore_device_total = self.benchmark_times["infinicore_device_total"] + + if torch_host_total > 0: + print(f"PyTorch Host Total Time: {torch_host_total * 1000:.3f} ms") + if torch_device_total > 0: + print(f"PyTorch Device Total Time: {torch_device_total * 1000:.3f} ms") + if infinicore_host_total > 0: + print(f"InfiniCore Host Total Time: {infinicore_host_total * 1000:.3f} ms") + if infinicore_device_total > 0: + print( + f"InfiniCore Device Total Time: {infinicore_device_total * 1000:.3f} ms" + ) - if torch_total > 0: - print(f"PyTorch Total Time: {torch_total * 1000:.3f} ms") - if infinicore_total > 0: - print(f"InfiniCore Total Time: {infinicore_total * 1000:.3f} ms") + # Calculate speedups + if torch_host_total > 0 and infinicore_host_total > 0: + host_speedup = torch_host_total / infinicore_host_total + print(f"Host Speedup (PyTorch/InfiniCore): {host_speedup:.2f}x") - if torch_total > 0 and infinicore_total > 0: - speedup = ( - torch_total / infinicore_total if infinicore_total > 0 else float("inf") - ) - print(f"Speedup (PyTorch/InfiniCore): {speedup:.2f}x") + if torch_device_total > 0 and infinicore_device_total > 0: + device_speedup = torch_device_total / infinicore_device_total + print(f"Device Speedup (PyTorch/InfiniCore): {device_speedup:.2f}x") def get_test_results(self): """Get all test results""" @@ -593,20 +610,27 @@ def run_test(self, device, test_case, config): test_result.return_code = -3 # Partial # Run benchmarking for partial tests if enabled if config.bench: - torch_time, infini_time = self._run_benchmarking( - config, - device_str, - torch_implemented, - infini_implemented, - inputs, - kwargs, - infini_inputs, - infini_kwargs, - test_case.output_count, - comparison_target, + torch_host, torch_device, infini_host, infini_device = ( + BenchmarkUtils.run_benchmarking( + config, + device_str, + torch_implemented, + infini_implemented, + self.torch_operator, + self.infinicore_operator, + inputs, + kwargs, + infini_inputs, + infini_kwargs, + test_case.output_count, + comparison_target, + bench_mode=config.bench, + ) ) - test_result.torch_time = torch_time - test_result.infini_time = infini_time + test_result.torch_host_time = torch_host + test_result.torch_device_time = torch_device + test_result.infini_host_time = infini_host + test_result.infini_device_time = infini_device return test_result # ========================================================================== # MULTIPLE OUTPUTS COMPARISON LOGIC @@ -716,109 +740,43 @@ def run_test(self, device, test_case, config): # UNIFIED BENCHMARKING LOGIC # ========================================================================== if config.bench: - torch_time, infini_time = self._run_benchmarking( - config, - device_str, - True, - True, - inputs, - kwargs, - infini_inputs, - infini_kwargs, - test_case.output_count, - comparison_target, + torch_host, torch_device, infini_host, infini_device = ( + BenchmarkUtils.run_benchmarking( + config, + device_str, + True, + True, + self.torch_operator, + self.infinicore_operator, + inputs, + kwargs, + infini_inputs, + infini_kwargs, + test_case.output_count, + comparison_target, + bench_mode=config.bench, + ) ) - test_result.torch_time = torch_time - test_result.infini_time = infini_time + test_result.torch_host_time = torch_host + test_result.torch_device_time = torch_device + test_result.infini_host_time = infini_host + test_result.infini_device_time = infini_device + + # Store timing information in the test runner + if hasattr(config, "_test_runner") and config._test_runner: + # Accumulate total times + config._test_runner.benchmark_times["torch_host_total"] += torch_host + config._test_runner.benchmark_times[ + "torch_device_total" + ] += torch_device + config._test_runner.benchmark_times[ + "infinicore_host_total" + ] += infini_host + config._test_runner.benchmark_times[ + "infinicore_device_total" + ] += infini_device # Test passed successfully test_result.success = True test_result.return_code = 0 return test_result - - def _run_benchmarking( - self, - config, - device_str, - torch_implemented, - infini_implemented, - inputs, - kwargs, - infini_inputs, - infini_kwargs, - output_count, - comparison_target, - ): - """ - Unified benchmarking logic with timing accumulation - - Returns: - tuple: (torch_time, infini_time) timing results - """ - # Initialize timing variables - torch_time = 0.0 - infini_time = 0.0 - - if torch_implemented: - if output_count > 1: - # For multiple outputs, just call the operator - def torch_op(): - return self.torch_operator(*inputs, **kwargs) - - else: - if comparison_target is None: - # Out-of-place benchmarking - def torch_op(): - return self.torch_operator(*inputs, **kwargs) - - else: - # In-place benchmarking - def torch_op(): - self.torch_operator(*inputs, **kwargs) - return ( - kwargs.get("out") - if "out" in kwargs - else inputs[comparison_target] - ) - - torch_time = profile_operation( - "PyTorch ", - torch_op, - device_str, - config.num_prerun, - config.num_iterations, - total=True, - ) - - if infini_implemented: - if comparison_target is None: - # Out-of-place benchmarking - def infini_op(): - return self.infinicore_operator(*infini_inputs, **infini_kwargs) - - else: - # In-place benchmarking - def infini_op(): - self.infinicore_operator(*infini_inputs, **infini_kwargs) - return ( - infini_kwargs.get("out") - if "out" in infini_kwargs - else infini_inputs[comparison_target] - ) - - infini_time = profile_operation( - "InfiniCore", - infini_op, - device_str, - config.num_prerun, - config.num_iterations, - total=True, - ) - - # Store timing information in the test runner - if hasattr(config, "_test_runner") and config._test_runner: - # Accumulate total times - config._test_runner.benchmark_times["torch_total"] += torch_time - config._test_runner.benchmark_times["infinicore_total"] += infini_time - - return torch_time, infini_time diff --git a/test/infinicore/framework/benchmark.py b/test/infinicore/framework/benchmark.py new file mode 100644 index 000000000..2bfa7f656 --- /dev/null +++ b/test/infinicore/framework/benchmark.py @@ -0,0 +1,287 @@ +""" +Benchmarking utilities for the InfiniCore testing framework +""" + +import time +import torch +import infinicore +from .utils import synchronize_device + + +class BenchmarkUtils: + """Utility class for benchmarking operations""" + + @staticmethod + def profile_operation( + desc, + func, + torch_device, + num_prerun, + num_iterations, + host_time=True, + device_time=True, + total=False, + ): + """ + Performance profiling workflow with both host and device timing + + Args: + desc: Operation description for display + func: Function to profile + torch_device: Torch device string + num_prerun: Number of warm-up runs + num_iterations: Number of iterations for timing + host_time: Whether to measure host (CPU) time + device_time: Whether to measure device time + total: Whether to return total time instead of per-iteration time + + Returns: + tuple: (host_time, device_time) timing results + """ + # Warm-up runs + for _ in range(num_prerun): + func() + + # Timed execution + host_elapsed = 0.0 + device_elapsed = 0.0 + + if host_time: + host_elapsed = BenchmarkUtils.timed_op_host( + func, num_iterations, torch_device + ) + + if device_time: + device_elapsed = BenchmarkUtils.timed_op_device( + func, num_iterations, torch_device + ) + + # Print results + if host_time and device_time: + print( + f" {desc} time - Host: {host_elapsed / num_iterations * 1000 :6f} ms, " + f"Device: {device_elapsed / num_iterations * 1000 :6f} ms" + ) + elif host_time: + print( + f" {desc} time - Host: {host_elapsed / num_iterations * 1000 :6f} ms" + ) + elif device_time: + print( + f" {desc} time - Device: {device_elapsed / num_iterations * 1000 :6f} ms" + ) + + if total: + return host_elapsed, device_elapsed + else: + return host_elapsed / num_iterations, device_elapsed / num_iterations + + @staticmethod + def timed_op_host(func, num_iterations, device): + """ + Execute function multiple times and measure total host execution time + + Args: + func: Function to execute + num_iterations: Number of iterations + device: Torch device string for synchronization + + Returns: + float: Total host execution time in seconds + """ + synchronize_device(device) + start = time.time() + for _ in range(num_iterations): + func() + synchronize_device(device) + return time.time() - start + + @staticmethod + def timed_op_device(func, num_iterations, device): + """ + Execute function multiple times and measure total device execution time using DeviceEvent + + Args: + func: Function to execute + num_iterations: Number of iterations + device: Torch device string for synchronization + + Returns: + float: Total device execution time in milliseconds + """ + # Only use DeviceEvent for GPU devices + if device in ["cpu"]: + return 0.0 + + # Create DeviceEvents for timing the entire loop + start_event = infinicore.DeviceEvent() + end_event = infinicore.DeviceEvent() + + # Record start event + start_event.record() + + # Execute the function multiple times + for _ in range(num_iterations): + func() + + # Record end event + end_event.record() + + # Synchronize to ensure all operations are complete + end_event.synchronize() + + # Calculate total elapsed time in milliseconds + total_device_time = start_event.elapsed_time(end_event) + + return total_device_time / 1000.0 # Convert to seconds + + @staticmethod + def run_benchmarking( + config, + device_str, + torch_implemented, + infini_implemented, + torch_operator, + infini_operator, + inputs, + kwargs, + infini_inputs, + infini_kwargs, + output_count, + comparison_target, + bench_mode="both", + ): + """ + Unified benchmarking logic with timing accumulation + + Args: + config: Test configuration + device_str: Torch device string + torch_implemented: Whether PyTorch operator is implemented + infini_implemented: Whether InfiniCore operator is implemented + torch_operator: PyTorch operator function + infini_operator: InfiniCore operator function + inputs: PyTorch operator inputs + kwargs: PyTorch operator keyword arguments + infini_inputs: InfiniCore operator inputs + infini_kwargs: InfiniCore operator keyword arguments + output_count: Number of outputs + comparison_target: Comparison target specification + bench_mode: Benchmark mode - "host", "device", or "both" + + Returns: + tuple: (torch_host_time, torch_device_time, infini_host_time, infini_device_time) + """ + # Determine what to time based on bench_mode + host_time = bench_mode in ["host", "both"] + device_time = bench_mode in ["device", "both"] + + # Initialize timing variables + torch_host_time = 0.0 + torch_device_time = 0.0 + infini_host_time = 0.0 + infini_device_time = 0.0 + + if torch_implemented: + if output_count > 1: + # For multiple outputs, just call the operator + def torch_op(): + return torch_operator(*inputs, **kwargs) + + else: + if comparison_target is None: + # Out-of-place benchmarking + def torch_op(): + return torch_operator(*inputs, **kwargs) + + else: + # In-place benchmarking + def torch_op(): + torch_operator(*inputs, **kwargs) + return ( + kwargs.get("out") + if "out" in kwargs + else inputs[comparison_target] + ) + + torch_host, torch_device = BenchmarkUtils.profile_operation( + "PyTorch ", + torch_op, + device_str, + config.num_prerun, + config.num_iterations, + host_time=host_time, + device_time=device_time, + total=True, + ) + torch_host_time = torch_host + torch_device_time = torch_device + + if infini_implemented: + if comparison_target is None: + # Out-of-place benchmarking + def infini_op(): + return infini_operator(*infini_inputs, **infini_kwargs) + + else: + # In-place benchmarking + def infini_op(): + infini_operator(*infini_inputs, **infini_kwargs) + return ( + infini_kwargs.get("out") + if "out" in infini_kwargs + else infini_inputs[comparison_target] + ) + + infini_host, infini_device = BenchmarkUtils.profile_operation( + "InfiniCore", + infini_op, + device_str, + config.num_prerun, + config.num_iterations, + host_time=host_time, + device_time=device_time, + total=True, + ) + infini_host_time = infini_host + infini_device_time = infini_device + + return torch_host_time, torch_device_time, infini_host_time, infini_device_time + + +class BenchmarkResult: + """Container for benchmark results""" + + def __init__(self): + self.torch_host_total = 0.0 + self.torch_device_total = 0.0 + self.infinicore_host_total = 0.0 + self.infinicore_device_total = 0.0 + self.per_test_case = {} + + def add_timing( + self, test_case_name, torch_host, torch_device, infini_host, infini_device + ): + """Add timing for a specific test case""" + self.per_test_case[test_case_name] = { + "torch_host_time": torch_host, + "torch_device_time": torch_device, + "infini_host_time": infini_host, + "infini_device_time": infini_device, + } + self.torch_host_total += torch_host + self.torch_device_total += torch_device + self.infinicore_host_total += infini_host + self.infinicore_device_total += infini_device + + def get_host_speedup(self): + """Calculate host speedup ratio""" + if self.infinicore_host_total > 0: + return self.torch_host_total / self.infinicore_host_total + return float("inf") + + def get_device_speedup(self): + """Calculate device speedup ratio""" + if self.infinicore_device_total > 0: + return self.torch_device_total / self.infinicore_device_total + return float("inf") diff --git a/test/infinicore/framework/config.py b/test/infinicore/framework/config.py index ccbff88e6..1beed00cd 100644 --- a/test/infinicore/framework/config.py +++ b/test/infinicore/framework/config.py @@ -54,9 +54,15 @@ def get_args(): # Run all tests on CPU only python test_operator.py --cpu - # Run with benchmarking on NVIDIA GPU + # Run with benchmarking on NVIDIA GPU (both host and device timing) python test_operator.py --nvidia --bench + # Run with benchmarking - host timing only + python test_operator.py --nvidia --bench host + + # Run with benchmarking - device timing only + python test_operator.py --nvidia --bench device + # Run with debug mode on multiple devices python test_operator.py --cpu --nvidia --debug @@ -72,8 +78,11 @@ def get_args(): # Core testing options parser.add_argument( "--bench", - action="store_true", - help="Enable performance benchmarking mode", + nargs="?", + const="both", + choices=["host", "device", "both"], + help="Enable performance benchmarking mode. " + "Options: host (CPU time only), device (GPU time only), both (default)", ) parser.add_argument( "--num_prerun", diff --git a/test/infinicore/framework/utils.py b/test/infinicore/framework/utils.py index 051a30321..540015484 100644 --- a/test/infinicore/framework/utils.py +++ b/test/infinicore/framework/utils.py @@ -15,35 +15,6 @@ def synchronize_device(torch_device): torch.mlu.synchronize() -def timed_op(func, num_iterations, device): - """Timed operation""" - synchronize_device(device) - start = time.time() - for _ in range(num_iterations): - func() - synchronize_device(device) - return time.time() - start - - -def profile_operation( - desc, func, torch_device, num_prerun, num_iterations, total=False -): - """ - Performance profiling workflow - """ - # Warm-up runs - for _ in range(num_prerun): - func() - - # Timed execution - elapsed = timed_op(lambda: func(), num_iterations, torch_device) - print(f" {desc} time: {elapsed / num_iterations * 1000 :6f} ms") - if total: - return elapsed - else: - return elapsed / num_iterations - - def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): """ Debug function to compare two tensors and print differences diff --git a/test/infinicore/run.py b/test/infinicore/run.py index ff642b8db..b68468a16 100644 --- a/test/infinicore/run.py +++ b/test/infinicore/run.py @@ -109,14 +109,18 @@ def import_operator_test(test_file_path): return False, f"Error importing {test_file_path}: {str(e)}" -def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False): +def run_all_op_tests( + ops_dir=None, specific_ops=None, bench=False, bench_mode="both", verbose=False +): """ Run all operator test scripts in the ops directory using direct import. Args: ops_dir (str, optional): Path to the ops directory. If None, uses auto-detection. specific_ops (list, optional): List of specific operator names to test. - extra_args (list, optional): Extra command line arguments to pass to test scripts. + bench (bool): Whether benchmarking is enabled + bench_mode (str): Benchmark mode - "host", "device", or "both" + verbose (bool): Whether verbose mode is enabled Returns: dict: Results dictionary with test names as keys and (success, test_runner, stdout, stderr) as values. @@ -174,8 +178,10 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False results = {} cumulative_timing = { - "total_torch_time": 0.0, - "total_infinicore_time": 0.0, + "total_torch_host_time": 0.0, + "total_torch_device_time": 0.0, + "total_infinicore_host_time": 0.0, + "total_infinicore_device_time": 0.0, "operators_tested": 0, } @@ -191,8 +197,10 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False results[test_name] = { "success": False, "return_code": -1, - "torch_time": 0.0, - "infini_time": 0.0, + "torch_host_time": 0.0, + "torch_device_time": 0.0, + "infini_host_time": 0.0, + "infini_device_time": 0.0, "error_message": test_instance_or_error, "test_runner": None, "stdout": "", @@ -207,8 +215,10 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False results[test_name] = { "success": False, "return_code": -1, - "torch_time": 0.0, - "infini_time": 0.0, + "torch_host_time": 0.0, + "torch_device_time": 0.0, + "infini_host_time": 0.0, + "infini_device_time": 0.0, "error_message": "No GenericTestRunner found", "test_runner": None, "stdout": "", @@ -287,15 +297,25 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False status_icon = "❌" status_text = "FAILED" - # Calculate timing - torch_time = sum(result.torch_time for result in test_results) - infini_time = sum(result.infini_time for result in test_results) + # Calculate timing for all four metrics + torch_host_time = sum(result.torch_host_time for result in test_results) + torch_device_time = sum( + result.torch_device_time for result in test_results + ) + infini_host_time = sum( + result.infini_host_time for result in test_results + ) + infini_device_time = sum( + result.infini_device_time for result in test_results + ) results[test_name] = { "success": test_success, "return_code": return_code, - "torch_time": torch_time, - "infini_time": infini_time, + "torch_host_time": torch_host_time, + "torch_device_time": torch_device_time, + "infini_host_time": infini_host_time, + "infini_device_time": infini_device_time, "error_message": "", "test_runner": test_runner, "stdout": stdout_output, @@ -308,8 +328,12 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False # Extract benchmark timing if in bench mode if bench and test_success and return_code == 0: - cumulative_timing["total_torch_time"] += torch_time - cumulative_timing["total_infinicore_time"] += infini_time + cumulative_timing["total_torch_host_time"] += torch_host_time + cumulative_timing["total_torch_device_time"] += torch_device_time + cumulative_timing["total_infinicore_host_time"] += infini_host_time + cumulative_timing[ + "total_infinicore_device_time" + ] += infini_device_time cumulative_timing["operators_tested"] += 1 except Exception as e: @@ -327,8 +351,10 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False results[test_name] = { "success": False, "return_code": -1, - "torch_time": 0.0, - "infini_time": 0.0, + "torch_host_time": 0.0, + "torch_device_time": 0.0, + "infini_host_time": 0.0, + "infini_device_time": 0.0, "error_message": str(e), "test_runner": None, "stdout": "", @@ -348,7 +374,11 @@ def run_all_op_tests(ops_dir=None, specific_ops=None, bench=False, verbose=False def print_summary( - results, verbose=False, total_expected_tests=0, cumulative_timing=None + results, + verbose=False, + total_expected_tests=0, + cumulative_timing=None, + bench_mode="both", ): """Print a comprehensive summary of test results including benchmark data.""" print(f"\n{'='*80}") @@ -405,12 +435,24 @@ def print_summary( print(f"{'-'*40}") print("BENCHMARK SUMMARY:") print(f" Operators Tested: {cumulative_timing['operators_tested']}") - print( - f" PyTorch Total Time: {cumulative_timing['total_torch_time'] * 1000:12.3f} ms" - ) - print( - f" InfiniCore Total Time: {cumulative_timing['total_infinicore_time'] * 1000:12.3f} ms" - ) + + # Display timing based on bench_mode + if bench_mode in ["host", "both"]: + print( + f" PyTorch Host Total Time: {cumulative_timing['total_torch_host_time'] * 1000:12.3f} ms" + ) + print( + f" InfiniCore Host Total Time: {cumulative_timing['total_infinicore_host_time'] * 1000:12.3f} ms" + ) + + if bench_mode in ["device", "both"]: + print( + f" PyTorch Device Total Time: {cumulative_timing['total_torch_device_time'] * 1000:12.3f} ms" + ) + print( + f" InfiniCore Device Total Time: {cumulative_timing['total_infinicore_device_time'] * 1000:12.3f} ms" + ) + print(f"{'-'*40}") # Display passed operators @@ -528,9 +570,15 @@ def generate_help_epilog(ops_dir): ) epilog_parts.append(" python run.py --cpu --nvidia --verbose") epilog_parts.append("") - epilog_parts.append(" # Run with benchmarking to get cumulative timing") + epilog_parts.append(" # Run with benchmarking (both host and device timing)") epilog_parts.append(" python run.py --cpu --bench") epilog_parts.append("") + epilog_parts.append(" # Run with host timing only") + epilog_parts.append(" python run.py --nvidia --bench host") + epilog_parts.append("") + epilog_parts.append(" # Run with device timing only") + epilog_parts.append(" python run.py --nvidia --bench device") + epilog_parts.append("") epilog_parts.append(" # List available tests without running") epilog_parts.append(" python run.py --list") epilog_parts.append("") @@ -559,6 +607,9 @@ def generate_help_epilog(ops_dir): epilog_parts.append( " - --bench mode now shows cumulative timing across all operators" ) + epilog_parts.append( + " - --bench host/device/both controls host/device timing measurement" + ) epilog_parts.append( " - --verbose mode stops execution on first error and shows full traceback" ) @@ -599,8 +650,11 @@ def main(): ) parser.add_argument( "--bench", - action="store_true", - help="Enable bench mode to show performance data", + nargs="?", + const="both", + choices=["host", "device", "both"], + help="Enable performance benchmarking mode. " + "Options: host (CPU time only), device (GPU time only), both (default)", ) get_hardware_args_group(parser) @@ -641,6 +695,10 @@ def main(): if args.verbose: print(f"Verbose mode: ENABLED (will stop on first error with full traceback)") + if args.bench: + bench_mode = args.bench if args.bench != "both" else "both" + print(f"Benchmark mode: {bench_mode.upper()} timing") + if args.ops: # Validate requested operators valid_ops = [] @@ -671,13 +729,18 @@ def main(): results, cumulative_timing = run_all_op_tests( ops_dir=ops_dir, specific_ops=args.ops, - bench=args.bench, + bench=bool(args.bench), + bench_mode=args.bench if args.bench else "both", verbose=args.verbose, ) # Print summary and exit with appropriate code all_passed = print_summary( - results, args.verbose, total_expected_tests, cumulative_timing + results, + args.verbose, + total_expected_tests, + cumulative_timing, + bench_mode=args.bench if args.bench else "both", ) # Check if there were any tests with missing implementations