Skip to content

Commit 224d658

Browse files
Convert JobAPI to Recipe for Kaplan-Meier example (#3894)
Fixes # . ### Description Convert KM example from JobAPI to Recipe, also add production instructions with provisioned HE context ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --------- Co-authored-by: Chester Chen <512707+chesterxgchen@users.noreply.github.com>
1 parent 39a9763 commit 224d658

File tree

10 files changed

+638
-170
lines changed

10 files changed

+638
-170
lines changed

examples/advanced/kaplan-meier-he/README.md

Lines changed: 203 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
This example illustrates two features:
44
* How to perform Kaplan-Meier survival analysis in federated setting without and with secure features via time-binning and Homomorphic Encryption (HE).
5-
* How to use the Flare ModelController API to contract a workflow to facilitate HE under simulator mode.
5+
* How to use the Recipe API with Flare ModelController for job configuration and execution in both simulation and production environments.
66

77
## Basics of Kaplan-Meier Analysis
88
Kaplan-Meier survival analysis is a non-parametric statistic used to estimate the survival function from lifetime data. It is used to analyze the time it takes for an event of interest to occur. For example, during a clinical trial, the Kaplan-Meier estimator can be used to estimate the proportion of patients who survive a certain amount of time after treatment.
@@ -62,7 +62,7 @@ To run the baseline script, simply execute:
6262
```commandline
6363
python utils/baseline_kaplan_meier.py
6464
```
65-
By default, this will generate a KM curve image `km_curve_baseline.png` under `/tmp` directory. The resulting KM curve is shown below:
65+
By default, this will generate a KM curve image `km_curve_baseline.png` under `/tmp/nvflare/baseline` directory. The resulting KM curve is shown below:
6666
![KM survival baseline](figs/km_curve_baseline.png)
6767
Here, we show the survival curve for both daily (without binning) and weekly binning. The two curves aligns well with each other, while the weekly-binned curve has lower resolution.
6868

@@ -72,41 +72,232 @@ We make use of FLARE ModelController API to implement the federated Kaplan-Meier
7272

7373
The Flare ModelController API (`ModelController`) provides the functionality of flexible FLModel payloads for each round of federated analysis. This gives us the flexibility of transmitting various information needed by our scheme at different stages of federated learning.
7474

75-
Our [existing HE examples](../cifar10/cifar10-real-world) uses data filter mechanism for HE, provisioning the HE context information (specs and keys) for both client and server of the federated job under [CKKS](../../../nvflare/app_opt/he/model_encryptor.py) scheme. In this example, we would like to illustrate ModelController's capability in supporting customized needs beyond the existing HE functionalities (designed mainly for encrypting deep learning models).
76-
- different HE schemes (BFV) rather than CKKS
77-
- different content at different rounds of federated learning, and only specific payload needs to be encrypted
75+
Our [existing HE examples](../cifar10/cifar10-real-world) uses data filter mechanism for HE, provisioning the HE context information (specs and keys) for both client and server of the federated job under [CKKS](../../../nvflare/app_opt/he/model_encryptor.py) scheme. In this example, we would like to illustrate ModelController's capability in supporting customized needs beyond the existing HE functionalities (designed mainly for encrypting deep learning models):
76+
- Different content at different rounds of federated learning, where only specific payloads need to be encrypted
77+
- Flexibility in choosing what to encrypt (histograms) versus what to send in plain text (metadata)
7878

7979
With the ModelController API, such "proof of concept" experiment becomes easy. In this example, the federated analysis pipeline includes 2 rounds without HE, or 3 rounds with HE.
8080

8181
For the federated analysis without HE, the detailed steps are as follows:
8282
1. Server sends the simple start message without any payload.
8383
2. Clients submit the local event histograms to server. Server aggregates the histograms with varying lengths by adding event counts of the same slot together, and sends the aggregated histograms back to clients.
8484

85-
For the federated analysis with HE, we need to ensure proper HE aggregation using BFV, and the detailed steps are as follows:
85+
For the federated analysis with HE, we need to ensure proper HE aggregation using CKKS, and the detailed steps are as follows:
8686
1. Server send the simple start message without any payload.
8787
2. Clients collect the information of the local maximum bin number (for event time) and send to server, where server aggregates the information by selecting the maximum among all clients. The global maximum number is then distributed back to clients. This step is necessary because we would like to standardize the histograms generated by all clients, such that they will have the exact same length and can be encrypted as vectors of same size, which will be addable.
8888
3. Clients condense their local raw event lists into two histograms with the global length received, encrypt the histrogram value vectors, and send to server. Server aggregated the received histograms by adding the encrypted vectors together, and sends the aggregated histograms back to clients.
8989

9090
After these rounds, the federated work is completed. Then at each client, the aggregated histograms will be decrypted and converted back to an event list, and Kaplan-Meier analysis can be performed on the global information.
9191

92+
### HE Context and Data Management
93+
94+
- **Simulation Mode**:
95+
- Uses **CKKS scheme** (approximate arithmetic, compatible with production)
96+
- HE context files are manually created via `prepare_he_context.py`:
97+
- Client context: `/tmp/nvflare/he_context/he_context_client.txt`
98+
- Server context: `/tmp/nvflare/he_context/he_context_server.txt`
99+
- Data prepared at `/tmp/nvflare/dataset/km_data`
100+
- Paths can be customized via `--he_context_path` (for client context) and `--data_root`
101+
- **Production Mode**:
102+
- Uses **CKKS scheme**
103+
- HE context is automatically provisioned into startup kits via `nvflare provision`
104+
- Context files are resolved by NVFlare's SecurityContentService:
105+
- Clients automatically use: `client_context.tenseal` (from their startup kit)
106+
- Server automatically uses: `server_context.tenseal` (from its startup kit)
107+
- The `--he_context_path` parameter is ignored in production mode
108+
- **Reuses the same data** from simulation mode at `/tmp/nvflare/dataset/km_data` by default
109+
110+
**Note:** CKKS scheme provides strong encryption with approximate arithmetic, which works well for this Kaplan-Meier analysis. The histogram counts are encrypted as floating-point numbers and rounded back to integers after decryption. Both simulation and production modes use the same CKKS scheme for consistency and compatibility. Production mode can reuse the data prepared during simulation mode, eliminating redundant data preparation.
111+
92112
## Run the job
93-
First, we prepared data for a 5-client federated job. We split and generate the data files for each client with binning interval of 7 days.
113+
114+
This example supports both **Simulation Mode** (for local testing) and **Production Mode** (for real-world deployment).
115+
116+
| Feature | Simulation Mode | Production Mode |
117+
|---------|----------------|-----------------|
118+
| **Use Case** | Testing & Development | Real-world Deployment / Production Testing |
119+
| **HE Context** | Manual preparation via script | Auto-provisioned via startup kits |
120+
| **Security** | Single machine, no encryption between processes | Secure startup kits with certificates |
121+
| **Setup** | Quick & simple | Requires provisioning & starting all parties |
122+
| **Startup** | Single command | `start_all.sh` (local) or manual (distributed) |
123+
| **Participants** | All run locally in one process | Distributed servers/clients running separately |
124+
| **Data** | Prepared once, shared by all | Same data reused from simulation |
125+
126+
### Simulation Mode
127+
128+
For simulation mode (testing and development), we manually prepare the data and HE context:
129+
130+
**Step 1: Prepare Data**
131+
132+
Split and generate data files for each client with binning interval of 7 days:
94133
```commandline
95134
python utils/prepare_data.py --site_num 5 --bin_days 7 --out_path "/tmp/nvflare/dataset/km_data"
96135
```
97136

98-
Then we prepare HE context for clients and server, note that this step is done by secure provisioning for real-life applications, but in this study experimenting with BFV scheme, we use this step to distribute the HE context.
137+
**Step 2: Prepare HE Context (Simulation Only)**
138+
139+
For simulation mode, manually prepare the HE context with CKKS scheme:
99140
```commandline
141+
# Remove old HE context if it exists
142+
rm -rf /tmp/nvflare/he_context
143+
# Generate new CKKS HE context
100144
python utils/prepare_he_context.py --out_path "/tmp/nvflare/he_context"
101145
```
102146

103-
Next, we run the federated training using NVFlare Simulator via [JobAPI](https://nvflare.readthedocs.io/en/main/programming_guide/fed_job_api.html), both without and with HE:
147+
This generates the HE context with CKKS scheme (poly_modulus_degree=8192, global_scale=2^40) compatible with production mode.
148+
149+
**Step 3: Run the Job**
150+
151+
Run the job without and with HE:
152+
```commandline
153+
python job.py
154+
python job.py --encryption
155+
```
156+
157+
The script will execute the job in simulation mode and display the job status. Results (KM curves and analysis details) will be saved to each simulated client's workspace directory under `/tmp/nvflare/workspaces/`.
158+
159+
### Production Mode
160+
161+
For production deployments, the HE context is automatically provisioned through secure startup kits.
162+
163+
**Quick Start for Local Testing:**
164+
If you want to quickly test production mode on a single machine:
165+
1. Run provisioning: `nvflare provision -p project.yml -w /tmp/nvflare/prod_workspaces`
166+
2. Start all parties: `./start_all.sh`
167+
3. Start admin console: `cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com && ./startup/fl_admin.sh` (use username `admin@nvidia.com`)
168+
4. Submit job: `python job.py --encryption --startup_kit_location /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com`
169+
5. Monitor job via admin console: `list_jobs`, `check_status client`, `download_job <job_id>`
170+
6. Shutdown: `shutdown all` in admin console
171+
172+
For detailed steps and distributed deployment, continue below:
173+
174+
**Step 1: Install NVFlare with HE Support**
175+
176+
```commandline
177+
pip install nvflare[HE]
178+
```
179+
180+
**Step 2: Provision Startup Kits with HE Context**
181+
182+
The `project.yml` file in this directory is pre-configured with `HEBuilder` using the CKKS scheme. Run provisioning to output to `/tmp/nvflare/prod_workspaces`:
183+
184+
```commandline
185+
nvflare provision -p project.yml -w /tmp/nvflare/prod_workspaces
186+
```
187+
188+
This generates startup kits in `/tmp/nvflare/prod_workspaces/km_he_project/prod_00/`:
189+
- `localhost/` - Server startup kit with `server_context.tenseal`
190+
- `site-1/`, `site-2/`, etc. - Client startup kits, each with `client_context.tenseal`
191+
- `admin@nvidia.com/` - Admin console
192+
193+
The HE context files are automatically included in each startup kit and do not need to be manually distributed.
194+
195+
**Step 3: Distribute Startup Kits**
196+
197+
Securely distribute the startup kits to each participant from `/tmp/nvflare/prod_workspaces/km_he_project/prod_00/`:
198+
- `localhost/` directory is the server (for local testing, no need to send)
199+
- Send `site-1/`, `site-2/`, etc. directories to each client host (for distributed deployment)
200+
- Keep `admin@nvidia.com/` directory for the admin user
201+
202+
**Step 4: Start All Parties**
203+
204+
**Option A: Quick Start (Local Testing)**
205+
206+
For local testing where all parties run on the same machine, use the convenience script:
207+
208+
```commandline
209+
./start_all.sh
210+
```
211+
212+
This will start the server and all 5 clients in the background. Logs are saved to `/tmp/nvflare/logs/`.
213+
214+
Then start the admin console:
215+
```commandline
216+
cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com
217+
./startup/fl_admin.sh
218+
```
219+
220+
**Important:** When prompted for "User Name:", enter `admin@nvidia.com` (this matches the admin defined in project.yml).
221+
222+
Once connected, check the status of all participants:
223+
```
224+
> check_status server
225+
> check_status client
226+
```
227+
228+
**Option B: Manual Start (Distributed Deployment)**
229+
230+
For distributed deployment where parties run on different machines:
231+
232+
**On the Server Host:**
233+
```commandline
234+
cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/localhost
235+
./startup/start.sh
236+
```
237+
238+
Wait for the server to be ready (you should see "Server started" in the logs).
239+
240+
**On Each Client Host:**
241+
```commandline
242+
# On site-1
243+
cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/site-1
244+
./startup/start.sh
245+
246+
# On site-2
247+
cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/site-2
248+
./startup/start.sh
249+
250+
# Repeat for site-3, site-4, and site-5
251+
```
252+
253+
**On the Admin Machine:**
104254
```commandline
105-
python km_job.py
106-
python km_job.py --encryption
255+
cd /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com
256+
./startup/fl_admin.sh
257+
# When prompted, use username: admin@nvidia.com
107258
```
108259

109-
By default, this will generate a KM curve image `km_curve_fl.png` and `km_curve_fl_he.png` under each client's directory.
260+
**Step 5: Submit and Run the Job**
261+
262+
With all parties running, submit the job using the Recipe API. The job will automatically use:
263+
- The provisioned HE context from each participant's startup kit
264+
- The data already prepared in simulation mode at `/tmp/nvflare/dataset/km_data`
265+
266+
```commandline
267+
python job.py --encryption --startup_kit_location /tmp/nvflare/prod_workspaces/km_he_project/prod_00/admin@nvidia.com
268+
```
269+
270+
The script will output the job status. Note the job ID from the output.
271+
272+
**Monitoring Job Progress:**
273+
274+
The job runs asynchronously on the FL system. Use the admin console to monitor progress:
275+
276+
```commandline
277+
# In the admin console
278+
> list_jobs # View all jobs
279+
> check_status server # Check server status
280+
> check_status client # Check all clients status
281+
> download_job <job_id> # Download results after completion
282+
```
283+
284+
Results will be saved to each client's workspace directory after the job completes:
285+
- `/tmp/nvflare/prod_workspaces/km_he_project/prod_00/site-1/{JOB_ID}/`
286+
- Look for `km_curve_fl_he.png` and `km_global.json` in each client's job directory
287+
288+
**Note:** In production mode with HE, the HE context paths are automatically configured to use the provisioned context files from each participant's startup kit:
289+
- Clients use: `client_context.tenseal`
290+
- Server uses: `server_context.tenseal`
291+
292+
The `--he_context_path` parameter is only used for simulation mode and is ignored in production mode. No manual HE context distribution is needed in production.
293+
294+
**Step 6: Shutdown All Parties**
295+
296+
After the job completes, shut down all parties gracefully via admin console:
297+
298+
```
299+
> shutdown all
300+
```
110301

111302
## Display Result
112303

examples/advanced/kaplan-meier-he/src/kaplan_meier_train.py renamed to examples/advanced/kaplan-meier-he/client.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,15 @@ def details_save(kmf):
4646
"event_count": event_count.tolist(),
4747
"survival_rate": survival_rate.tolist(),
4848
}
49-
file_path = os.path.join(os.getcwd(), "km_global.json")
49+
50+
# Save to job-specific directory
51+
# The script is located at: site-X/{JOB_DIR}/app_site-X/custom/client.py (sim) or site-X/{JOB_ID}/app_site-X/custom/client.py (prod)
52+
# We need to navigate up to the {JOB_DIR} directory
53+
script_dir = os.path.dirname(os.path.abspath(__file__))
54+
# Go up 2 levels: custom -> app_site-X -> {JOB_DIR}
55+
job_dir = os.path.abspath(os.path.join(script_dir, "..", ".."))
56+
57+
file_path = os.path.join(job_dir, "km_global.json")
5058
print(f"save the details of KM analysis result to {file_path} \n")
5159
with open(file_path, "w") as json_file:
5260
json.dump(results, json_file, indent=4)
@@ -62,7 +70,15 @@ def plot_and_save(kmf):
6270
plt.xlabel("time")
6371
plt.legend("", frameon=False)
6472
plt.tight_layout()
65-
file_path = os.path.join(os.getcwd(), "km_curve_fl.png")
73+
74+
# Save to job-specific directory
75+
# The script is located at: site-X/{JOB_DIR}/app_site-X/custom/client.py (sim) or site-X/{JOB_ID}/app_site-X/custom/client.py (prod)
76+
# We need to navigate up to the {JOB_DIR} directory
77+
script_dir = os.path.dirname(os.path.abspath(__file__))
78+
# Go up 2 levels: custom -> app_site-X -> {JOB_DIR}
79+
job_dir = os.path.abspath(os.path.join(script_dir, "..", ".."))
80+
81+
file_path = os.path.join(job_dir, "km_curve_fl.png")
6682
print(f"save the curve plot to {file_path} \n")
6783
plt.savefig(file_path)
6884

@@ -94,10 +110,10 @@ def main():
94110
# Empty payload from server, send local histogram
95111
# Convert local data to histogram
96112
event_table = survival_table_from_events(time_local, event_local)
97-
hist_idx = event_table.index.values.astype(int)
98113
hist_obs = {}
99114
hist_cen = {}
100-
for idx in range(max(hist_idx)):
115+
max_hist_idx = max(event_table.index.values.astype(int))
116+
for idx in range(max_hist_idx):
101117
hist_obs[idx] = 0
102118
hist_cen[idx] = 0
103119
# Assign values

0 commit comments

Comments
 (0)