Skip to content

Commit 18ea83f

Browse files
authored
Create reference_genera.py
1 parent 9c438d2 commit 18ea83f

File tree

1 file changed

+148
-0
lines changed

1 file changed

+148
-0
lines changed

DImmunScore/reference_genera.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import torch
2+
import pandas as pd
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
import sklearn
6+
import os
7+
from vae_model import BetaVAE
8+
import pickle
9+
import joblib
10+
import matplotlib
11+
from sklearn.preprocessing import MinMaxScaler
12+
13+
def reference_sample_genarate(file="dis_data_df.pkl",
14+
model_file='anneal_betaFalse_anneal_steps20000enc512_128_32_lr0.0001_bs16_beta0.6.pth',
15+
featurefile='infection_fea_619.npy',
16+
dataaugmentation_file = '1e6_sampling_data_is_axis.h5',
17+
labelencoder_file='label_encoder.pkl',
18+
if_calculate=True,
19+
scaler=True):
20+
# 检查文件是否存在
21+
if os.path.exists(file):
22+
# 读取 pkl 文件为字典
23+
with open(file, 'rb') as f:
24+
data_dict = pickle.load(f)
25+
print("Dictionary successfully loaded")
26+
dis_data_df = copy.deepcopy(data_dict)
27+
else:
28+
print(f"File '{file_name}' does not exist in the current directory.")
29+
30+
immune_features = np.load(featurefile)
31+
# 将数据转换为 PyTorch 张量并移动到 GPU
32+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33+
data_dict = {k: torch.tensor(v[immune_features].values, dtype=torch.float32).to(device) for k, v in data_dict.items()}
34+
print("Data successfully moved to:", device)
35+
36+
model = torch.load(model_file,
37+
map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
38+
weights_only=False)
39+
model.eval()
40+
vae_results = {k:model(v) for k,v in data_dict.items()}
41+
latent_var = {k:model.reparameterize(v[1],v[2]) for k,v in vae_results.items()}
42+
43+
# 将每个键的值从 GPU 移到 CPU,并转为 NumPy 数组
44+
latent_var = {k: v.detach().cpu().numpy() for k, v in latent_var.items()}
45+
print("All tensors have been moved to CPU and converted to NumPy arrays.")
46+
output_is = pd.read_hdf(dataaugmentation_file, key='df')
47+
48+
vis_df = pd.concat(
49+
[
50+
pd.DataFrame(v, columns=['x', 'y'], index=dis_data_df[k].index)
51+
.join(dis_data_df[k])
52+
.assign(Disease_Group=k)
53+
for k, v in latent_var.items()
54+
],
55+
axis=0)
56+
57+
vis_df.rename(columns={'Disease_Group': 'Disease Group'}, inplace=True)
58+
59+
vis_df['Disease Group'] = vis_df['Disease Group'].replace({
60+
'COVID': 'COVID-19',
61+
'Control': 'Healthy Control',
62+
'TB': 'Tuberculosis',
63+
'Candida': 'Fungus',
64+
'HIV': 'AIDS'
65+
})
66+
67+
output_is_for_cls = output_is[output_is.columns[:-3]]
68+
svm_result, log_reg_result, xgb_result = ensemable_classcification(output_is_for_cls,rt_type='Class')
69+
70+
models = ['SVM', 'Logistic', 'XGBoost']
71+
results = [svm_result, log_reg_result, xgb_result]
72+
label_encoder = joblib.load(labelencoder_file)
73+
74+
position_labeled_df = output_is[output_is.columns[-3:]].assign(
75+
**{f"{m} Result": label_encoder.inverse_transform(r) for m, r in zip(models, results)}
76+
)
77+
78+
cols = ['SVM Result', 'Logistic Result', 'XGBoost Result']
79+
selected_label_df = position_labeled_df[position_labeled_df[cols].nunique(axis=1).eq(1)]
80+
81+
# 计算 SVM Result, Logistic Result, XGBoost Result 的均值
82+
position_labeled_df['mean_result'] = ensemable_classcification(output_is_for_cls,rt_type='mean_Prob')['Healthy Control']
83+
84+
mn_scaler = MinMaxScaler()
85+
vis_df_2 = vis_df[['x','y','Disease Group']]
86+
vis_df_2[['x', 'y']] = mn_scaler.fit_transform(vis_df_2[['x', 'y']])
87+
88+
if if_calculate:
89+
return vis_df_2,mn_scaler
90+
else:
91+
return vis_df
92+
93+
def generate_control_samples(
94+
latentspace_df,
95+
trainingdata_df,
96+
nm_scaler,
97+
control_threshold=0.95,
98+
disease_group_col="Disease Group",
99+
control_group="Healthy Control",
100+
coord_cols=["x", "y"]):
101+
"""
102+
生成标准化控制组样本
103+
104+
参数:
105+
latentspace_df (DataFrame): 潜在空间数据框
106+
trainingdata_df (DataFrame): 可视化坐标数据框
107+
control_threshold (float): 控制组筛选阈值
108+
disease_group_col (str): 疾病分组列名
109+
control_group (str): 控制组名称
110+
coord_cols (list): 坐标列名称列表
111+
112+
返回:
113+
DataFrame: 标准化后的控制组样本
114+
"""
115+
# 初始化缩放器
116+
scaler = nm_scaler
117+
118+
try:
119+
# 初步筛选控制组样本
120+
control_mask = latentspace_df[control_group].gt(control_threshold).any(axis=1)
121+
control_df = latentspace_df.loc[control_mask].copy()
122+
123+
# 获取坐标范围
124+
control_vis = trainingdata_df[trainingdata_df[disease_group_col] == control_group]
125+
ranges = {
126+
'x': (control_vis['x'].min(), control_vis['x'].max()),
127+
'y': (control_vis['y'].min(), control_vis['y'].max())
128+
}
129+
130+
# 空间范围筛选
131+
spatial_filter = (
132+
control_df['x'].between(*ranges['x']) &
133+
control_df['y'].between(*ranges['y'])
134+
)
135+
filtered_df = control_df.loc[spatial_filter]
136+
137+
# 数据标准化
138+
if not filtered_df.empty:
139+
filtered_df.loc[:, coord_cols] = scaler.fit_transform(filtered_df[coord_cols])
140+
141+
return filtered_df
142+
143+
except KeyError as e:
144+
print(f"列不存在错误: {str(e)}")
145+
return pd.DataFrame()
146+
except Exception as e:
147+
print(f"处理失败: {str(e)}")
148+
return pd.DataFrame()

0 commit comments

Comments
 (0)