Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions examples/python/heartbeat_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing
import time

import torch

from nixl._api import nixl_agent, nixl_agent_config
from nixl.logging import get_logger

logger = get_logger(__name__)


def run_target():
"""
Target mode function that runs in a subprocess.
This posts metadata to etcd and then is killed.
"""
logger.info("Target subprocess started")

config = nixl_agent_config(True, True, 5555)

# Allocate memory and register with NIXL
agent = nixl_agent(
"target",
config,
)
tensors = [torch.ones(10, dtype=torch.float32) for _ in range(2)]

logger.info("Target running with tensors: %s", tensors)

reg_descs = agent.register_memory(tensors)
if not reg_descs:
logger.error("Target: Memory registration failed.")
return

agent.send_local_metadata()

logger.info("Waiting to die")

time.sleep(100)

agent.deregister_memory(reg_descs)

logger.info("Target subprocess complete successfully (should have died by now).")


if __name__ == "__main__":
# Start the target process
target_process = multiprocessing.Process(target=run_target)
target_process.start()

logger.info("Subprocess started, pausing...")

time.sleep(5)

config = nixl_agent_config(True, True)

agent = nixl_agent("initiator", config)

# Fetch remote metadata when its ready
agent.fetch_remote_metadata("target")

# Ensure remote metadata has arrived from fetch
ready = False
while not ready:
ready = agent.check_remote_metadata("target")

logger.info("Ready to kill, pausing...")

time.sleep(5)
# SIGKILL the target process to test heartbeat failure
target_process.kill()

logger.info("Target process killed, waiting for metadata to be invalidated")

# Wait for metadata to be invalidated
ready = True
while ready:
ready = agent.check_remote_metadata("target")

agent.invalidate_local_metadata()

logger.info("Test Complete.")
14 changes: 12 additions & 2 deletions src/api/cpp/nixl_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class nixlAgentConfig {
*/
std::chrono::microseconds etcdWatchTimeout;

/**
* @var Heartbeat interval in seconds
* Interval for heartbeat that keeps remote metadata valid.
*/
std::chrono::seconds heartbeatInterval;

/**
* @brief Agent configuration constructor for enabling various features.
* @param use_prog_thread flag to determine use of progress thread
Expand All @@ -72,6 +78,8 @@ class nixlAgentConfig {
* @param lthr_delay_us Optional delay for listener thread in us
* @param capture_telemetry Optional flag to enable telemetry capture
* @param etcd_watch_timeout Optional timeout for etcd watch operations in microseconds
* @param heartbeat_interval Optional timeout for how often an agent should send a
* keepalive heartbeat. Only supported in ETCD for now.
*/
nixlAgentConfig(const bool use_prog_thread,
const bool use_listen_thread = false,
Expand All @@ -82,15 +90,17 @@ class nixlAgentConfig {
const uint64_t lthr_delay_us = 100000,
const bool capture_telemetry = false,
const std::chrono::microseconds &etcd_watch_timeout =
std::chrono::microseconds(5000000))
std::chrono::microseconds(5000000),
const std::chrono::seconds &heartbeat_interval = std::chrono::seconds(2))
: useProgThread(use_prog_thread),
useListenThread(use_listen_thread),
listenPort(port),
syncMode(sync_mode),
captureTelemetry(capture_telemetry),
pthrDelay(pthr_delay_us),
lthrDelay(lthr_delay_us),
etcdWatchTimeout(etcd_watch_timeout) {}
etcdWatchTimeout(etcd_watch_timeout),
heartbeatInterval(heartbeat_interval) {}

/**
* @brief Copy constructor for nixlAgentConfig object
Expand Down
63 changes: 54 additions & 9 deletions src/core/nixl_listener.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "common/nixl_log.h"
#if HAVE_ETCD
#include <etcd/SyncClient.hpp>
#include <etcd/KeepAlive.hpp>
#include <etcd/Watcher.hpp>
#include <future>
#endif // HAVE_ETCD
Expand Down Expand Up @@ -193,6 +194,11 @@ class nixlEtcdClient {
std::hash<std::string>, strEqual> agentWatchers;
std::chrono::microseconds watchTimeout_;

std::thread heartbeat_thread;
std::atomic<bool> heartbeat_thread_start = false;
std::atomic<bool> heartbeat_thread_stop = false;
std::chrono::seconds heartbeat_interval;

// Helper function to create etcd key
std::string makeKey(const std::string& agent_name,
const std::string& metadata_type) {
Expand All @@ -203,8 +209,10 @@ class nixlEtcdClient {

public:
nixlEtcdClient(const std::string &my_agent_name,
const std::chrono::microseconds &timeout = std::chrono::microseconds(5000000))
: watchTimeout_(timeout) {
const std::chrono::microseconds &timeout = std::chrono::microseconds(5000000),
const std::chrono::seconds &heartbeat = std::chrono::seconds(2))
: watchTimeout_(timeout),
heartbeat_interval(heartbeat) {
const char* etcd_endpoints = std::getenv("NIXL_ETCD_ENDPOINTS");
if (!etcd_endpoints || strlen(etcd_endpoints) == 0) {
throw std::runtime_error("No etcd endpoints provided");
Expand Down Expand Up @@ -232,6 +240,23 @@ class nixlEtcdClient {
}
}

~nixlEtcdClient() {
heartbeat_thread_stop = true;
if (heartbeat_thread.joinable()) {
heartbeat_thread.join();
}
}

void
startHeartbeatThread(uint64_t lease_id) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be a queue of lease ids. If an agent can store multiple metadata within etcd we will have multiple lease ids. We should add all those to a queue (like the commqueue) and the heartbeat thread should just loop over all of them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way would be add some special known key to deem an agent active and create a lease_id only for that. The agent keeps this key alive, and the other agents watcher only checks this key. If this key is deleted that means the agent is gone, so delete all MD for it.

while (!heartbeat_thread_stop) {
// keep alive for twice the heartbeat interval
etcd::KeepAlive keepalive(*etcd, (heartbeat_interval.count()) * 2, lease_id);
keepalive.Check();
std::this_thread::sleep_for(heartbeat_interval);
}
}

// Store metadata in etcd
nixl_status_t storeMetadataInEtcd(const std::string& agent_name,
const std::string& metadata_type,
Expand All @@ -243,24 +268,40 @@ class nixlEtcdClient {

try {
std::string metadata_key = makeKey(agent_name, metadata_type);
etcd::Response response = etcd->put(metadata_key, metadata);
etcd::Response response = etcd->leasegrant((heartbeat_interval.count()) * 2);
uint64_t lease_id = response.value().lease();

if (response.is_ok()) {

NIXL_DEBUG << "Successfully leased " << lease_id;
} else {
NIXL_ERROR << "Failed to get lease";
return NIXL_ERR_BACKEND;
}

response = etcd->put(metadata_key, metadata, lease_id);

if (response.is_ok()) {
NIXL_DEBUG << "Successfully stored " << metadata_type
<< " in etcd with key: " << metadata_key << " (rev "
<< response.value().modified_index() << ")";
return NIXL_SUCCESS;
} else {
NIXL_ERROR << "Failed to store " << metadata_type << " in etcd: " << response.error_message();
return NIXL_ERR_BACKEND;
}

if (!heartbeat_thread_start) {
heartbeat_thread_start = true;
heartbeat_thread =
std::thread(&nixlEtcdClient::startHeartbeatThread, this, lease_id);
}
return NIXL_SUCCESS;
}
catch (const std::exception &e) {
NIXL_ERROR << "Error sending " << metadata_type << " to etcd: " << e.what();
return NIXL_ERR_BACKEND;
}
}

// Remove all agent's metadata from etcd
nixl_status_t removeMetadataFromEtcd(const std::string& agent_name) {
if (!etcd) {
Expand Down Expand Up @@ -383,7 +424,8 @@ class nixlEtcdClient {
}

// Setup a watcher for an agent's metadata invalidation if it doesn't already exist
void setupAgentWatcher(const std::string &agent_name) {
void
setupAgentWatcher(const std::string &agent_name, const std::string metadata_label) {
if (agentWatchers.find(agent_name) != agentWatchers.end()) {
return;
}
Expand Down Expand Up @@ -413,7 +455,9 @@ class nixlEtcdClient {
}
};

std::string agent_prefix = makeKey(agent_name, "");
std::string agent_prefix = makeKey(agent_name, metadata_label);
NIXL_DEBUG << "Create watcher for metadata " << agent_prefix;

agentWatchers[agent_name] = std::make_unique<etcd::Watcher>(*etcd, agent_prefix, process_response);
}

Expand Down Expand Up @@ -455,7 +499,8 @@ nixlAgentData::commWorkerInternal(nixlAgent *myAgent) {
std::unique_ptr<nixlEtcdClient> etcdClient = nullptr;
// useEtcd is set in nixlAgent constructor and is true if NIXL_ETCD_ENDPOINTS is set
if(useEtcd) {
etcdClient = std::make_unique<nixlEtcdClient>(name, config.etcdWatchTimeout);
etcdClient = std::make_unique<nixlEtcdClient>(
name, config.etcdWatchTimeout, config.heartbeatInterval);
}
#endif // HAVE_ETCD

Expand Down Expand Up @@ -588,7 +633,7 @@ nixlAgentData::commWorkerInternal(nixlAgent *myAgent) {
}
NIXL_DEBUG << "Successfully loaded metadata for agent: " << remote_agent;

etcdClient->setupAgentWatcher(remote_agent);
etcdClient->setupAgentWatcher(remote_agent, metadata_label);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a remote agent sends multiple metadata, this watch will only be installed for the first MD since we only allow 1 watch per agent. Would we have a case where we need to only invalidate a certain MD in this flow?

break;
}
case ETCD_INVAL:
Expand Down