Skip to content

Commit c213d39

Browse files
committed
saving work in progress on dl
1 parent 91694e0 commit c213d39

File tree

7 files changed

+487
-48
lines changed

7 files changed

+487
-48
lines changed

app/conf/calibration_form/common.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -386,21 +386,21 @@ components:
386386
collapsible: true
387387
children:
388388
- id: upload-summary-data-file-button
389-
label: Upload
389+
label: Upload statistics
390390
help: Upload summary statistics data from a csv file
391391
class_name: dash.dcc.Upload
392392
handler: file_upload
393393
kwargs:
394394
children: Load statistics data
395395
- id: upload-obs-data-file-button
396-
label: Upload
396+
label: Upload simulation
397397
help: Upload simulated root data from a csv file
398398
class_name: dash.dcc.Upload
399399
handler: file_upload
400400
kwargs:
401401
children: Load simulation data
402402
- id: upload-edge-data-file-button
403-
label: Upload
403+
label: Upload edges
404404
help: Upload edge root data from a csv file
405405
class_name: dash.dcc.Upload
406406
handler: file_upload
@@ -694,7 +694,7 @@ components:
694694
type: number
695695
min: 1
696696
step: 1
697-
value: 10
697+
value: 5
698698
persistence: true
699699
- id: draws-input
700700
param: pp_samples
@@ -716,7 +716,7 @@ components:
716716
type: number
717717
min: 1
718718
step: 1
719-
value: 50
719+
value: 5
720720
persistence: true
721721
- id: num-transforms-input
722722
param: nn_num_transforms

app/flows/run_snpe.py

Lines changed: 199 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# Imports
55
######################################
66

7+
# mypy: ignore-errors
8+
79
# isort: off
810

911
# This is for compatibility with Prefect.
@@ -13,15 +15,18 @@
1315
# isort: on
1416

1517
import os.path as osp
18+
from random import choices
1619
from typing import Callable
1720

1821
import mlflow
22+
import networkx as nx
1923
import numpy as np
2024
import pandas as pd
2125
import torch
2226
import torch.distributions as dist
2327
import torch.nn as nn
2428
import torch.nn.functional as F
29+
import torch_geometric.transforms as T
2530
from joblib import dump as calibrator_dump
2631
from matplotlib import pyplot as plt
2732
from prefect import flow, task
@@ -36,6 +41,8 @@
3641
prepare_for_sbi,
3742
simulate_for_sbi,
3843
)
44+
from torch_geometric.nn import SAGEConv
45+
from torch_geometric.nn.pool import global_max_pool
3946

4047
from deeprootgen.calibration import (
4148
SnpeModel,
@@ -46,6 +53,7 @@
4653
)
4754
from deeprootgen.data_model import RootCalibrationModel, SummaryStatisticsModel
4855
from deeprootgen.io import save_graph_to_db
56+
from deeprootgen.model import RootSystemGraph
4957
from deeprootgen.pipeline import (
5058
begin_experiment,
5159
get_datetime_now,
@@ -101,6 +109,7 @@ def prepare_task(input_parameters: RootCalibrationModel) -> tuple:
101109

102110
data_type = v["data_type"]
103111
if data_type == "discrete":
112+
lower_bound, upper_bound = int(lower_bound), int(upper_bound)
104113
replicates = np.floor(upper_bound - lower_bound).astype("int")
105114
probabilities = torch.tensor([1 / replicates])
106115
probabilities = probabilities.repeat(replicates)
@@ -122,6 +131,146 @@ def prepare_task(input_parameters: RootCalibrationModel) -> tuple:
122131
return names, priors, limits, statistics_list
123132

124133

134+
# @TODO this is a hack for compatibility with the sbi API,
135+
# and should be replaced with a GNN feature extractor surrogate.
136+
# i.e. we train a separate feature extractor to provide graph embeddings,
137+
# then use that feature extractor instead of this embedding net.
138+
class GraphFeatureExtractor(torch.nn.Module):
139+
"""A graph feature extractor for density estimation."""
140+
141+
def __init__(self, organ_columns: list[str]) -> None:
142+
"""GraphFeatureExtractor constructor.
143+
144+
Args:
145+
organ_columns (list[str]):
146+
The list of organ columns for grouping organ features.
147+
"""
148+
super().__init__()
149+
self.organ_columns = organ_columns
150+
151+
self.transform = T.Compose([T.NormalizeFeatures(organ_columns)])
152+
153+
G = RootSystemGraph()
154+
organ_features = []
155+
for organ_column in organ_columns:
156+
organ_features.extend(G.organ_columns[organ_column])
157+
self.organ_features = organ_features
158+
159+
num_organ_features = len(organ_features)
160+
self.num_organ_features = num_organ_features
161+
self.conv1 = SAGEConv(
162+
num_organ_features,
163+
num_organ_features * 4,
164+
aggr="mean",
165+
normalize=True,
166+
bias=True,
167+
)
168+
self.conv2 = SAGEConv(
169+
num_organ_features * 4,
170+
num_organ_features * 2,
171+
aggr="mean",
172+
normalize=True,
173+
bias=True,
174+
)
175+
176+
self.fc = torch.nn.Linear(num_organ_features * 2, num_organ_features)
177+
self.pool = global_max_pool
178+
self.activation = F.elu
179+
180+
self.G_list: list = []
181+
182+
def process_graph(self, G: nx.Graph) -> tuple:
183+
"""Process a new NetworkX graph.
184+
185+
Args:
186+
G (nx.Graph):
187+
The NetworkX graph.
188+
189+
Returns:
190+
tuple:
191+
The node and edge features.
192+
"""
193+
for column in self.organ_columns:
194+
G[column] = torch.Tensor(pd.DataFrame(G[column]).values).double()
195+
196+
train_data = self.transform(G)
197+
organ_features = []
198+
for column in self.organ_columns:
199+
organ_features.append(train_data[column])
200+
201+
x = torch.Tensor(np.hstack(organ_features))
202+
edge_index = train_data.edge_index
203+
return x, edge_index
204+
205+
def add_graph(self, G: nx.Graph) -> int:
206+
"""Add a graph to the graph list.
207+
208+
Args:
209+
G (nx.Graph):
210+
The NetworkX graph.
211+
212+
Returns:
213+
int:
214+
The list index.
215+
"""
216+
x, edge_index = self.process_graph(G)
217+
218+
self.G_list.append((x, edge_index))
219+
return len(self.G_list) - 1
220+
221+
def encode(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
222+
"""Construct graph embeddings from node and edges.
223+
224+
Args:
225+
x (torch.Tensor):
226+
The node features.
227+
edge_index (torch.Tensor):
228+
The edge index.
229+
230+
Returns:
231+
torch.Tensor:
232+
The graph embeddings.
233+
"""
234+
batch_index = torch.Tensor(np.repeat(0, x.shape[0])).type(torch.int64)
235+
236+
x = self.conv1(x, edge_index)
237+
x = self.activation(x)
238+
x = self.conv2(x, edge_index)
239+
x = self.activation(x)
240+
x = self.pool(x, batch_index)
241+
x = self.activation(x)
242+
x = self.fc(x)
243+
x = x.view(-1)
244+
245+
return x
246+
247+
def forward(self, x: torch.Tensor) -> torch.Tensor:
248+
"""The forward pass.
249+
250+
Args:
251+
x (torch.Tensor):
252+
The batch tensor.
253+
254+
Returns:
255+
torch.Tensor:
256+
The graph embedding.
257+
"""
258+
if x.shape[1] > 1:
259+
return x
260+
261+
batch_size = x.shape[0]
262+
indices = np.array(range(batch_size))
263+
264+
batches = []
265+
batch = choices(self.G_list, k=batch_size)
266+
for i in indices:
267+
x, edge_index = batch[i]
268+
x = self.encode(x, edge_index)
269+
batches.append(x)
270+
x = torch.stack(batches)
271+
return x
272+
273+
125274
@task
126275
def perform_task(
127276
input_parameters: RootCalibrationModel,
@@ -145,23 +294,41 @@ def perform_task(
145294
tuple:
146295
The trained model and samples.
147296
"""
297+
use_summary_statistics: bool = (
298+
input_parameters.statistics_comparison.use_summary_statistics
299+
)
300+
if use_summary_statistics:
301+
embedding_net = nn.Identity()
302+
else:
303+
organ_columns = ["organ_coordinates", "organ_hierarchy", "organ_size"]
304+
embedding_net = GraphFeatureExtractor(organ_columns)
148305

149306
def simulator_func(theta: np.ndarray) -> np.ndarray:
150307
theta = theta.detach().cpu().numpy()
151308
parameter_specs = {}
152309
for i, name in enumerate(names):
153310
parameter_specs[name] = theta[i]
154311

155-
simulated, _ = calculate_summary_statistics(
156-
parameter_specs, input_parameters, statistics_list
157-
)
312+
if use_summary_statistics:
313+
simulated, _ = calculate_summary_statistics(
314+
parameter_specs, input_parameters, statistics_list
315+
)
316+
else:
317+
simulation, _ = run_calibration_simulation(
318+
parameter_specs, input_parameters
319+
)
320+
321+
G = simulation.G.as_torch(drop=True)
322+
indx = embedding_net.add_graph(G)
323+
simulated = np.array([indx]).astype("int")
158324

159325
return simulated
160326

161327
calibration_parameters = input_parameters.calibration_parameters
162328
simulator, prior = prepare_for_sbi(simulator_func, priors)
163329
neural_posterior = utils.posterior_nn(
164330
model="nsf",
331+
embedding_net=embedding_net,
165332
hidden_features=calibration_parameters["nn_num_hidden_features"],
166333
num_transforms=calibration_parameters["nn_num_transforms"],
167334
)
@@ -177,16 +344,34 @@ def simulator_func(theta: np.ndarray) -> np.ndarray:
177344
inference = inference.append_simulations(theta, x, data_device="cpu")
178345
density_estimator = inference.train()
179346
posterior = inference.build_posterior(density_estimator)
180-
181347
calibration_parameters = input_parameters.calibration_parameters
182348
n_draws = calibration_parameters["pp_samples"]
183-
observed_values = []
184-
for statistic in statistics_list:
185-
observed_values.append(statistic.statistic_value)
186-
posterior.set_default_x(observed_values)
187-
posterior_samples = posterior.sample((n_draws,), x=observed_values)
188349

189-
observed_values = [statistic.dict() for statistic in statistics_list]
350+
if use_summary_statistics:
351+
observed_values = []
352+
for statistic in statistics_list:
353+
observed_values.append(statistic.statistic_value)
354+
posterior.set_default_x(observed_values)
355+
posterior_samples = posterior.sample((n_draws,), x=observed_values)
356+
observed_values = [statistic.dict() for statistic in statistics_list]
357+
else:
358+
root_g = RootSystemGraph()
359+
observed_data_content = input_parameters.observed_data_content
360+
raw_edge_content = input_parameters.raw_edge_content
361+
node_df, edge_df = root_g.from_content_string(
362+
observed_data_content, raw_edge_content
363+
)
364+
G = root_g.as_torch(node_df, edge_df, drop=True)
365+
x, edge_index = embedding_net.process_graph(G)
366+
367+
with torch.no_grad():
368+
observed_values = embedding_net.encode(x, edge_index)
369+
370+
posterior.set_default_x(observed_values)
371+
posterior_samples = posterior.sample((n_draws,), x=observed_values)
372+
embedding_net.G_list = []
373+
observed_values = (node_df, edge_df)
374+
190375
return inference, simulator, prior, posterior, posterior_samples, observed_values
191376

192377

@@ -234,6 +419,9 @@ def log_task(
234419
tuple:
235420
The simulation and its parameters.
236421
"""
422+
# use_summary_statistics: bool = (
423+
# input_parameters.statistics_comparison.use_summary_statistics
424+
# )
237425
time_now = get_datetime_now()
238426
outdir = get_outdir()
239427

@@ -293,6 +481,7 @@ def log_task(
293481
description="# Simulation-based calibration metrics.",
294482
)
295483

484+
num_bins = None
296485
if sbc_draws <= 20: # type: ignore
297486
num_bins = sbc_draws
298487

0 commit comments

Comments
 (0)