Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 53 additions & 16 deletions nvflare/tool/package_checker/check_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
NVFlareConfig,
NVFlareRole,
check_grpc_server_running,
check_socket_server_running,
construct_dummy_overseer_response,
get_communication_scheme,
try_bind_address,
try_write_dir,
)
Expand Down Expand Up @@ -97,14 +99,32 @@ def _get_primary_sp(sp_list):
return None


class CheckGRPCServerAvailable(CheckRule):
class CheckServerAvailable(CheckRule):
def __init__(self, name: str, role: str):
"""Initialize server availability checker.

This rule automatically detects the communication scheme (GRPC, HTTP, etc.)
and uses the appropriate method to check server connectivity.

Args:
name: Name of the check rule
role: Role of the entity performing the check (server/client/admin)
"""
super().__init__(name)
if role not in [NVFlareRole.SERVER, NVFlareRole.CLIENT, NVFlareRole.ADMIN]:
raise RuntimeError(f"role {role} is not supported.")
self.role = role

def __call__(self, package_path, data):
"""Check if server is available and accessible.

Args:
package_path: Path to the package directory
data: Additional data (unused)

Returns:
CheckResult indicating success or failure
"""
startup = os.path.join(package_path, "startup")
if self.role == NVFlareRole.SERVER:
nvf_config = NVFlareConfig.SERVER
Expand All @@ -117,27 +137,44 @@ def __call__(self, package_path, data):
with open(fed_config_file, "r") as f:
fed_config = json.load(f)

# Admin has a different config structure - handle separately
if self.role == NVFlareRole.ADMIN:
admin = fed_config["admin"]
host = admin["host"]
port = admin["port"]
overseer_agent_conf = {
"path": "nvflare.ha.dummy_overseer_agent.DummyOverseerAgent",
"args": {"sp_end_point": f"{host}:{port}:{port}"},
}
scheme = admin.get("scheme", "grpc")
else:
# For client/server, get info from overseer agent
overseer_agent_conf = fed_config["overseer_agent"]

resp = construct_dummy_overseer_response(overseer_agent_conf=overseer_agent_conf, role=self.role)
resp = resp.json()
sp_list = resp.get("sp_list", [])
psp = _get_primary_sp(sp_list)
sp_end_point = psp["sp_end_point"]
sp_name, grpc_port, admin_port = sp_end_point.split(":")

if not check_grpc_server_running(startup=startup, host=sp_name, port=int(grpc_port)):
resp = construct_dummy_overseer_response(overseer_agent_conf=overseer_agent_conf, role=self.role)
resp = resp.json()
sp_list = resp.get("sp_list", [])
psp = _get_primary_sp(sp_list)
sp_end_point = psp["sp_end_point"]
host, port, admin_port = sp_end_point.split(":")
port = int(port)

# Determine the communication scheme
scheme = get_communication_scheme(package_path, nvf_config, default_scheme="grpc")

# Check connectivity based on the communication scheme
if scheme in ["grpc", "agrpc"]:
if not check_grpc_server_running(startup=startup, host=host, port=int(port)):
return CheckResult(
f"Can't connect to {scheme} server ({host}:{port})",
"Please check if server is up.",
)
elif scheme in ["http", "https", "tcp", "stcp"]:
# HTTP/HTTPS use WebSocket, TCP/STCP use raw sockets - both checked via socket connection
if not check_socket_server_running(startup=startup, host=host, port=int(port), scheme=scheme):
return CheckResult(
f"Can't connect to {scheme} server ({host}:{port})",
"Please check if server is up.",
)
else:
return CheckResult(
f"Can't connect to grpc server ({sp_name}:{grpc_port})",
"Please check if server is up.",
f"Unsupported communication scheme: {scheme}",
f"Scheme '{scheme}' is not supported for connectivity check.",
)

return CheckResult(CHECK_PASSED, "N/A")
9 changes: 7 additions & 2 deletions nvflare/tool/package_checker/client_package_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import re
import sys

from .check_rule import CheckGRPCServerAvailable
from .check_rule import CheckServerAvailable
from .package_checker import PackageChecker
from .utils import NVFlareConfig, NVFlareRole

Expand All @@ -35,8 +35,13 @@ def should_be_checked(self) -> bool:
return False

def init_rules(self, package_path):
"""Initialize preflight check rules.

The CheckServerAvailable rule automatically detects the communication scheme
(GRPC, HTTP, etc.) and uses the appropriate connectivity check method.
"""
self.rules = [
CheckGRPCServerAvailable(name="Check GRPC server available", role=self.NVF_ROLE),
CheckServerAvailable(name="Check server available", role=self.NVF_ROLE),
]

def get_uid_from_startup_script(self) -> str:
Expand Down
27 changes: 21 additions & 6 deletions nvflare/tool/package_checker/server_package_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .check_rule import CheckAddressBinding, CheckWriting
from .package_checker import PackageChecker
from .utils import NVFlareConfig
from .utils import NVFlareConfig, get_communication_scheme

SERVER_SCRIPT = "nvflare.private.fed.app.server.server_train"

Expand Down Expand Up @@ -54,12 +54,16 @@ def _get_job_storage_root(package_path: str) -> str:
return job_storage_root


def _get_grpc_host_and_port(package_path: str) -> (str, int):
def _get_fl_host_and_port(package_path: str) -> (str, int):
"""Get federated learning service host and port.

This is the main communication port for FL, which could use GRPC, TCP, or HTTP scheme.
"""
fed_config = _get_server_fed_config(package_path)
server_conf = fed_config["servers"][0]
grpc_service_config = server_conf["service"]
grpc_target_address = grpc_service_config["target"]
_, port = grpc_target_address.split(":")
service_config = server_conf["service"]
target_address = service_config["target"]
_, port = target_address.split(":")
return "localhost", int(port)


Expand All @@ -77,8 +81,19 @@ def __init__(self):

def init_rules(self, package_path):
self.dry_run_timeout = 3

# Determine the communication scheme
scheme = get_communication_scheme(package_path, NVFlareConfig.SERVER)

supported_schemes = ["grpc", "agrpc", "http", "https", "tcp", "stcp"]
if scheme not in supported_schemes:
raise RuntimeError(
f"Communication scheme '{scheme}' is not supported. "
f"Supported schemes: {', '.join(supported_schemes)}"
)

self.rules = [
CheckAddressBinding(name="Check grpc port binding", get_host_and_port_from_package=_get_grpc_host_and_port),
CheckAddressBinding(name="Check FL port binding", get_host_and_port_from_package=_get_fl_host_and_port),
CheckAddressBinding(
name="Check admin port binding", get_host_and_port_from_package=_get_admin_host_and_port
),
Expand Down
144 changes: 108 additions & 36 deletions nvflare/tool/package_checker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
import shlex
import shutil
import socket
import ssl
import subprocess
import tempfile
from typing import Any, Dict, Optional

import grpc
from requests import Request, Response, Session
from requests.adapters import HTTPAdapter
from requests import Response


class NVFlareConfig:
Expand Down Expand Up @@ -66,25 +65,6 @@ def try_bind_address(host: str, port: int):
return None


def _create_http_session(ca_path=None, cert_path=None, prv_key_path=None):
session = Session()
adapter = HTTPAdapter(max_retries=1)
session.mount("https://", adapter)
if ca_path:
session.verify = ca_path
session.cert = (cert_path, prv_key_path)
return session


def _send_request(
session, api_point, headers: Optional[Dict[str, Any]] = None, payload: Optional[Dict[str, Any]] = None
) -> Response:
req = Request("POST", api_point, json=payload, headers=headers)
prepared = session.prepare_request(req)
resp = session.send(prepared)
return resp


def parse_overseer_agent_args(overseer_agent_conf: dict, required_args: list) -> dict:
result = {}
for k in required_args:
Expand Down Expand Up @@ -162,27 +142,26 @@ def _get_conn_sec(startup: str):


def check_grpc_server_running(startup: str, host: str, port: int, token=None) -> bool:
with open(os.path.join(startup, _get_ca_cert_file_name()), "rb") as f:
trusted_certs = f.read()
with open(os.path.join(startup, _get_prv_key_file_name(NVFlareRole.CLIENT)), "rb") as f:
private_key = f.read()
with open(os.path.join(startup, _get_cert_file_name(NVFlareRole.CLIENT)), "rb") as f:
certificate_chain = f.read()

conn_sec = _get_conn_sec(startup)
secure = True
if conn_sec == "clear":
secure = False

call_credentials = grpc.metadata_call_credentials(
lambda context, callback: callback((("x-custom-token", token),), None)
)
credentials = grpc.ssl_channel_credentials(
certificate_chain=certificate_chain, private_key=private_key, root_certificates=trusted_certs
)

composite_credentials = grpc.composite_channel_credentials(credentials, call_credentials)
if secure:
with open(os.path.join(startup, _get_ca_cert_file_name()), "rb") as f:
trusted_certs = f.read()
with open(os.path.join(startup, _get_prv_key_file_name(NVFlareRole.CLIENT)), "rb") as f:
private_key = f.read()
with open(os.path.join(startup, _get_cert_file_name(NVFlareRole.CLIENT)), "rb") as f:
certificate_chain = f.read()
call_credentials = grpc.metadata_call_credentials(
lambda context, callback: callback((("x-custom-token", token),), None)
)
credentials = grpc.ssl_channel_credentials(
certificate_chain=certificate_chain, private_key=private_key, root_certificates=trusted_certs
)
composite_credentials = grpc.composite_channel_credentials(credentials, call_credentials)
channel = grpc.secure_channel(target=f"{host}:{port}", credentials=composite_credentials)
else:
channel = grpc.insecure_channel(target=f"{host}:{port}")
Expand All @@ -194,6 +173,66 @@ def check_grpc_server_running(startup: str, host: str, port: int, token=None) ->
return True


def check_socket_server_running(startup: str, host: str, port: int, scheme: str = "https") -> bool:
"""Check if socket-based server (HTTP/HTTPS/TCP/STCP) is running and accessible.

This function performs a socket connection test with optional SSL/TLS.
It's used for HTTP/WebSocket and TCP-based FL servers.

Args:
startup: Path to startup directory containing certificates
host: Server hostname or IP address
port: Server port number
scheme: URL scheme ("http", "https", "tcp", "stcp")

Returns:
True if server is accessible, False otherwise
"""
conn_sec = _get_conn_sec(startup)
secure = True
if conn_sec == "clear":
secure = False

# Determine if we need SSL based on scheme
use_ssl = secure and scheme in ["https", "stcp"]

# Try a socket connection to check if port is reachable
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(10)

try:
if use_ssl:
# For secure connection, wrap socket with SSL and use client certificates
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
context.minimum_version = ssl.TLSVersion.TLSv1_2
ca_path = os.path.join(startup, _get_ca_cert_file_name())
cert_path = os.path.join(startup, _get_cert_file_name(NVFlareRole.CLIENT))
prv_key_path = os.path.join(startup, _get_prv_key_file_name(NVFlareRole.CLIENT))

context.load_verify_locations(ca_path)
context.load_cert_chain(cert_path, prv_key_path)
# Check hostname may fail for localhost, so disable it for preflight check
context.check_hostname = False

ssl_sock = context.wrap_socket(sock, server_hostname=host)
ssl_sock.connect((host, port))
ssl_sock.close()
else:
# For insecure connection, just check if we can connect
sock.connect((host, port))
sock.close()

return True
except (socket.timeout, socket.error, ssl.SSLError, OSError, ConnectionRefusedError):
# Connection failed - server is not accessible
return False
finally:
try:
sock.close()
except Exception:
pass


def run_command_in_subprocess(command):
new_env = os.environ.copy()
process = subprocess.Popen(
Expand All @@ -206,3 +245,36 @@ def run_command_in_subprocess(command):
universal_newlines=True,
)
return process


def get_communication_scheme(package_path: str, config_name: str, default_scheme: str = "http") -> str:
"""Read the communication scheme from package configuration files.

This function checks multiple sources to determine the communication scheme:
1. For servers: fed_server.json (service.scheme)
2. For all packages: comm_config.json in local/ or startup/ directories

Args:
package_path: Path to the package directory
config_name: Name of the configuration file (fed_server.json, fed_client.json, fed_admin.json)
default_scheme: Default scheme to return if no scheme is found

Returns:
The communication scheme (e.g., "grpc", "http")
"""
# First try to read from fed_xxx.json
startup = os.path.join(package_path, "startup")
fed_config_file = os.path.join(startup, config_name)
if os.path.exists(fed_config_file):
try:
with open(fed_config_file, "r") as f:
fed_config = json.load(f)
server_conf = fed_config.get("servers", [{}])[0]
service_config = server_conf.get("service", {})
scheme = service_config.get("scheme")
if scheme:
return scheme.lower()
except Exception:
pass

return default_scheme
5 changes: 0 additions & 5 deletions tests/integration_test/data/projects/dummy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@ builders:
args:
# config_folder can be set to inform NVIDIA FLARE where to get configuration
config_folder: config
overseer_agent:
path: nvflare.ha.dummy_overseer_agent.DummyOverseerAgent
overseer_exists: false
args:
sp_end_point: localhost0:8002:8003

- path: nvflare.lighter.impl.cert.CertBuilder
- path: nvflare.lighter.impl.signature.SignatureBuilder
Loading
Loading