diff --git a/NAG2G/data/graphormer_dataset.py b/NAG2G/data/graphormer_dataset.py index 8eb6382..11cde19 100644 --- a/NAG2G/data/graphormer_dataset.py +++ b/NAG2G/data/graphormer_dataset.py @@ -99,6 +99,20 @@ def __init__(self, product_dataset, reactant_dataset, align_base="product"): self.reactant_dataset = reactant_dataset self.align_base = align_base self.set_epoch(None) + + def get_list(atoms_map_product, atoms_map_reactant): + atoms_map_reactant_dict = { + atoms_map_reactant[i]: i for i in range(len(atoms_map_reactant)) + } + + tmp = [atoms_map_reactant_dict[i] for i in atoms_map_product if i in atoms_map_reactant_dict] + + all_indices = set(range(len(atoms_map_reactant))) + missing_indices = list(all_indices - set(tmp)) + + list_reactant = tmp + missing_indices + + return list_reactant def get_list(self, atoms_map_product, atoms_map_reactant): if self.align_base == "reactant": @@ -111,13 +125,25 @@ def get_list(self, atoms_map_product, atoms_map_reactant): list_product = [atoms_map_product_dict[i] for i in atoms_map_reactant[mask]] elif self.align_base == "product": list_product = None + # atoms_map_reactant_dict = { + # atoms_map_reactant[i]: i for i in range(len(atoms_map_reactant)) + # } + # tmp = [atoms_map_reactant_dict[i] for i in atoms_map_product] + # orders = np.array([i for i in range(len(atoms_map_reactant))]) + # mask = atoms_map_reactant != 0 + # list_reactant = np.concatenate([tmp, orders[~mask]], 0) + atoms_map_reactant_dict = { - atoms_map_reactant[i]: i for i in range(len(atoms_map_reactant)) + atoms_map_reactant[i]: i for i in range(len(atoms_map_reactant)) } - tmp = [atoms_map_reactant_dict[i] for i in atoms_map_product] - orders = np.array([i for i in range(len(atoms_map_reactant))]) - mask = atoms_map_reactant != 0 - list_reactant = np.concatenate([tmp, orders[~mask]], 0) + + tmp = [atoms_map_reactant_dict[i] for i in atoms_map_product if i in atoms_map_reactant_dict] + + all_indices = set(range(len(atoms_map_reactant))) + missing_indices = list(all_indices - set(tmp)) + + list_reactant = tmp + missing_indices + else: raise return list_product, list_reactant diff --git a/NAG2G/data/random_smiles_dataset.py b/NAG2G/data/random_smiles_dataset.py index 5d66ef2..6bf356f 100755 --- a/NAG2G/data/random_smiles_dataset.py +++ b/NAG2G/data/random_smiles_dataset.py @@ -61,14 +61,27 @@ def get_map(self, smi): c_id_list = [atom.GetAtomMapNum() for atom in c_mol.GetAtoms()] return c_id_list, c_mol + # def get_list(self, atoms_map_product, atoms_map_reactant): + # atoms_map_reactant_dict = { + # atoms_map_reactant[i]: i for i in range(len(atoms_map_reactant)) + # } + # tmp = np.array([atoms_map_reactant_dict[i] for i in atoms_map_product]) + # orders = np.array([i for i in range(len(atoms_map_reactant))]) + # mask = np.array(atoms_map_reactant) != 0 + # list_reactant = np.concatenate([tmp, orders[~mask]], 0).tolist() + # return list_reactant def get_list(self, atoms_map_product, atoms_map_reactant): atoms_map_reactant_dict = { atoms_map_reactant[i]: i for i in range(len(atoms_map_reactant)) } - tmp = np.array([atoms_map_reactant_dict[i] for i in atoms_map_product]) - orders = np.array([i for i in range(len(atoms_map_reactant))]) - mask = np.array(atoms_map_reactant) != 0 - list_reactant = np.concatenate([tmp, orders[~mask]], 0).tolist() + + tmp = [atoms_map_reactant_dict[i] for i in atoms_map_product if i in atoms_map_reactant_dict] + + all_indices = set(range(len(atoms_map_reactant))) + missing_indices = list(all_indices - set(tmp)) + + list_reactant = tmp + missing_indices + return list_reactant def __getitem__(self, index: int): diff --git a/NAG2G/utils/graph_process.py b/NAG2G/utils/graph_process.py index d935c25..7eef48b 100644 --- a/NAG2G/utils/graph_process.py +++ b/NAG2G/utils/graph_process.py @@ -474,16 +474,37 @@ def get_graph(mol): return x, edge_index, edge_attr +# def shuffle_graph_process(result, list_=None): +# if list_ is None: +# list_ = [i for i in range(result["atoms"].shape[0])] +# random.shuffle(list_) +# result["atoms"] = result["atoms"][list_] +# result["atoms_map"] = result["atoms_map"][list_] +# result["node_attr"] = result["node_attr"][list_] + +# list_reverse = {i: idx for idx, i in enumerate(list_)} +# for i in range(result["edge_index"].shape[0]): +# for j in range(result["edge_index"].shape[1]): +# result["edge_index"][i, j] = list_reverse[result["edge_index"][i, j]] +# return result + def shuffle_graph_process(result, list_=None): if list_ is None: list_ = [i for i in range(result["atoms"].shape[0])] random.shuffle(list_) + result["atoms"] = result["atoms"][list_] result["atoms_map"] = result["atoms_map"][list_] result["node_attr"] = result["node_attr"][list_] - + list_reverse = {i: idx for idx, i in enumerate(list_)} + for i in range(result["edge_index"].shape[0]): for j in range(result["edge_index"].shape[1]): - result["edge_index"][i, j] = list_reverse[result["edge_index"][i, j]] + if result["edge_index"][i, j] in list_reverse: + result["edge_index"][i, j] = list_reverse[result["edge_index"][i, j]] + else: + print(f"Warning: Index {result['edge_index'][i, j]} not found in list_reverse") + result["edge_index"][i, j] = -1 + return result diff --git a/data_preprocess/lmdb_preprocess.py b/data_preprocess/lmdb_preprocess.py index 49e311e..2544ec4 100644 --- a/data_preprocess/lmdb_preprocess.py +++ b/data_preprocess/lmdb_preprocess.py @@ -73,11 +73,14 @@ def make_lmdb(path_smi, outputfilename): assert tmp_result_2d["atoms"] == result["target_atoms"] result["target_coordinates"] = tmp_result_2d["coordinates"].copy() else: - assert tmp_result_2d["atoms"] == result["target_atoms"] - assert ( - result["target_coordinates"][0].shape - == tmp_result_2d["coordinates"][0].shape - ) + if "target_coordinates" not in result: + print(raw_string) + continue + # assert tmp_result_2d["atoms"] == result["target_atoms"] + # assert ( + # result["target_coordinates"][0].shape + # == tmp_result_2d["coordinates"][0].shape + # ) result["target_coordinates"] += tmp_result_2d["coordinates"].copy() result["target_coordinates"] = np.array(result["target_coordinates"]) diff --git a/valid.sh b/valid.sh index 8d66ccd..2990e94 100755 --- a/valid.sh +++ b/valid.sh @@ -92,7 +92,8 @@ torchrun \ --infer_save_name ${infer_save_name} \ --batch-size $batch_size \ --data-buffer-size ${batch_size} --fixed-validation-seed 11 --batch-size-valid ${batch_size} \ - --config_file $config_file + --config_file $config_file \ + --no_reactant cd NAG2G new_filename=$(echo "$infer_save_name" | sed 's/.txt/_{}.txt/')