Skip to content

Commit 4e2cd06

Browse files
authored
Add files via upload
1 parent 08c82f6 commit 4e2cd06

File tree

1 file changed

+214
-0
lines changed

1 file changed

+214
-0
lines changed

run_immudef.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# compute_immune_score.py
2+
import torch
3+
import pandas as pd
4+
import numpy as np
5+
import pickle
6+
import os
7+
from sklearn.preprocessing import MinMaxScaler
8+
from QImmuDef_VAE import BetaVAE
9+
from ensemble_learning import ensemable_classcification, immune_scror_calculate
10+
from typing import Dict, Any, Optional, Union
11+
from feature_selection_methods.normalizmethods import Normalizer
12+
import argparse
13+
14+
15+
def _load_external_df(maybe_path: Union[str, pd.DataFrame]) -> pd.DataFrame:
16+
"""
17+
Load an external DataFrame. If input is already a DataFrame, return it.
18+
If it's a string, try common loaders (pickle, parquet, csv).
19+
Raise ValueError on failure.
20+
"""
21+
if maybe_path is None:
22+
return None
23+
if isinstance(maybe_path, pd.DataFrame):
24+
return maybe_path.copy()
25+
if not isinstance(maybe_path, str):
26+
raise ValueError("extra input must be a pandas DataFrame or a path string.")
27+
# try common formats in order
28+
if not os.path.exists(maybe_path):
29+
raise FileNotFoundError(f"Extra input path not found: {maybe_path}")
30+
# try pd.read_pickle first (for .pkl)
31+
try:
32+
return pd.read_pickle(maybe_path)
33+
except Exception:
34+
pass
35+
# parquet
36+
try:
37+
return pd.read_parquet(maybe_path)
38+
except Exception:
39+
pass
40+
# csv
41+
try:
42+
return pd.read_csv(maybe_path, index_col=0)
43+
except Exception:
44+
pass
45+
46+
# last resort: try generic pickle load
47+
try:
48+
with open(maybe_path, 'rb') as f:
49+
obj = pickle.load(f)
50+
if isinstance(obj, pd.DataFrame):
51+
return obj
52+
except Exception:
53+
pass
54+
55+
raise ValueError(f"Unable to load extra DataFrame from path: {maybe_path}")
56+
57+
58+
def compute_latent_immune_score(
59+
data_pkl: str = "dis_data_df.pkl",
60+
model_path: str = "anneal_betaFalse_anneal_steps20000enc512_128_32_lr0.0001_bs16_beta0.1.pth",
61+
features_npy: str = "infection_fea_619.npy",
62+
generated_h5: str = "1e6_sampling_data_is_axis.h5",
63+
h5_key: str = "df",
64+
prob_threshold: float = 0.90,
65+
latent_dim: int = 2,
66+
control_group_name: str = "Control",
67+
healthy_class_name: str = "Healthy Control",
68+
device: str = None,
69+
extra_input: Optional[Union[str, pd.DataFrame]] = None,
70+
extra_key: str = "Extra"
71+
) -> pd.DataFrame:
72+
"""
73+
一键计算基于生成高置信Control的潜空间免疫评分
74+
支持传入额外一个 DataFrame(或其文件路径),该 DataFrame 会先通过 Normalizer 做
75+
标准化(在该额外数据上 fit 标准化器),并被加入到 data_dict[extra_key] 中,参与后续流程。
76+
返回: vis_df_final (含 x, y, Disease Group, immune_score)
77+
"""
78+
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
79+
device = torch.device(device)
80+
81+
# 1. 加载特征 & 模型
82+
features = np.load(features_npy)
83+
model = torch.load(model_path, map_location=device, weights_only=False).eval()
84+
85+
# 2. 加载真实数据
86+
with open(data_pkl, 'rb') as f:
87+
data_dict = pickle.load(f) # expected dict[str, pd.DataFrame]
88+
89+
# 2.1 如果提供了额外数据,则读取并标准化后加入 data_dict
90+
if extra_input is not None:
91+
extra_df = _load_external_df(extra_input)
92+
if extra_df is None:
93+
raise ValueError("Failed to load extra input DataFrame.")
94+
# 确保 extra_df 至少包含 features 中的一部分列
95+
missing_features = [f for f in features if f not in extra_df.columns]
96+
if missing_features:
97+
# 如果缺少部分特征,尝试只保留相交的特征;如果没有交集,则报错
98+
inter = [f for f in features if f in extra_df.columns]
99+
if len(inter) == 0:
100+
raise ValueError(f"Extra DataFrame does not contain any of required features. Required features length: {len(features)}")
101+
# 提示:只使用交集特征
102+
extra_df = extra_df[inter].copy()
103+
else:
104+
# 只保留并按照 features 顺序排列(以便后续一致)
105+
extra_df = extra_df[list(features)].copy()
106+
107+
# 使用 Normalizer 对额外数据进行默认的 Standard 标准化(fit on extra_df)
108+
normalized_extra = Normalizer(extra_df, ues_fitted=False, methods='Standar')
109+
# Normalizer 可能返回 DataFrame 或 dict,这里确保拿到 DataFrame
110+
if isinstance(normalized_extra, dict):
111+
# 如果返回 dict,则取第一个元素
112+
normalized_extra = list(normalized_extra.values())[0]
113+
114+
# 将标准化后的 DataFrame 加入 data_dict
115+
data_dict[extra_key] = normalized_extra
116+
117+
# 3. 将每个 group 的特征转为 tensor 并送入 VAE 编码
118+
tensor_data = {}
119+
for k, v in data_dict.items():
120+
# v 可能不是 DataFrame(谨慎检查)
121+
if not isinstance(v, pd.DataFrame):
122+
raise TypeError(f"data_dict[{k}] is not a pandas DataFrame.")
123+
# 确保 v 含有 features(或其子集),并按 features 顺序取列
124+
inter_feats = [f for f in features if f in v.columns]
125+
if len(inter_feats) == 0:
126+
raise ValueError(f"data_dict[{k}] does not contain any of the required features.")
127+
# If some features are missing, we will only feed available columns.
128+
# The VAE model was trained on full feature set; if shape mismatch happens,
129+
# the model.encode will raise an error — user should ensure feature alignment.
130+
arr = v[inter_feats].values
131+
tensor_data[k] = torch.tensor(arr, dtype=torch.float32).to(device)
132+
133+
# 4. VAE 潜空间投影
134+
latent_z = {}
135+
for k, x in tensor_data.items():
136+
mu, logvar = model.encode(x)
137+
latent_z[k] = model.reparameterize(mu, logvar)
138+
139+
# 5. 合并生成 vis_df(x,y)
140+
dfs = []
141+
for k, z in latent_z.items():
142+
df = pd.DataFrame(z.detach().cpu().numpy(),
143+
columns=[f'z{i}' for i in range(latent_dim)],
144+
index=data_dict[k].index)
145+
df['Disease Group'] = k
146+
dfs.append(df)
147+
vis_df = pd.concat(dfs).rename(columns={'z0': 'x', 'z1': 'y'} if latent_dim >= 2 else {'z0': 'x'})
148+
149+
# 6. 生成高置信Control参考
150+
gen_df = pd.read_hdf(generated_h5, key=h5_key)
151+
probs = ensemable_classcification(gen_df[features], rt_type='mean_Prob')
152+
gen_df['Healthy Control Prob'] = probs[healthy_class_name]
153+
154+
high_conf_ctrl = gen_df[gen_df['Healthy Control Prob'] > prob_threshold].copy()
155+
real_ctrl = vis_df[vis_df['Disease Group'] == control_group_name]
156+
157+
# 动态列名兼容(支持更高维)
158+
coord_cols = ['x', 'y'][-latent_dim:]
159+
for col in coord_cols:
160+
if col not in high_conf_ctrl.columns:
161+
high_conf_ctrl[col] = high_conf_ctrl[f'z{coord_cols.index(col)}']
162+
163+
# 限制在真实Control包围盒内(若 real_ctrl 为空会出错)
164+
if real_ctrl.empty:
165+
raise ValueError(f"No real control samples found for control_group_name='{control_group_name}' in vis_df.")
166+
167+
mask = True
168+
for col in coord_cols:
169+
mn, mx = real_ctrl[col].min(), real_ctrl[col].max()
170+
mask &= high_conf_ctrl[col].between(mn, mx)
171+
gen_ctrl_filtered = high_conf_ctrl[mask]
172+
173+
# 7. 统一归一化(基于真实数据范围)
174+
scaler = MinMaxScaler().fit(vis_df[coord_cols])
175+
vis_df[coord_cols] = scaler.transform(vis_df[coord_cols])
176+
gen_ctrl_filtered[coord_cols] = scaler.transform(gen_ctrl_filtered[coord_cols])
177+
178+
# 8. 计算免疫评分(以生成Control中位数为参考点)
179+
refer_point = gen_ctrl_filtered[coord_cols].median()
180+
vis_df_final = immune_scror_calculate(
181+
input_data_dict=vis_df,
182+
features=coord_cols,
183+
refer=refer_point
184+
)
185+
return vis_df_final
186+
187+
188+
# --------------------- 命令行入口 ---------------------
189+
if __name__ == "__main__":
190+
parser = argparse.ArgumentParser(description="Compute latent-space immune score using high-confidence generated controls")
191+
parser.add_argument("--data", default="dis_data_df.pkl", help="Path to real disease data pickle")
192+
parser.add_argument("--model", default="anneal_betaFalse_anneal_steps20000enc512_128_32_lr0.0001_bs16_beta0.1.pth")
193+
parser.add_argument("--features", default="infection_fea_619.npy", help="Feature list numpy file")
194+
parser.add_argument("--generated", default="1e6_sampling_data_is_axis.h5", help="Generated samples HDF5")
195+
parser.add_argument("--threshold", type=float, default=0.90, help="Healthy control probability threshold")
196+
parser.add_argument("--output", default="vis_df_with_immune_score.parquet", help="Output file path")
197+
parser.add_argument("--extra", default=None, help="Optional extra DataFrame path (pkl/parquet/csv) or omit to skip")
198+
parser.add_argument("--extra-key", default="Extra", help="Key name to use for the extra DataFrame inside data_dict")
199+
200+
args = parser.parse_args()
201+
202+
result_df = compute_latent_immune_score(
203+
data_pkl=args.data,
204+
model_path=args.model,
205+
features_npy=args.features,
206+
generated_h5=args.generated,
207+
prob_threshold=args.threshold,
208+
extra_input=args.extra,
209+
extra_key=args.extra_key
210+
)
211+
212+
result_df.to_parquet(args.output)
213+
print(f"Success: Immune score computed! Saved to {args.output}")
214+
print(result_df['immune_score'].describe())

0 commit comments

Comments
 (0)