-
Notifications
You must be signed in to change notification settings - Fork 223
[Example] Add Chem Suzuki-Miyaura 交叉偶联反应产率预测 #1175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 14 commits
edc3512
e2ab6ec
6356c91
e8ab68c
6ee370c
66c0957
fa853a2
9a36d81
865689a
eaf1054
9b67577
d33f8bf
62d18cf
a4f882a
41e9f8b
7f8ddce
aa7d727
0f3d15c
6c098d9
e53498f
2b0e527
13af56a
38b1a8b
74da604
1e2555e
cf45ac5
d30b1d8
06b4c44
0482f70
04ae02d
0759af8
a0a3995
3b16af5
2743241
8126b2a
150d12c
daac3d2
16aae51
349bf3b
904a25c
3370e1d
4a79ffd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,142 @@ | ||||||||||||||||
# Suzuki-Miyaura 交叉偶联反应产率预测 | ||||||||||||||||
|
||||||||||||||||
!!! note | ||||||||||||||||
|
||||||||||||||||
1. 开始训练、评估前,数据文件data_set.xlsx的存在,并对应修改 yaml 配置文件中的 `data_dir` 为数据文件路径。 | ||||||||||||||||
2. 如果需要使用预训练模型进行评估,请先下载预训练模型[chem_model.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/TADF/Est/Est_pretrained.pdparams), 并对应修改 yaml 配置文件中的 `load_model_path` 为模型参数路径。 | ||||||||||||||||
3. 开始训练、评估前,请安装 `rdkit` 等,相关依赖请执行`pip install -r requirements.txt`安装。 | ||||||||||||||||
|
||||||||||||||||
=== "模型训练命令" | ||||||||||||||||
|
||||||||||||||||
``` sh | ||||||||||||||||
# 训练: | ||||||||||||||||
python chem.py mode=train | ||||||||||||||||
``` | ||||||||||||||||
|
||||||||||||||||
=== "模型评估命令" | ||||||||||||||||
|
||||||||||||||||
``` sh | ||||||||||||||||
# 评估: | ||||||||||||||||
python chem.py mode=eval | ||||||||||||||||
``` | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
|
||||||||||||||||
## 1. 背景简介 | ||||||||||||||||
|
||||||||||||||||
Suzuki-Miyaura 交叉偶联反应表达式如下所示。 | ||||||||||||||||
|
||||||||||||||||
$$ | ||||||||||||||||
\mathrm{Ar{-}X} + \mathrm{Ar'{-}B(OH)_2} \xrightarrow[\text{Base}]{\mathrm{Pd}^0} \mathrm{Ar{-}Ar'} + \mathrm{HX} | ||||||||||||||||
$$ | ||||||||||||||||
|
||||||||||||||||
在零价钯配合物催化下,芳基或烯基硼酸或硼酸酯与氯、溴、碘代芳烃或烯烃发生交叉偶联。该反应具有反应条件温和、转化率高的优点,在材料合成、药物研发等领域具有重要作用,但存在开发周期长,试错成本高的问题。本研究通过使用高通量实验数据分析反应底物(包括亲电试剂和亲核试剂),催化配体,碱基,溶剂对偶联反应产率的影响,从而建立预测模型。 | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
## 2. Suzuki-Miyaura 交叉偶联反应产率预测模型的实现 | ||||||||||||||||
|
||||||||||||||||
本节将讲解如何基于PaddleScience代码,实现对于 Suzuki-Miyaura 交叉偶联反应产率预测模型的构建、训练、测试和评估。案例的目录结构如下。 | ||||||||||||||||
``` log | ||||||||||||||||
chem/ | ||||||||||||||||
├──config/ | ||||||||||||||||
│ └── chem.yaml | ||||||||||||||||
├── chem.py | ||||||||||||||||
├── data_set.xlsx | ||||||||||||||||
└── requirements.txt | ||||||||||||||||
``` | ||||||||||||||||
|
||||||||||||||||
### 2.1 数据集构建和载入 | ||||||||||||||||
|
||||||||||||||||
本样例使用的数据来自参考文献[1]提供的开源数据,仅考虑试剂本身对于实验结果的影响,从中筛选了各分量均有试剂参与的部分反应数据,保存在文件 `./data_set.xlsx` 中。该工作开发了一套基于流动化学(flow chemistry)的自动化平台,该平台在氩气保护的手套箱中组装,使用改良的高效液相色谱(HPLC)系统,结合自动化取样装置,从192个储液瓶中按设定程序吸取反应组分(亲电试剂、亲核试剂、催化剂、配体和碱),并注入流动载液中。每个反应段在温控反应盘管中以设定的流速、压力、时间进行反应,反应液通过UPLC-MS进行实时检测。通过调控亲电试剂、亲核试剂、11种配体、7种碱和4种溶剂的组合,最终实现了5760个反应条件的系统性筛选。接下来以其中一条数据为例结合代码说明数据集的构建与载入流程。 | ||||||||||||||||
|
||||||||||||||||
``` | ||||||||||||||||
ClC=1C=C2C=CC=NC2=CC1 | CC=1C(=C2C=NN(C2=CC1)C1OCCCC1)B(O)O | C(C)(C)(C)P(C(C)(C)C)C(C)(C)C | [OH-].[Na+] | C(C)#N | 4.76 | ||||||||||||||||
``` | ||||||||||||||||
其中用SMILES依次表示亲电试剂、亲核试剂、催化配体、碱、溶剂和实验产率。 | ||||||||||||||||
|
||||||||||||||||
首先从表格文件中将实验材料信息和反应产率进行导入,并划分训练集和测试集, | ||||||||||||||||
|
||||||||||||||||
``` py linenums="27" title="examples/chem/chem.py" | ||||||||||||||||
--8<-- | ||||||||||||||||
examples/chem/chem.py:27:35 | ||||||||||||||||
--8<-- | ||||||||||||||||
``` | ||||||||||||||||
|
||||||||||||||||
应用 `rdkit.Chem.rdFingerprintGenerator` 将亲电试剂、亲核试剂、催化配体、碱和溶剂的SMILES描述转换为 Morgan 指纹。Morgan指纹是一种分子结构的向量化描述,通过局部拓扑被编码为 hash 值,映射到2048位指纹位上。用 PaddleScience 代码表示如下 | ||||||||||||||||
|
||||||||||||||||
``` py linenums="38" title="examples/chem/chem.py" | ||||||||||||||||
--8<-- | ||||||||||||||||
examples/chem/chem.py:38:66 | ||||||||||||||||
--8<-- | ||||||||||||||||
``` | ||||||||||||||||
|
||||||||||||||||
### 2.2 约束构建 | ||||||||||||||||
|
||||||||||||||||
本案例采用监督学习,按照 PaddleScience 的API结构说明,采用内置的 `SupervisedConstraint` 构建监督约束。用 PaddleScience 代码表示如下 | ||||||||||||||||
|
||||||||||||||||
``` py linenums="73" title="examples/chem/chem.py" | ||||||||||||||||
--8<-- | ||||||||||||||||
examples/chem/chem.py:73:89 | ||||||||||||||||
--8<-- | ||||||||||||||||
``` | ||||||||||||||||
`SupervisedConstraint` 的第二个参数表示采用均方误差 `MSELoss` 作为损失函数,第三个参数表示约束条件的名字,方便后续对其索引。 | ||||||||||||||||
|
||||||||||||||||
### 2.3 模型构建 | ||||||||||||||||
|
||||||||||||||||
本案例设计了五条独立的子网络(全连接层+ReLU激活),每条子网络分别提取对应化学物质的特征。随后,这五个特征向量通过可训练的权重参数进行加权平均,实现不同化学成分对反应产率预测影响的自适应学习。最后,将融合后的特征输入到一个全连接层进行进一步映射,输出反应产率预测值。整个网络结构体现了对反应中各组成成分信息的独立提取与有权重的融合,符合反应机理特性。用 PaddleScience 代码表示如下 | ||||||||||||||||
|
||||||||||||||||
``` py linenums="7" title="ppsci/arch/chem.py" | ||||||||||||||||
--8<-- | ||||||||||||||||
ppsci/arch/chem.py:7:107 | ||||||||||||||||
--8<-- | ||||||||||||||||
``` | ||||||||||||||||
|
||||||||||||||||
模型依据配置文件信息进行实例化 | ||||||||||||||||
|
||||||||||||||||
``` py linenums="91" title="examples/chem/chem.py" | ||||||||||||||||
--8<-- | ||||||||||||||||
examples/chem/chem.py:91:91 | ||||||||||||||||
--8<-- | ||||||||||||||||
``` | ||||||||||||||||
|
||||||||||||||||
参数通过配置文件进行设置如下 | ||||||||||||||||
|
||||||||||||||||
``` py linenums="35" title="examples/chem/config/chem.yaml" | ||||||||||||||||
--8<-- | ||||||||||||||||
examples/chem/config/chem.yaml:35:41 | ||||||||||||||||
--8<-- | ||||||||||||||||
``` | ||||||||||||||||
|
||||||||||||||||
### 2.4 优化器构建 | ||||||||||||||||
|
||||||||||||||||
训练器采用Adam优化器,学习率设置由配置文件给出。用 PaddleScience 代码表示如下 | ||||||||||||||||
|
||||||||||||||||
``` py linenums="93" title="examples/chem/chem.py" | ||||||||||||||||
--8<-- | ||||||||||||||||
examples/chem/chem.py:93:93 | ||||||||||||||||
--8<-- | ||||||||||||||||
``` | ||||||||||||||||
|
||||||||||||||||
### 2.5 模型训练 | ||||||||||||||||
|
||||||||||||||||
完成上述设置之后,只需要将上述实例化的对象按顺序传递给`ppsci.solver.Solver`,然后启动训练即可。用PaddleScience 代码表示如下 | ||||||||||||||||
|
||||||||||||||||
``` py linenums="95" title="examples/chem/chem.py" | ||||||||||||||||
--8<-- | ||||||||||||||||
examples/chem/chem.py:95:105 | ||||||||||||||||
--8<-- | ||||||||||||||||
``` | ||||||||||||||||
|
||||||||||||||||
## 3. 完整代码 | ||||||||||||||||
|
||||||||||||||||
``` py linenums="1" title="examples/chem/chem.py" | ||||||||||||||||
--8<-- | ||||||||||||||||
examples/chem/chem.py | ||||||||||||||||
--8<-- | ||||||||||||||||
``` | ||||||||||||||||
|
||||||||||||||||
## 4. 结果展示 | ||||||||||||||||
|
||||||||||||||||
下图展示对 Suzuki-Miyaura 交叉偶联反应产率的模型预测结果。 | ||||||||||||||||
|
||||||||||||||||
## 5. 参考文献 | ||||||||||||||||
|
||||||||||||||||
[1] Perera D, Tucker J W, Brahmbhatt S, et al. A platform for automated nanomole-scale reaction screening and micromole-scale synthesis in flow[J]. Science, 2018, 359(6374): 429-434. |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,157 @@ | ||||||
import os | ||||||
|
||||||
import hydra | ||||||
import matplotlib.pyplot as plt | ||||||
import numpy as np | ||||||
import paddle | ||||||
import pandas as pd | ||||||
import rdkit.Chem as Chem | ||||||
from omegaconf import DictConfig | ||||||
from rdkit.Chem import rdFingerprintGenerator | ||||||
from sklearn.metrics import r2_score | ||||||
from sklearn.model_selection import train_test_split | ||||||
|
||||||
import ppsci | ||||||
|
||||||
os.environ["HYDRA_FULL_ERROR"] = "1" | ||||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "True" | ||||||
plt.rcParams["axes.unicode_minus"] = False | ||||||
plt.rcParams["font.sans-serif"] = ["DejaVu Sans"] | ||||||
|
||||||
x_train = None | ||||||
x_test = None | ||||||
y_train = None | ||||||
y_test = None | ||||||
|
||||||
|
||||||
def load_data(cfg: DictConfig): | ||||||
data_dir = cfg.data_dir | ||||||
dataset = pd.read_excel(data_dir, skiprows=1) | ||||||
x = dataset.iloc[:, 1:6] | ||||||
y = dataset.iloc[:, 6] | ||||||
x_train, x_test, y_train, y_test = train_test_split( | ||||||
x, y, test_size=0.2, random_state=42 | ||||||
) | ||||||
return x_train, x_test, y_train, y_test | ||||||
|
||||||
|
||||||
def data_processed(x, y): | ||||||
x = build_dataset(x) | ||||||
y = paddle.to_tensor(y.to_numpy(dtype=np.float32)) | ||||||
y = paddle.unsqueeze(y, axis=1) | ||||||
return x, y | ||||||
|
||||||
|
||||||
def build_dataset(data): | ||||||
r1 = paddle.to_tensor(np.array(cal_print(data.iloc[:, 0])), dtype=paddle.float32) | ||||||
r2 = paddle.to_tensor(np.array(cal_print(data.iloc[:, 1])), dtype=paddle.float32) | ||||||
ligand = paddle.to_tensor( | ||||||
np.array(cal_print(data.iloc[:, 2])), dtype=paddle.float32 | ||||||
) | ||||||
base = paddle.to_tensor(np.array(cal_print(data.iloc[:, 3])), dtype=paddle.float32) | ||||||
solvent = paddle.to_tensor( | ||||||
np.array(cal_print(data.iloc[:, 4])), dtype=paddle.float32 | ||||||
) | ||||||
return paddle.concat([r1, r2, ligand, base, solvent], axis=1) | ||||||
|
||||||
|
||||||
def cal_print(smiles): | ||||||
vectors = [] | ||||||
for smi in smiles: | ||||||
mol = Chem.MolFromSmiles(smi) | ||||||
generator = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048) | ||||||
fp = generator.GetFingerprint(mol) | ||||||
_input = np.array(list(map(float, fp.ToBitString()))) | ||||||
vectors.append(_input) | ||||||
return vectors | ||||||
|
||||||
|
||||||
def train(cfg: DictConfig): | ||||||
global x_train, y_train | ||||||
x_train, y_train = data_processed(x_train, y_train) | ||||||
|
||||||
# build supervised constraint | ||||||
sup = ppsci.constraint.SupervisedConstraint( | ||||||
dataloader_cfg={ | ||||||
"dataset": { | ||||||
"input": {"v": x_train}, | ||||||
"label": {"u": y_train}, | ||||||
# "weight": {"W": param}, | ||||||
"name": "IterableNamedArrayDataset", | ||||||
}, | ||||||
"batch_size": cfg.TRAIN.batch_size, | ||||||
}, | ||||||
loss=ppsci.loss.MSELoss("mean"), | ||||||
name="sup", | ||||||
) | ||||||
constraint = { | ||||||
"sup": sup, | ||||||
} | ||||||
|
||||||
model = ppsci.arch.ChemMultimodalMLP(**cfg.MODEL) | ||||||
|
||||||
optimizer = ppsci.optimizer.optimizer.Adam(cfg.TRAIN.learning_rate)(model) | ||||||
|
||||||
# Build solver | ||||||
solver = ppsci.solver.Solver( | ||||||
model, | ||||||
constraint=constraint, | ||||||
optimizer=optimizer, | ||||||
epochs=cfg.TRAIN.epochs, | ||||||
eval_during_train=False, | ||||||
iters_per_epoch=cfg.TRAIN.iters_per_epoch, | ||||||
cfg=cfg, | ||||||
) | ||||||
solver.train() | ||||||
|
||||||
|
||||||
def eval(cfg: DictConfig): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
global x_test, y_test | ||||||
x_test, y_test = data_processed(x_test, y_test) | ||||||
# Reformat data for evaluation | ||||||
x_test = {"v": x_test} | ||||||
y_test = {"u": y_test} | ||||||
model = ppsci.arch.ChemMultimodalMLP(**cfg.MODEL) | ||||||
model.set_state_dict(paddle.load(cfg.EVAL.load_model_path)) | ||||||
ypred = model(x_test) | ||||||
|
||||||
# Calculate evaluation metrics | ||||||
loss = ppsci.metric.MAE() | ||||||
MAE = loss(ypred, y_test).get("u").numpy() | ||||||
loss = ppsci.metric.RMSE() | ||||||
RMSE = loss(ypred, y_test).get("u").numpy() | ||||||
ypred = ypred.get("u").numpy() | ||||||
ytest = y_test.get("u").numpy() | ||||||
R2 = r2_score(ytest, ypred) | ||||||
print("MAE", MAE) | ||||||
print("RMSE", RMSE) | ||||||
print("R2", R2) | ||||||
|
||||||
# Visualization | ||||||
plt.scatter(ytest, ypred, s=15, color="royalblue", marker="s", linewidth=1) | ||||||
plt.plot([ytest.min(), ytest.max()], [ytest.min(), ytest.max()], "r-", lw=1) | ||||||
plt.legend(title="R²={:.3f}\n\nMAE={:.3f}".format(R2, MAE)) | ||||||
plt.xlabel("Test Yield(%)") | ||||||
plt.ylabel("Predicted Yield(%)") | ||||||
save_path = "chem.png" | ||||||
plt.savefig(save_path) | ||||||
print(f"Iamge saved to: {save_path}") | ||||||
plt.show() | ||||||
|
||||||
|
||||||
@hydra.main(version_base=None, config_path="./config", config_name="chem.yaml") | ||||||
def main(cfg: DictConfig): | ||||||
global x_train, x_test, y_train, y_test | ||||||
|
||||||
x_train, x_test, y_train, y_test = load_data(cfg) | ||||||
|
||||||
if cfg.mode == "train": | ||||||
train(cfg) | ||||||
elif cfg.mode == "eval": | ||||||
eval(cfg) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 函数名已修改 |
||||||
else: | ||||||
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") | ||||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
defaults: | ||
- ppsci_default | ||
- TRAIN: train_default | ||
- TRAIN/ema: ema_default | ||
- TRAIN/swa: swa_default | ||
- EVAL: eval_default | ||
- INFER: infer_default | ||
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default | ||
- _self_ | ||
|
||
hydra: | ||
run: | ||
# dynamic output directory according to running time and override name | ||
dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} # | ||
job: | ||
name: ${mode} # name of logfile | ||
chdir: false # keep current working directory unchanged | ||
callbacks: | ||
init_callback: # | ||
_target_: ppsci.utils.callbacks.InitCallback # | ||
sweep: | ||
# output directory for multirun | ||
dir: ${hydra.run.dir} | ||
subdir: ./ | ||
|
||
# general settings | ||
mode: train # running mode: train/eval # | ||
seed: 42 # | ||
output_dir: ${hydra:run.dir} # | ||
log_freq: 20 # | ||
use_tbd: false # | ||
data_dir: "./data_set.xlsx" # | ||
|
||
# model settings | ||
MODEL: # | ||
input_dim : 2048 # Assuming x_train is your DataFrame | ||
output_dim : 1 | ||
hidden_dim : 512 | ||
hidden_dim2 : 1024 | ||
hidden_dim3 : 2048 | ||
hidden_dim4 : 1024 | ||
|
||
# training settings | ||
TRAIN: # | ||
epochs: 1500 # | ||
iters_per_epoch: 20 # | ||
# save_freq: 100 # | ||
# eval_during_train: False # | ||
batch_size: 8 # | ||
learning_rate: 0.0001 | ||
save_model_path: './chem_model.pdparams' | ||
# weight_decay: 1e-5 | ||
# pretrained_model_path: null # | ||
# checkpoint_path: null # | ||
# k: 9 | ||
# i: 2 | ||
|
||
# evaluation settings | ||
EVAL: | ||
test_size: 0.1 | ||
load_model_path: './chem_model.pdparams' | ||
seed: 20 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
openpyxl | ||
rdkit | ||
scikit-learn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.