11"""Compute embeddings and evaluate downstream prediction."""
22import copy
3+ import multiprocessing as mp
34import os
45import pickle as pkl
56import tempfile
7+ from collections import defaultdict
68from contextlib import contextmanager
79
810import hydra
1113import pytorch_lightning as pl
1214import torch
1315from omegaconf import OmegaConf
16+ from sklearn .metrics import make_scorer
1417
15- try :
16- from embeddings_validation import ReportCollect
17- from embeddings_validation .config import Config
18- except ImportError :
19- raise ImportError ("Please, install embeddings_validation or hotpp-benchmark[downstream]" )
2018from .common import get_trainer , dump_report
2119from .embed import distributed_predict , extract_embeddings , embeddings_to_pandas
2220
@@ -73,7 +71,12 @@ def targets_to_pandas(id_field, target_names, by_split):
7371 return pd .DataFrame (columns ).set_index (id_field )
7472
7573
76- def eval_embeddings (conf ):
74+ def eval_embeddings_ptls (conf ):
75+ try :
76+ from embeddings_validation import ReportCollect
77+ from embeddings_validation .config import Config
78+ except ImportError :
79+ raise ImportError ("Please, install embeddings_validation or hotpp-benchmark[downstream]" )
7780 OmegaConf .set_struct (conf , False )
7881 conf ["workers" ] = conf .get ("workers" , 1 )
7982 conf ["total_cpu_count" ] = conf .get ("total_cpu_count" , conf .workers )
@@ -86,6 +89,86 @@ def eval_embeddings(conf):
8689 local_scheduler = conf .get ("local_scheduler" , True ),
8790 log_level = conf .get ("log_level" , "INFO" ))
8891
92+ scores = {split : {f"ptls-{ k } " : v for k , v in results .items ()}
93+ for split , results in parse_result (conf .report_file ).items ()}
94+ return scores # split -> metric -> value.
95+
96+
97+ def eval_embeddings_impl (conf ):
98+ id_field = conf .target .cols_id
99+ if len (id_field ) != 1 :
100+ raise RuntimeError ("Multiple ID fields." )
101+ id_field = id_field [0 ]
102+ with open (conf .features .embeddings .read_params .file_name , "rb" ) as fp :
103+ data = pkl .load (fp ).set_index (id_field ) # (id, ...features...).
104+ targets = pd .read_csv (conf .target .file_name )[[id_field , conf .target .col_target ]].set_index (id_field )
105+ train_index = pd .read_csv (conf .split .train_id .file_name )[[id_field ]].set_index (id_field )
106+ val_index = pd .read_csv (conf .split .val_id .file_name )[[id_field ]].set_index (id_field ) if "val_id" in conf .split else None
107+ test_index = pd .read_csv (conf .split .test_id .file_name )[[id_field ]].set_index (id_field )
108+
109+ models = {k : v for k , v in conf .models .items () if v .enabled }
110+ if len (models ) != 1 :
111+ raise RuntimeError ("Multiple models." )
112+ model = next (iter (models .values ()))
113+ preprocessors = []
114+ if "preprocessing" in model :
115+ data_train = data [train_index ]
116+ for p in model .preprocessing :
117+ preprocessor = hydra .utils .instantiate (p )
118+ data_train = preprocessor .fit_transform (data_train )
119+ model = hydra .utils .instantiate (model .model )
120+ metrics = {}
121+ for k , spec in conf .metrics .items ():
122+ spec = dict (spec )
123+ if not spec .pop ("enabled" , True ):
124+ continue
125+ scorer_fn = hydra .utils .get_method (spec .pop ("score_func" ))
126+ scorer_params = dict (spec .pop ("scorer_params" , {}))
127+ if spec :
128+ raise ValueError (f"Unknown metric parameters: { spec .keys ()} " )
129+ metrics [k ] = make_scorer (scorer_fn , ** scorer_params )
130+
131+ train_targets = targets .join (train_index , how = "inner" )
132+ val_targets = targets .join (val_index , how = "inner" ) if val_index is not None else None
133+ test_targets = targets .join (test_index , how = "inner" )
134+
135+ train_data = data .loc [train_targets .index ]
136+ val_data = data .loc [val_targets .index ] if val_targets is not None else None
137+ test_data = data .loc [test_targets .index ]
138+
139+ for preprocessor in preprocessors :
140+ train_data = preprocessor .transform (train_data )
141+ val_data = preprocessor .transform (val_data ) if val_data is not None else None
142+ test_data = preprocessor .transform (test_data )
143+
144+ model .fit (train_data , train_targets )
145+
146+ scores = defaultdict (dict )
147+ for name , metric in metrics .items ():
148+ if val_data is not None :
149+ scores ["val" ][f"downstream-{ name } " ] = metric (model , val_data , val_targets )
150+ scores ["test" ][f"downstream-{ name } " ] = metric (model , test_data , test_targets )
151+ return scores
152+
153+
154+ def eval_embeddings_worker (conf , pipe ):
155+ try :
156+ scores = eval_embeddings_impl (conf )
157+ pipe .send (scores )
158+ finally :
159+ pipe .close ()
160+
161+
162+ def eval_embeddings (conf ):
163+ parent , child = mp .Pipe (duplex = False )
164+ p = mp .Process (target = eval_embeddings_worker , args = (conf , child ))
165+ p .start ()
166+ scores = parent .recv ()
167+ p .join ()
168+ if p .exitcode != 0 :
169+ raise RuntimeError (f"Evaluation failed" )
170+ return scores # split -> metric -> value.
171+
89172
90173def parse_result (path ):
91174 scores = {}
@@ -103,7 +186,8 @@ def parse_result(path):
103186 tokens = line .strip ().split ()
104187 mean = float (tokens [2 ])
105188 std = float (tokens [6 ])
106- scores [split ] = (mean , std )
189+ scores [split ] = {"downstream" : mean ,
190+ "downstream-std" : std }
107191 split = None
108192 return scores
109193
@@ -126,7 +210,10 @@ def eval_downstream(downstream_config, trainer, datamodule, model,
126210 embeddings = extract_embeddings (trainer , datamodule , model , splits = splits )
127211 embeddings = embeddings_to_pandas (datamodule .id_field , embeddings )
128212 if len (embeddings .index .unique ()) != len (embeddings ):
129- raise ValueError ("Duplicate ids" )
213+ from collections import Counter
214+ duplicates = Counter (embeddings .index .to_list ())
215+ duplicates = {k : v for k , v in duplicates .items () if v > 1 }
216+ raise ValueError (f"Duplicate ids { duplicates } " )
130217 if precomputed_targets is not None :
131218 if isinstance (precomputed_targets , str ):
132219 targets = pd .read_parquet (precomputed_targets ).set_index ("id" )
@@ -152,14 +239,20 @@ def eval_downstream(downstream_config, trainer, datamodule, model,
152239 downstream_config .environment .work_dir = root
153240 downstream_config .features .embeddings .read_params .file_name = embeddings_path
154241 downstream_config .target .file_name = targets_path
155- downstream_config .split .train_id . file_name = targets_path
242+ downstream_config .split .train_id = OmegaConf . create ({ "file_name" : targets_path })
156243 downstream_config .report_file = os .path .join (root , "downstream_report.txt" )
157244
245+ val_targets = targets [targets ["split" ] == "val" ]
246+ if len (val_targets ) > 0 :
247+ val_ids_path = os .path .join (root , "val_ids.csv" )
248+ val_targets [[]].to_csv (val_ids_path ) # Index only.
249+ downstream_config .split .val_id = OmegaConf .create ({"file_name" : val_ids_path })
250+
158251 test_targets = targets [targets ["split" ] == "test" ]
159252 if len (test_targets ) > 0 :
160253 test_ids_path = os .path .join (root , "test_ids.csv" )
161254 test_targets [[]].to_csv (test_ids_path ) # Index only.
162- downstream_config .split .test_id . file_name = test_ids_path
255+ downstream_config .split .test_id = OmegaConf . create ({ "file_name" : test_ids_path })
163256
164257 if os .path .exists (downstream_config .report_file ):
165258 os .remove (downstream_config .report_file )
@@ -170,9 +263,11 @@ def eval_downstream(downstream_config, trainer, datamodule, model,
170263 is_main_process = True
171264 if not is_main_process :
172265 return None
173- eval_embeddings (downstream_config )
174-
175- scores = parse_result (downstream_config .report_file )
266+ scores = eval_embeddings (downstream_config )
267+ if downstream_config .get ("eval_ptls" , False ):
268+ ptls_scores = eval_embeddings_ptls (downstream_config )
269+ for split , result in ptls_scores .items ():
270+ scores [split ].update (result )
176271 return scores
177272
178273
@@ -192,9 +287,9 @@ def main(conf):
192287 if scores is not None :
193288 # The main process.
194289 result = {}
195- for split , ( mean , std ) in scores .items ():
196- result [ f" { split } / { conf . downstream . target . col_target } (mean)" ] = mean
197- result [f"{ split } /{ conf . downstream . target . col_target } (std) " ] = std
290+ for split , metrics in scores .items ():
291+ for metric , value in metrics . items ():
292+ result [f"{ split } /{ metric } " ] = value
198293 with open (downstream_report , "w" ) as fp :
199294 dump_report (result , fp )
200295
0 commit comments