Skip to content

Commit 98ec87d

Browse files
[2.7] Cherry pick of Fix preflight check and ci (#3917) (#3929)
Preflight check tool was hardcoded to test against GRPC communication scheme. We have added more schemes and now our default is HTTP so we should change accordingly - Update preflight check - Remove overseer test - Update preflight check tests <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --------- Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --------- Co-authored-by: Chester Chen <[email protected]>
1 parent 2d4fdd4 commit 98ec87d

File tree

7 files changed

+334
-124
lines changed

7 files changed

+334
-124
lines changed

nvflare/tool/package_checker/check_rule.py

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,16 @@
2020
NVFlareConfig,
2121
NVFlareRole,
2222
check_grpc_server_running,
23+
check_socket_server_running,
2324
construct_dummy_overseer_response,
25+
get_communication_scheme,
2426
try_bind_address,
2527
try_write_dir,
2628
)
2729

2830
CHECK_PASSED = "PASSED"
31+
# please refer to nvflare/lighter/impl/static_file.py
32+
DEFAULT_SCHEME = "http"
2933

3034

3135
class CheckResult:
@@ -97,14 +101,32 @@ def _get_primary_sp(sp_list):
97101
return None
98102

99103

100-
class CheckGRPCServerAvailable(CheckRule):
104+
class CheckServerAvailable(CheckRule):
101105
def __init__(self, name: str, role: str):
106+
"""Initialize server availability checker.
107+
108+
This rule automatically detects the communication scheme (GRPC, HTTP, etc.)
109+
and uses the appropriate method to check server connectivity.
110+
111+
Args:
112+
name: Name of the check rule
113+
role: Role of the entity performing the check (server/client/admin)
114+
"""
102115
super().__init__(name)
103116
if role not in [NVFlareRole.SERVER, NVFlareRole.CLIENT, NVFlareRole.ADMIN]:
104117
raise RuntimeError(f"role {role} is not supported.")
105118
self.role = role
106119

107120
def __call__(self, package_path, data):
121+
"""Check if server is available and accessible.
122+
123+
Args:
124+
package_path: Path to the package directory
125+
data: Additional data (unused)
126+
127+
Returns:
128+
CheckResult indicating success or failure
129+
"""
108130
startup = os.path.join(package_path, "startup")
109131
if self.role == NVFlareRole.SERVER:
110132
nvf_config = NVFlareConfig.SERVER
@@ -117,27 +139,44 @@ def __call__(self, package_path, data):
117139
with open(fed_config_file, "r") as f:
118140
fed_config = json.load(f)
119141

142+
# Admin has a different config structure - handle separately
120143
if self.role == NVFlareRole.ADMIN:
121144
admin = fed_config["admin"]
122145
host = admin["host"]
123-
port = admin["port"]
124-
overseer_agent_conf = {
125-
"path": "nvflare.ha.dummy_overseer_agent.DummyOverseerAgent",
126-
"args": {"sp_end_point": f"{host}:{port}:{port}"},
127-
}
146+
port = int(admin["port"])
147+
scheme = admin.get("scheme", DEFAULT_SCHEME)
128148
else:
149+
# For client/server, get info from overseer agent
129150
overseer_agent_conf = fed_config["overseer_agent"]
130-
131-
resp = construct_dummy_overseer_response(overseer_agent_conf=overseer_agent_conf, role=self.role)
132-
resp = resp.json()
133-
sp_list = resp.get("sp_list", [])
134-
psp = _get_primary_sp(sp_list)
135-
sp_end_point = psp["sp_end_point"]
136-
sp_name, grpc_port, admin_port = sp_end_point.split(":")
137-
138-
if not check_grpc_server_running(startup=startup, host=sp_name, port=int(grpc_port)):
151+
resp = construct_dummy_overseer_response(overseer_agent_conf=overseer_agent_conf, role=self.role)
152+
resp = resp.json()
153+
sp_list = resp.get("sp_list", [])
154+
psp = _get_primary_sp(sp_list)
155+
sp_end_point = psp["sp_end_point"]
156+
host, port, admin_port = sp_end_point.split(":")
157+
port = int(port)
158+
159+
# Determine the communication scheme
160+
scheme = get_communication_scheme(package_path, nvf_config, default_scheme=DEFAULT_SCHEME)
161+
162+
# Check connectivity based on the communication scheme
163+
if scheme in ["grpc", "agrpc"]:
164+
if not check_grpc_server_running(startup=startup, host=host, port=port):
165+
return CheckResult(
166+
f"Can't connect to {scheme} server ({host}:{port})",
167+
"Please check if server is up.",
168+
)
169+
elif scheme in ["http", "https", "tcp", "stcp"]:
170+
# HTTP/HTTPS use WebSocket, TCP/STCP use raw sockets - both checked via socket connection
171+
if not check_socket_server_running(startup=startup, host=host, port=port, scheme=scheme):
172+
return CheckResult(
173+
f"Can't connect to {scheme} server ({host}:{port})",
174+
"Please check if server is up.",
175+
)
176+
else:
139177
return CheckResult(
140-
f"Can't connect to grpc server ({sp_name}:{grpc_port})",
141-
"Please check if server is up.",
178+
f"Unsupported communication scheme: {scheme}",
179+
f"Scheme '{scheme}' is not supported for connectivity check.",
142180
)
181+
143182
return CheckResult(CHECK_PASSED, "N/A")

nvflare/tool/package_checker/client_package_checker.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import re
1717
import sys
1818

19-
from .check_rule import CheckGRPCServerAvailable
19+
from .check_rule import CheckServerAvailable
2020
from .package_checker import PackageChecker
2121
from .utils import NVFlareConfig, NVFlareRole
2222

@@ -35,8 +35,13 @@ def should_be_checked(self) -> bool:
3535
return False
3636

3737
def init_rules(self, package_path):
38+
"""Initialize preflight check rules.
39+
40+
The CheckServerAvailable rule automatically detects the communication scheme
41+
(GRPC, HTTP, etc.) and uses the appropriate connectivity check method.
42+
"""
3843
self.rules = [
39-
CheckGRPCServerAvailable(name="Check GRPC server available", role=self.NVF_ROLE),
44+
CheckServerAvailable(name="Check server available", role=self.NVF_ROLE),
4045
]
4146

4247
def get_uid_from_startup_script(self) -> str:

nvflare/tool/package_checker/server_package_checker.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from .check_rule import CheckAddressBinding, CheckWriting
2121
from .package_checker import PackageChecker
22-
from .utils import NVFlareConfig
22+
from .utils import NVFlareConfig, get_communication_scheme
2323

2424
SERVER_SCRIPT = "nvflare.private.fed.app.server.server_train"
2525

@@ -54,12 +54,16 @@ def _get_job_storage_root(package_path: str) -> str:
5454
return job_storage_root
5555

5656

57-
def _get_grpc_host_and_port(package_path: str) -> (str, int):
57+
def _get_fl_host_and_port(package_path: str) -> (str, int):
58+
"""Get federated learning service host and port.
59+
60+
This is the main communication port for FL, which could use GRPC, TCP, or HTTP scheme.
61+
"""
5862
fed_config = _get_server_fed_config(package_path)
5963
server_conf = fed_config["servers"][0]
60-
grpc_service_config = server_conf["service"]
61-
grpc_target_address = grpc_service_config["target"]
62-
_, port = grpc_target_address.split(":")
64+
service_config = server_conf["service"]
65+
target_address = service_config["target"]
66+
_, port = target_address.split(":")
6367
return "localhost", int(port)
6468

6569

@@ -77,8 +81,19 @@ def __init__(self):
7781

7882
def init_rules(self, package_path):
7983
self.dry_run_timeout = 3
84+
85+
# Determine the communication scheme
86+
scheme = get_communication_scheme(package_path, NVFlareConfig.SERVER)
87+
88+
supported_schemes = ["grpc", "agrpc", "http", "https", "tcp", "stcp"]
89+
if scheme not in supported_schemes:
90+
raise RuntimeError(
91+
f"Communication scheme '{scheme}' is not supported. "
92+
f"Supported schemes: {', '.join(supported_schemes)}"
93+
)
94+
8095
self.rules = [
81-
CheckAddressBinding(name="Check grpc port binding", get_host_and_port_from_package=_get_grpc_host_and_port),
96+
CheckAddressBinding(name="Check FL port binding", get_host_and_port_from_package=_get_fl_host_and_port),
8297
CheckAddressBinding(
8398
name="Check admin port binding", get_host_and_port_from_package=_get_admin_host_and_port
8499
),

nvflare/tool/package_checker/utils.py

Lines changed: 108 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717
import shlex
1818
import shutil
1919
import socket
20+
import ssl
2021
import subprocess
2122
import tempfile
22-
from typing import Any, Dict, Optional
2323

2424
import grpc
25-
from requests import Request, Response, Session
26-
from requests.adapters import HTTPAdapter
25+
from requests import Response
2726

2827

2928
class NVFlareConfig:
@@ -66,25 +65,6 @@ def try_bind_address(host: str, port: int):
6665
return None
6766

6867

69-
def _create_http_session(ca_path=None, cert_path=None, prv_key_path=None):
70-
session = Session()
71-
adapter = HTTPAdapter(max_retries=1)
72-
session.mount("https://", adapter)
73-
if ca_path:
74-
session.verify = ca_path
75-
session.cert = (cert_path, prv_key_path)
76-
return session
77-
78-
79-
def _send_request(
80-
session, api_point, headers: Optional[Dict[str, Any]] = None, payload: Optional[Dict[str, Any]] = None
81-
) -> Response:
82-
req = Request("POST", api_point, json=payload, headers=headers)
83-
prepared = session.prepare_request(req)
84-
resp = session.send(prepared)
85-
return resp
86-
87-
8868
def parse_overseer_agent_args(overseer_agent_conf: dict, required_args: list) -> dict:
8969
result = {}
9070
for k in required_args:
@@ -162,27 +142,26 @@ def _get_conn_sec(startup: str):
162142

163143

164144
def check_grpc_server_running(startup: str, host: str, port: int, token=None) -> bool:
165-
with open(os.path.join(startup, _get_ca_cert_file_name()), "rb") as f:
166-
trusted_certs = f.read()
167-
with open(os.path.join(startup, _get_prv_key_file_name(NVFlareRole.CLIENT)), "rb") as f:
168-
private_key = f.read()
169-
with open(os.path.join(startup, _get_cert_file_name(NVFlareRole.CLIENT)), "rb") as f:
170-
certificate_chain = f.read()
171145

172146
conn_sec = _get_conn_sec(startup)
173147
secure = True
174148
if conn_sec == "clear":
175149
secure = False
176150

177-
call_credentials = grpc.metadata_call_credentials(
178-
lambda context, callback: callback((("x-custom-token", token),), None)
179-
)
180-
credentials = grpc.ssl_channel_credentials(
181-
certificate_chain=certificate_chain, private_key=private_key, root_certificates=trusted_certs
182-
)
183-
184-
composite_credentials = grpc.composite_channel_credentials(credentials, call_credentials)
185151
if secure:
152+
with open(os.path.join(startup, _get_ca_cert_file_name()), "rb") as f:
153+
trusted_certs = f.read()
154+
with open(os.path.join(startup, _get_prv_key_file_name(NVFlareRole.CLIENT)), "rb") as f:
155+
private_key = f.read()
156+
with open(os.path.join(startup, _get_cert_file_name(NVFlareRole.CLIENT)), "rb") as f:
157+
certificate_chain = f.read()
158+
call_credentials = grpc.metadata_call_credentials(
159+
lambda context, callback: callback((("x-custom-token", token),), None)
160+
)
161+
credentials = grpc.ssl_channel_credentials(
162+
certificate_chain=certificate_chain, private_key=private_key, root_certificates=trusted_certs
163+
)
164+
composite_credentials = grpc.composite_channel_credentials(credentials, call_credentials)
186165
channel = grpc.secure_channel(target=f"{host}:{port}", credentials=composite_credentials)
187166
else:
188167
channel = grpc.insecure_channel(target=f"{host}:{port}")
@@ -194,6 +173,66 @@ def check_grpc_server_running(startup: str, host: str, port: int, token=None) ->
194173
return True
195174

196175

176+
def check_socket_server_running(startup: str, host: str, port: int, scheme: str = "https") -> bool:
177+
"""Check if socket-based server (HTTP/HTTPS/TCP/STCP) is running and accessible.
178+
179+
This function performs a socket connection test with optional SSL/TLS.
180+
It's used for HTTP/WebSocket and TCP-based FL servers.
181+
182+
Args:
183+
startup: Path to startup directory containing certificates
184+
host: Server hostname or IP address
185+
port: Server port number
186+
scheme: URL scheme ("http", "https", "tcp", "stcp")
187+
188+
Returns:
189+
True if server is accessible, False otherwise
190+
"""
191+
conn_sec = _get_conn_sec(startup)
192+
secure = True
193+
if conn_sec == "clear":
194+
secure = False
195+
196+
# Determine if we need SSL based on scheme
197+
use_ssl = secure and scheme in ["https", "stcp"]
198+
199+
# Try a socket connection to check if port is reachable
200+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
201+
sock.settimeout(10)
202+
203+
try:
204+
if use_ssl:
205+
# For secure connection, wrap socket with SSL and use client certificates
206+
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
207+
context.minimum_version = ssl.TLSVersion.TLSv1_2
208+
ca_path = os.path.join(startup, _get_ca_cert_file_name())
209+
cert_path = os.path.join(startup, _get_cert_file_name(NVFlareRole.CLIENT))
210+
prv_key_path = os.path.join(startup, _get_prv_key_file_name(NVFlareRole.CLIENT))
211+
212+
context.load_verify_locations(ca_path)
213+
context.load_cert_chain(cert_path, prv_key_path)
214+
# Check hostname may fail for localhost, so disable it for preflight check
215+
context.check_hostname = False
216+
217+
ssl_sock = context.wrap_socket(sock, server_hostname=host)
218+
ssl_sock.connect((host, port))
219+
ssl_sock.close()
220+
else:
221+
# For insecure connection, just check if we can connect
222+
sock.connect((host, port))
223+
sock.close()
224+
225+
return True
226+
except (socket.timeout, socket.error, ssl.SSLError, OSError, ConnectionRefusedError):
227+
# Connection failed - server is not accessible
228+
return False
229+
finally:
230+
try:
231+
sock.close()
232+
except Exception:
233+
pass
234+
235+
197236
def run_command_in_subprocess(command):
198237
new_env = os.environ.copy()
199238
process = subprocess.Popen(
@@ -206,3 +245,36 @@ def run_command_in_subprocess(command):
206245
universal_newlines=True,
207246
)
208247
return process
248+
249+
250+
def get_communication_scheme(package_path: str, config_name: str, default_scheme: str = "http") -> str:
251+
"""Read the communication scheme from package configuration files.
252+
253+
This function checks multiple sources to determine the communication scheme:
254+
1. For servers: fed_server.json (service.scheme)
255+
2. For all packages: comm_config.json in local/ or startup/ directories
256+
257+
Args:
258+
package_path: Path to the package directory
259+
config_name: Name of the configuration file (fed_server.json, fed_client.json, fed_admin.json)
260+
default_scheme: Default scheme to return if no scheme is found
261+
262+
Returns:
263+
The communication scheme (e.g., "grpc", "http")
264+
"""
265+
# First try to read from fed_xxx.json
266+
startup = os.path.join(package_path, "startup")
267+
fed_config_file = os.path.join(startup, config_name)
268+
if os.path.exists(fed_config_file):
269+
try:
270+
with open(fed_config_file, "r") as f:
271+
fed_config = json.load(f)
272+
server_conf = fed_config.get("servers", [{}])[0]
273+
service_config = server_conf.get("service", {})
274+
scheme = service_config.get("scheme")
275+
if scheme:
276+
return scheme.lower()
277+
except Exception:
278+
pass
279+
280+
return default_scheme

tests/integration_test/data/projects/dummy.yml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,6 @@ builders:
2525
args:
2626
# config_folder can be set to inform NVIDIA FLARE where to get configuration
2727
config_folder: config
28-
overseer_agent:
29-
path: nvflare.ha.dummy_overseer_agent.DummyOverseerAgent
30-
overseer_exists: false
31-
args:
32-
sp_end_point: localhost0:8002:8003
3328

3429
- path: nvflare.lighter.impl.cert.CertBuilder
3530
- path: nvflare.lighter.impl.signature.SignatureBuilder

0 commit comments

Comments
 (0)