44# Imports
55######################################
66
7+ # mypy: ignore-errors
8+
79# isort: off
810
911# This is for compatibility with Prefect.
1315# isort: on
1416
1517import os .path as osp
18+ from random import choices
1619from typing import Callable
1720
1821import mlflow
22+ import networkx as nx
1923import numpy as np
2024import pandas as pd
2125import torch
2226import torch .distributions as dist
2327import torch .nn as nn
2428import torch .nn .functional as F
29+ import torch_geometric .transforms as T
2530from joblib import dump as calibrator_dump
2631from matplotlib import pyplot as plt
2732from prefect import flow , task
3641 prepare_for_sbi ,
3742 simulate_for_sbi ,
3843)
44+ from torch_geometric .nn import SAGEConv
45+ from torch_geometric .nn .pool import global_max_pool
3946
4047from deeprootgen .calibration import (
4148 SnpeModel ,
4653)
4754from deeprootgen .data_model import RootCalibrationModel , SummaryStatisticsModel
4855from deeprootgen .io import save_graph_to_db
56+ from deeprootgen .model import RootSystemGraph
4957from deeprootgen .pipeline import (
5058 begin_experiment ,
5159 get_datetime_now ,
@@ -101,6 +109,7 @@ def prepare_task(input_parameters: RootCalibrationModel) -> tuple:
101109
102110 data_type = v ["data_type" ]
103111 if data_type == "discrete" :
112+ lower_bound , upper_bound = int (lower_bound ), int (upper_bound )
104113 replicates = np .floor (upper_bound - lower_bound ).astype ("int" )
105114 probabilities = torch .tensor ([1 / replicates ])
106115 probabilities = probabilities .repeat (replicates )
@@ -122,6 +131,146 @@ def prepare_task(input_parameters: RootCalibrationModel) -> tuple:
122131 return names , priors , limits , statistics_list
123132
124133
134+ # @TODO this is a hack for compatibility with the sbi API,
135+ # and should be replaced with a GNN feature extractor surrogate.
136+ # i.e. we train a separate feature extractor to provide graph embeddings,
137+ # then use that feature extractor instead of this embedding net.
138+ class GraphFeatureExtractor (torch .nn .Module ):
139+ """A graph feature extractor for density estimation."""
140+
141+ def __init__ (self , organ_columns : list [str ]) -> None :
142+ """GraphFeatureExtractor constructor.
143+
144+ Args:
145+ organ_columns (list[str]):
146+ The list of organ columns for grouping organ features.
147+ """
148+ super ().__init__ ()
149+ self .organ_columns = organ_columns
150+
151+ self .transform = T .Compose ([T .NormalizeFeatures (organ_columns )])
152+
153+ G = RootSystemGraph ()
154+ organ_features = []
155+ for organ_column in organ_columns :
156+ organ_features .extend (G .organ_columns [organ_column ])
157+ self .organ_features = organ_features
158+
159+ num_organ_features = len (organ_features )
160+ self .num_organ_features = num_organ_features
161+ self .conv1 = SAGEConv (
162+ num_organ_features ,
163+ num_organ_features * 4 ,
164+ aggr = "mean" ,
165+ normalize = True ,
166+ bias = True ,
167+ )
168+ self .conv2 = SAGEConv (
169+ num_organ_features * 4 ,
170+ num_organ_features * 2 ,
171+ aggr = "mean" ,
172+ normalize = True ,
173+ bias = True ,
174+ )
175+
176+ self .fc = torch .nn .Linear (num_organ_features * 2 , num_organ_features )
177+ self .pool = global_max_pool
178+ self .activation = F .elu
179+
180+ self .G_list : list = []
181+
182+ def process_graph (self , G : nx .Graph ) -> tuple :
183+ """Process a new NetworkX graph.
184+
185+ Args:
186+ G (nx.Graph):
187+ The NetworkX graph.
188+
189+ Returns:
190+ tuple:
191+ The node and edge features.
192+ """
193+ for column in self .organ_columns :
194+ G [column ] = torch .Tensor (pd .DataFrame (G [column ]).values ).double ()
195+
196+ train_data = self .transform (G )
197+ organ_features = []
198+ for column in self .organ_columns :
199+ organ_features .append (train_data [column ])
200+
201+ x = torch .Tensor (np .hstack (organ_features ))
202+ edge_index = train_data .edge_index
203+ return x , edge_index
204+
205+ def add_graph (self , G : nx .Graph ) -> int :
206+ """Add a graph to the graph list.
207+
208+ Args:
209+ G (nx.Graph):
210+ The NetworkX graph.
211+
212+ Returns:
213+ int:
214+ The list index.
215+ """
216+ x , edge_index = self .process_graph (G )
217+
218+ self .G_list .append ((x , edge_index ))
219+ return len (self .G_list ) - 1
220+
221+ def encode (self , x : torch .Tensor , edge_index : torch .Tensor ) -> torch .Tensor :
222+ """Construct graph embeddings from node and edges.
223+
224+ Args:
225+ x (torch.Tensor):
226+ The node features.
227+ edge_index (torch.Tensor):
228+ The edge index.
229+
230+ Returns:
231+ torch.Tensor:
232+ The graph embeddings.
233+ """
234+ batch_index = torch .Tensor (np .repeat (0 , x .shape [0 ])).type (torch .int64 )
235+
236+ x = self .conv1 (x , edge_index )
237+ x = self .activation (x )
238+ x = self .conv2 (x , edge_index )
239+ x = self .activation (x )
240+ x = self .pool (x , batch_index )
241+ x = self .activation (x )
242+ x = self .fc (x )
243+ x = x .view (- 1 )
244+
245+ return x
246+
247+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
248+ """The forward pass.
249+
250+ Args:
251+ x (torch.Tensor):
252+ The batch tensor.
253+
254+ Returns:
255+ torch.Tensor:
256+ The graph embedding.
257+ """
258+ if x .shape [1 ] > 1 :
259+ return x
260+
261+ batch_size = x .shape [0 ]
262+ indices = np .array (range (batch_size ))
263+
264+ batches = []
265+ batch = choices (self .G_list , k = batch_size )
266+ for i in indices :
267+ x , edge_index = batch [i ]
268+ x = self .encode (x , edge_index )
269+ batches .append (x )
270+ x = torch .stack (batches )
271+ return x
272+
273+
125274@task
126275def perform_task (
127276 input_parameters : RootCalibrationModel ,
@@ -145,23 +294,41 @@ def perform_task(
145294 tuple:
146295 The trained model and samples.
147296 """
297+ use_summary_statistics : bool = (
298+ input_parameters .statistics_comparison .use_summary_statistics
299+ )
300+ if use_summary_statistics :
301+ embedding_net = nn .Identity ()
302+ else :
303+ organ_columns = ["organ_coordinates" , "organ_hierarchy" , "organ_size" ]
304+ embedding_net = GraphFeatureExtractor (organ_columns )
148305
149306 def simulator_func (theta : np .ndarray ) -> np .ndarray :
150307 theta = theta .detach ().cpu ().numpy ()
151308 parameter_specs = {}
152309 for i , name in enumerate (names ):
153310 parameter_specs [name ] = theta [i ]
154311
155- simulated , _ = calculate_summary_statistics (
156- parameter_specs , input_parameters , statistics_list
157- )
312+ if use_summary_statistics :
313+ simulated , _ = calculate_summary_statistics (
314+ parameter_specs , input_parameters , statistics_list
315+ )
316+ else :
317+ simulation , _ = run_calibration_simulation (
318+ parameter_specs , input_parameters
319+ )
320+
321+ G = simulation .G .as_torch (drop = True )
322+ indx = embedding_net .add_graph (G )
323+ simulated = np .array ([indx ]).astype ("int" )
158324
159325 return simulated
160326
161327 calibration_parameters = input_parameters .calibration_parameters
162328 simulator , prior = prepare_for_sbi (simulator_func , priors )
163329 neural_posterior = utils .posterior_nn (
164330 model = "nsf" ,
331+ embedding_net = embedding_net ,
165332 hidden_features = calibration_parameters ["nn_num_hidden_features" ],
166333 num_transforms = calibration_parameters ["nn_num_transforms" ],
167334 )
@@ -177,16 +344,34 @@ def simulator_func(theta: np.ndarray) -> np.ndarray:
177344 inference = inference .append_simulations (theta , x , data_device = "cpu" )
178345 density_estimator = inference .train ()
179346 posterior = inference .build_posterior (density_estimator )
180-
181347 calibration_parameters = input_parameters .calibration_parameters
182348 n_draws = calibration_parameters ["pp_samples" ]
183- observed_values = []
184- for statistic in statistics_list :
185- observed_values .append (statistic .statistic_value )
186- posterior .set_default_x (observed_values )
187- posterior_samples = posterior .sample ((n_draws ,), x = observed_values )
188349
189- observed_values = [statistic .dict () for statistic in statistics_list ]
350+ if use_summary_statistics :
351+ observed_values = []
352+ for statistic in statistics_list :
353+ observed_values .append (statistic .statistic_value )
354+ posterior .set_default_x (observed_values )
355+ posterior_samples = posterior .sample ((n_draws ,), x = observed_values )
356+ observed_values = [statistic .dict () for statistic in statistics_list ]
357+ else :
358+ root_g = RootSystemGraph ()
359+ observed_data_content = input_parameters .observed_data_content
360+ raw_edge_content = input_parameters .raw_edge_content
361+ node_df , edge_df = root_g .from_content_string (
362+ observed_data_content , raw_edge_content
363+ )
364+ G = root_g .as_torch (node_df , edge_df , drop = True )
365+ x , edge_index = embedding_net .process_graph (G )
366+
367+ with torch .no_grad ():
368+ observed_values = embedding_net .encode (x , edge_index )
369+
370+ posterior .set_default_x (observed_values )
371+ posterior_samples = posterior .sample ((n_draws ,), x = observed_values )
372+ embedding_net .G_list = []
373+ observed_values = (node_df , edge_df )
374+
190375 return inference , simulator , prior , posterior , posterior_samples , observed_values
191376
192377
@@ -234,6 +419,9 @@ def log_task(
234419 tuple:
235420 The simulation and its parameters.
236421 """
422+ # use_summary_statistics: bool = (
423+ # input_parameters.statistics_comparison.use_summary_statistics
424+ # )
237425 time_now = get_datetime_now ()
238426 outdir = get_outdir ()
239427
@@ -293,6 +481,7 @@ def log_task(
293481 description = "# Simulation-based calibration metrics." ,
294482 )
295483
484+ num_bins = None
296485 if sbc_draws <= 20 : # type: ignore
297486 num_bins = sbc_draws
298487
0 commit comments