diff --git a/examples/python/heartbeat_example.py b/examples/python/heartbeat_example.py new file mode 100644 index 000000000..35096f350 --- /dev/null +++ b/examples/python/heartbeat_example.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import time + +import torch + +from nixl._api import nixl_agent, nixl_agent_config +from nixl.logging import get_logger + +logger = get_logger(__name__) + + +def run_target(): + """ + Target mode function that runs in a subprocess. + This posts metadata to etcd and then is killed. + """ + logger.info("Target subprocess started") + + config = nixl_agent_config(True, True, 5555) + + # Allocate memory and register with NIXL + agent = nixl_agent( + "target", + config, + ) + tensors = [torch.ones(10, dtype=torch.float32) for _ in range(2)] + + logger.info("Target running with tensors: %s", tensors) + + reg_descs = agent.register_memory(tensors) + if not reg_descs: + logger.error("Target: Memory registration failed.") + return + + agent.send_local_metadata() + + logger.info("Waiting to die") + + time.sleep(100) + + agent.deregister_memory(reg_descs) + + logger.info("Target subprocess complete successfully (should have died by now).") + + +if __name__ == "__main__": + # Start the target process + target_process = multiprocessing.Process(target=run_target) + target_process.start() + + logger.info("Subprocess started, pausing...") + + time.sleep(5) + + config = nixl_agent_config(True, True) + + agent = nixl_agent("initiator", config) + + # Fetch remote metadata when its ready + agent.fetch_remote_metadata("target") + + # Ensure remote metadata has arrived from fetch + ready = False + while not ready: + ready = agent.check_remote_metadata("target") + + logger.info("Ready to kill, pausing...") + + time.sleep(5) + # SIGKILL the target process to test heartbeat failure + target_process.kill() + + logger.info("Target process killed, waiting for metadata to be invalidated") + + # Wait for metadata to be invalidated + ready = True + while ready: + ready = agent.check_remote_metadata("target") + + agent.invalidate_local_metadata() + + logger.info("Test Complete.") diff --git a/src/api/cpp/nixl_params.h b/src/api/cpp/nixl_params.h index d5869b6bb..66dd15567 100644 --- a/src/api/cpp/nixl_params.h +++ b/src/api/cpp/nixl_params.h @@ -61,6 +61,12 @@ class nixlAgentConfig { */ std::chrono::microseconds etcdWatchTimeout; + /** + * @var Heartbeat interval in seconds + * Interval for heartbeat that keeps remote metadata valid. + */ + std::chrono::seconds heartbeatInterval; + /** * @brief Agent configuration constructor for enabling various features. * @param use_prog_thread flag to determine use of progress thread @@ -72,6 +78,8 @@ class nixlAgentConfig { * @param lthr_delay_us Optional delay for listener thread in us * @param capture_telemetry Optional flag to enable telemetry capture * @param etcd_watch_timeout Optional timeout for etcd watch operations in microseconds + * @param heartbeat_interval Optional timeout for how often an agent should send a + * keepalive heartbeat. Only supported in ETCD for now. */ nixlAgentConfig(const bool use_prog_thread, const bool use_listen_thread = false, @@ -82,7 +90,8 @@ class nixlAgentConfig { const uint64_t lthr_delay_us = 100000, const bool capture_telemetry = false, const std::chrono::microseconds &etcd_watch_timeout = - std::chrono::microseconds(5000000)) + std::chrono::microseconds(5000000), + const std::chrono::seconds &heartbeat_interval = std::chrono::seconds(2)) : useProgThread(use_prog_thread), useListenThread(use_listen_thread), listenPort(port), @@ -90,7 +99,8 @@ class nixlAgentConfig { captureTelemetry(capture_telemetry), pthrDelay(pthr_delay_us), lthrDelay(lthr_delay_us), - etcdWatchTimeout(etcd_watch_timeout) {} + etcdWatchTimeout(etcd_watch_timeout), + heartbeatInterval(heartbeat_interval) {} /** * @brief Copy constructor for nixlAgentConfig object diff --git a/src/core/nixl_listener.cpp b/src/core/nixl_listener.cpp index 16e3d9ca1..e317343a1 100644 --- a/src/core/nixl_listener.cpp +++ b/src/core/nixl_listener.cpp @@ -23,6 +23,7 @@ #include "common/nixl_log.h" #if HAVE_ETCD #include +#include #include #include #endif // HAVE_ETCD @@ -193,6 +194,10 @@ class nixlEtcdClient { std::hash, strEqual> agentWatchers; std::chrono::microseconds watchTimeout_; + std::thread heartbeat_thread; + std::atomic heartbeat_thread_stop = false; + std::chrono::seconds heartbeat_interval; + // Helper function to create etcd key std::string makeKey(const std::string& agent_name, const std::string& metadata_type) { @@ -203,8 +208,10 @@ class nixlEtcdClient { public: nixlEtcdClient(const std::string &my_agent_name, - const std::chrono::microseconds &timeout = std::chrono::microseconds(5000000)) - : watchTimeout_(timeout) { + const std::chrono::microseconds &timeout = std::chrono::microseconds(5000000), + const std::chrono::seconds &heartbeat = std::chrono::seconds(2)) + : watchTimeout_(timeout), + heartbeat_interval(heartbeat) { const char* etcd_endpoints = std::getenv("NIXL_ETCD_ENDPOINTS"); if (!etcd_endpoints || strlen(etcd_endpoints) == 0) { throw std::runtime_error("No etcd endpoints provided"); @@ -224,14 +231,43 @@ class nixlEtcdClient { NIXL_DEBUG << "Using etcd namespace for agents: " << namespace_prefix; + etcd::Response response = etcd->leasegrant((heartbeat.count()) * 2); + uint64_t lease_id = response.value().lease(); + + if (response.is_ok()) { + + NIXL_DEBUG << "Successfully leased " << lease_id; + } else { + throw std::runtime_error("Failed to get least for agent " + my_agent_name + + " in etcd: " + response.error_message()); + } + + heartbeat_thread = std::thread(&nixlEtcdClient::startHeartbeatThread, this, lease_id); std::string agent_prefix = makeKey(my_agent_name, ""); - etcd::Response response = etcd->put(agent_prefix, ""); + response = etcd->put(agent_prefix, "", lease_id); if (!response.is_ok()) { throw std::runtime_error("Failed to store agent " + my_agent_name + " prefix key in etcd: " + response.error_message()); } } + ~nixlEtcdClient() { + heartbeat_thread_stop = true; + if (heartbeat_thread.joinable()) { + heartbeat_thread.join(); + } + } + + void + startHeartbeatThread(uint64_t lease_id) { + while (!heartbeat_thread_stop) { + // keep alive for twice the heartbeat interval + etcd::KeepAlive keepalive(*etcd, (heartbeat_interval.count()) * 2, lease_id); + keepalive.Check(); + std::this_thread::sleep_for(heartbeat_interval); + } + } + // Store metadata in etcd nixl_status_t storeMetadataInEtcd(const std::string& agent_name, const std::string& metadata_type, @@ -249,18 +285,18 @@ class nixlEtcdClient { NIXL_DEBUG << "Successfully stored " << metadata_type << " in etcd with key: " << metadata_key << " (rev " << response.value().modified_index() << ")"; - return NIXL_SUCCESS; } else { NIXL_ERROR << "Failed to store " << metadata_type << " in etcd: " << response.error_message(); return NIXL_ERR_BACKEND; } + + return NIXL_SUCCESS; } catch (const std::exception &e) { NIXL_ERROR << "Error sending " << metadata_type << " to etcd: " << e.what(); return NIXL_ERR_BACKEND; } } - // Remove all agent's metadata from etcd nixl_status_t removeMetadataFromEtcd(const std::string& agent_name) { if (!etcd) { @@ -383,7 +419,8 @@ class nixlEtcdClient { } // Setup a watcher for an agent's metadata invalidation if it doesn't already exist - void setupAgentWatcher(const std::string &agent_name) { + void + setupAgentWatcher(const std::string &agent_name) { if (agentWatchers.find(agent_name) != agentWatchers.end()) { return; } @@ -414,6 +451,8 @@ class nixlEtcdClient { }; std::string agent_prefix = makeKey(agent_name, ""); + NIXL_DEBUG << "Create watcher for metadata " << agent_prefix; + agentWatchers[agent_name] = std::make_unique(*etcd, agent_prefix, process_response); } @@ -455,7 +494,8 @@ nixlAgentData::commWorkerInternal(nixlAgent *myAgent) { std::unique_ptr etcdClient = nullptr; // useEtcd is set in nixlAgent constructor and is true if NIXL_ETCD_ENDPOINTS is set if(useEtcd) { - etcdClient = std::make_unique(name, config.etcdWatchTimeout); + etcdClient = std::make_unique( + name, config.etcdWatchTimeout, config.heartbeatInterval); } #endif // HAVE_ETCD