|
| 1 | +# compute_immune_score.py |
| 2 | +import torch |
| 3 | +import pandas as pd |
| 4 | +import numpy as np |
| 5 | +import pickle |
| 6 | +import os |
| 7 | +from sklearn.preprocessing import MinMaxScaler |
| 8 | +from QImmuDef_VAE import BetaVAE |
| 9 | +from ensemble_learning import ensemable_classcification, immune_scror_calculate |
| 10 | +from typing import Dict, Any, Optional, Union |
| 11 | +from feature_selection_methods.normalizmethods import Normalizer |
| 12 | +import argparse |
| 13 | + |
| 14 | + |
| 15 | +def _load_external_df(maybe_path: Union[str, pd.DataFrame]) -> pd.DataFrame: |
| 16 | + """ |
| 17 | + Load an external DataFrame. If input is already a DataFrame, return it. |
| 18 | + If it's a string, try common loaders (pickle, parquet, csv). |
| 19 | + Raise ValueError on failure. |
| 20 | + """ |
| 21 | + if maybe_path is None: |
| 22 | + return None |
| 23 | + if isinstance(maybe_path, pd.DataFrame): |
| 24 | + return maybe_path.copy() |
| 25 | + if not isinstance(maybe_path, str): |
| 26 | + raise ValueError("extra input must be a pandas DataFrame or a path string.") |
| 27 | + # try common formats in order |
| 28 | + if not os.path.exists(maybe_path): |
| 29 | + raise FileNotFoundError(f"Extra input path not found: {maybe_path}") |
| 30 | + # try pd.read_pickle first (for .pkl) |
| 31 | + try: |
| 32 | + return pd.read_pickle(maybe_path) |
| 33 | + except Exception: |
| 34 | + pass |
| 35 | + # parquet |
| 36 | + try: |
| 37 | + return pd.read_parquet(maybe_path) |
| 38 | + except Exception: |
| 39 | + pass |
| 40 | + # csv |
| 41 | + try: |
| 42 | + return pd.read_csv(maybe_path, index_col=0) |
| 43 | + except Exception: |
| 44 | + pass |
| 45 | + |
| 46 | + # last resort: try generic pickle load |
| 47 | + try: |
| 48 | + with open(maybe_path, 'rb') as f: |
| 49 | + obj = pickle.load(f) |
| 50 | + if isinstance(obj, pd.DataFrame): |
| 51 | + return obj |
| 52 | + except Exception: |
| 53 | + pass |
| 54 | + |
| 55 | + raise ValueError(f"Unable to load extra DataFrame from path: {maybe_path}") |
| 56 | + |
| 57 | + |
| 58 | +def compute_latent_immune_score( |
| 59 | + data_pkl: str = "dis_data_df.pkl", |
| 60 | + model_path: str = "anneal_betaFalse_anneal_steps20000enc512_128_32_lr0.0001_bs16_beta0.1.pth", |
| 61 | + features_npy: str = "infection_fea_619.npy", |
| 62 | + generated_h5: str = "1e6_sampling_data_is_axis.h5", |
| 63 | + h5_key: str = "df", |
| 64 | + prob_threshold: float = 0.90, |
| 65 | + latent_dim: int = 2, |
| 66 | + control_group_name: str = "Control", |
| 67 | + healthy_class_name: str = "Healthy Control", |
| 68 | + device: str = None, |
| 69 | + extra_input: Optional[Union[str, pd.DataFrame]] = None, |
| 70 | + extra_key: str = "Extra" |
| 71 | +) -> pd.DataFrame: |
| 72 | + """ |
| 73 | + 一键计算基于生成高置信Control的潜空间免疫评分 |
| 74 | + 支持传入额外一个 DataFrame(或其文件路径),该 DataFrame 会先通过 Normalizer 做 |
| 75 | + 标准化(在该额外数据上 fit 标准化器),并被加入到 data_dict[extra_key] 中,参与后续流程。 |
| 76 | + 返回: vis_df_final (含 x, y, Disease Group, immune_score) |
| 77 | + """ |
| 78 | + device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| 79 | + device = torch.device(device) |
| 80 | + |
| 81 | + # 1. 加载特征 & 模型 |
| 82 | + features = np.load(features_npy) |
| 83 | + model = torch.load(model_path, map_location=device, weights_only=False).eval() |
| 84 | + |
| 85 | + # 2. 加载真实数据 |
| 86 | + with open(data_pkl, 'rb') as f: |
| 87 | + data_dict = pickle.load(f) # expected dict[str, pd.DataFrame] |
| 88 | + |
| 89 | + # 2.1 如果提供了额外数据,则读取并标准化后加入 data_dict |
| 90 | + if extra_input is not None: |
| 91 | + extra_df = _load_external_df(extra_input) |
| 92 | + if extra_df is None: |
| 93 | + raise ValueError("Failed to load extra input DataFrame.") |
| 94 | + # 确保 extra_df 至少包含 features 中的一部分列 |
| 95 | + missing_features = [f for f in features if f not in extra_df.columns] |
| 96 | + if missing_features: |
| 97 | + # 如果缺少部分特征,尝试只保留相交的特征;如果没有交集,则报错 |
| 98 | + inter = [f for f in features if f in extra_df.columns] |
| 99 | + if len(inter) == 0: |
| 100 | + raise ValueError(f"Extra DataFrame does not contain any of required features. Required features length: {len(features)}") |
| 101 | + # 提示:只使用交集特征 |
| 102 | + extra_df = extra_df[inter].copy() |
| 103 | + else: |
| 104 | + # 只保留并按照 features 顺序排列(以便后续一致) |
| 105 | + extra_df = extra_df[list(features)].copy() |
| 106 | + |
| 107 | + # 使用 Normalizer 对额外数据进行默认的 Standard 标准化(fit on extra_df) |
| 108 | + normalized_extra = Normalizer(extra_df, ues_fitted=False, methods='Standar') |
| 109 | + # Normalizer 可能返回 DataFrame 或 dict,这里确保拿到 DataFrame |
| 110 | + if isinstance(normalized_extra, dict): |
| 111 | + # 如果返回 dict,则取第一个元素 |
| 112 | + normalized_extra = list(normalized_extra.values())[0] |
| 113 | + |
| 114 | + # 将标准化后的 DataFrame 加入 data_dict |
| 115 | + data_dict[extra_key] = normalized_extra |
| 116 | + |
| 117 | + # 3. 将每个 group 的特征转为 tensor 并送入 VAE 编码 |
| 118 | + tensor_data = {} |
| 119 | + for k, v in data_dict.items(): |
| 120 | + # v 可能不是 DataFrame(谨慎检查) |
| 121 | + if not isinstance(v, pd.DataFrame): |
| 122 | + raise TypeError(f"data_dict[{k}] is not a pandas DataFrame.") |
| 123 | + # 确保 v 含有 features(或其子集),并按 features 顺序取列 |
| 124 | + inter_feats = [f for f in features if f in v.columns] |
| 125 | + if len(inter_feats) == 0: |
| 126 | + raise ValueError(f"data_dict[{k}] does not contain any of the required features.") |
| 127 | + # If some features are missing, we will only feed available columns. |
| 128 | + # The VAE model was trained on full feature set; if shape mismatch happens, |
| 129 | + # the model.encode will raise an error — user should ensure feature alignment. |
| 130 | + arr = v[inter_feats].values |
| 131 | + tensor_data[k] = torch.tensor(arr, dtype=torch.float32).to(device) |
| 132 | + |
| 133 | + # 4. VAE 潜空间投影 |
| 134 | + latent_z = {} |
| 135 | + for k, x in tensor_data.items(): |
| 136 | + mu, logvar = model.encode(x) |
| 137 | + latent_z[k] = model.reparameterize(mu, logvar) |
| 138 | + |
| 139 | + # 5. 合并生成 vis_df(x,y) |
| 140 | + dfs = [] |
| 141 | + for k, z in latent_z.items(): |
| 142 | + df = pd.DataFrame(z.detach().cpu().numpy(), |
| 143 | + columns=[f'z{i}' for i in range(latent_dim)], |
| 144 | + index=data_dict[k].index) |
| 145 | + df['Disease Group'] = k |
| 146 | + dfs.append(df) |
| 147 | + vis_df = pd.concat(dfs).rename(columns={'z0': 'x', 'z1': 'y'} if latent_dim >= 2 else {'z0': 'x'}) |
| 148 | + |
| 149 | + # 6. 生成高置信Control参考 |
| 150 | + gen_df = pd.read_hdf(generated_h5, key=h5_key) |
| 151 | + probs = ensemable_classcification(gen_df[features], rt_type='mean_Prob') |
| 152 | + gen_df['Healthy Control Prob'] = probs[healthy_class_name] |
| 153 | + |
| 154 | + high_conf_ctrl = gen_df[gen_df['Healthy Control Prob'] > prob_threshold].copy() |
| 155 | + real_ctrl = vis_df[vis_df['Disease Group'] == control_group_name] |
| 156 | + |
| 157 | + # 动态列名兼容(支持更高维) |
| 158 | + coord_cols = ['x', 'y'][-latent_dim:] |
| 159 | + for col in coord_cols: |
| 160 | + if col not in high_conf_ctrl.columns: |
| 161 | + high_conf_ctrl[col] = high_conf_ctrl[f'z{coord_cols.index(col)}'] |
| 162 | + |
| 163 | + # 限制在真实Control包围盒内(若 real_ctrl 为空会出错) |
| 164 | + if real_ctrl.empty: |
| 165 | + raise ValueError(f"No real control samples found for control_group_name='{control_group_name}' in vis_df.") |
| 166 | + |
| 167 | + mask = True |
| 168 | + for col in coord_cols: |
| 169 | + mn, mx = real_ctrl[col].min(), real_ctrl[col].max() |
| 170 | + mask &= high_conf_ctrl[col].between(mn, mx) |
| 171 | + gen_ctrl_filtered = high_conf_ctrl[mask] |
| 172 | + |
| 173 | + # 7. 统一归一化(基于真实数据范围) |
| 174 | + scaler = MinMaxScaler().fit(vis_df[coord_cols]) |
| 175 | + vis_df[coord_cols] = scaler.transform(vis_df[coord_cols]) |
| 176 | + gen_ctrl_filtered[coord_cols] = scaler.transform(gen_ctrl_filtered[coord_cols]) |
| 177 | + |
| 178 | + # 8. 计算免疫评分(以生成Control中位数为参考点) |
| 179 | + refer_point = gen_ctrl_filtered[coord_cols].median() |
| 180 | + vis_df_final = immune_scror_calculate( |
| 181 | + input_data_dict=vis_df, |
| 182 | + features=coord_cols, |
| 183 | + refer=refer_point |
| 184 | + ) |
| 185 | + return vis_df_final |
| 186 | + |
| 187 | + |
| 188 | +# --------------------- 命令行入口 --------------------- |
| 189 | +if __name__ == "__main__": |
| 190 | + parser = argparse.ArgumentParser(description="Compute latent-space immune score using high-confidence generated controls") |
| 191 | + parser.add_argument("--data", default="dis_data_df.pkl", help="Path to real disease data pickle") |
| 192 | + parser.add_argument("--model", default="anneal_betaFalse_anneal_steps20000enc512_128_32_lr0.0001_bs16_beta0.1.pth") |
| 193 | + parser.add_argument("--features", default="infection_fea_619.npy", help="Feature list numpy file") |
| 194 | + parser.add_argument("--generated", default="1e6_sampling_data_is_axis.h5", help="Generated samples HDF5") |
| 195 | + parser.add_argument("--threshold", type=float, default=0.90, help="Healthy control probability threshold") |
| 196 | + parser.add_argument("--output", default="vis_df_with_immune_score.parquet", help="Output file path") |
| 197 | + parser.add_argument("--extra", default=None, help="Optional extra DataFrame path (pkl/parquet/csv) or omit to skip") |
| 198 | + parser.add_argument("--extra-key", default="Extra", help="Key name to use for the extra DataFrame inside data_dict") |
| 199 | + |
| 200 | + args = parser.parse_args() |
| 201 | + |
| 202 | + result_df = compute_latent_immune_score( |
| 203 | + data_pkl=args.data, |
| 204 | + model_path=args.model, |
| 205 | + features_npy=args.features, |
| 206 | + generated_h5=args.generated, |
| 207 | + prob_threshold=args.threshold, |
| 208 | + extra_input=args.extra, |
| 209 | + extra_key=args.extra_key |
| 210 | + ) |
| 211 | + |
| 212 | + result_df.to_parquet(args.output) |
| 213 | + print(f"Success: Immune score computed! Saved to {args.output}") |
| 214 | + print(result_df['immune_score'].describe()) |
0 commit comments