Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions skfp/fingerprints/getaway.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,80 @@ def __init__(
)
self.clip_val = clip_val

def get_feature_names_out(self, input_features=None) -> np.ndarray: # noqa: ARG002
"""
Get fingerprint output feature names. They correspond to various
descriptors derived from weighted Molecular Influence Matrix (MIM).
Definitions are complex and long, see references given in main body
for explanations, particularly Todeschini & Consonni book.

Parameters
----------
input_features : array-like of str or None, default=None
Unused, kept for scikit-learn compatibility.

Returns
-------
feature_names_out : ndarray of str objects
GETAWAY feature names.
"""
weighting_variants = [
"unweighted",
"atomic mass",
"van der Waals volume",
"electronegativity",
"polarizability",
"ion polarity",
"IState",
]

feature_names = [
"total information content on the leverage equality (ITH)",
"standardized information content on the leverage equality (ISH)",
"mean information content on the leverage magnitude (HIC)",
"geometric mean of the leverage magnitude (HGM)",
]

for weighting in weighting_variants:
h_indices = [f"{weighting} H index radius {radius}" for radius in range(9)]
total_h_index = f"{weighting} total H index"
feature_names.extend(h_indices)
feature_names.append(total_h_index)

hats_indices = [
f"{weighting} H matrix autocorrelation (HATS) radius {radius}"
for radius in range(9)
]
total_hats_index = f"{weighting} total HATS index"
feature_names.extend(hats_indices)
feature_names.append(total_hats_index)

feature_names.extend(
[
"R-connectivity index (RCON)",
"average row sum of the influence/distance matrix (RARS",
"R-matrix leading eigenvalue (REIG)",
]
)

for weighting in weighting_variants:
r_indices = [
f"{weighting} R index radius {radius}" for radius in range(1, 9)
]
total_r_index = f"{weighting} R total index"
feature_names.extend(r_indices)
feature_names.append(total_r_index)

maximal_r_indices = [
f"{weighting} maximal R index (R+) radius {radius}"
for radius in range(1, 9)
]
total_maximal_r_index = f"{weighting} maximal R total index"
feature_names.extend(maximal_r_indices)
feature_names.append(total_maximal_r_index)

return np.asarray(feature_names, dtype=object)

def transform(
self, X: Sequence[str | Mol], copy: bool = False
) -> np.ndarray | csr_array:
Expand Down
39 changes: 39 additions & 0 deletions tests/fingerprints/getaway.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,42 @@ def test_getaway_sparse_fingerprint(mols_conformers_list):

assert X_skfp.shape == (len(mols_conformers_list), 273)
assert np.issubdtype(X_skfp.dtype, np.floating)


def test_getaway_feature_names():
getaway_fp = GETAWAYFingerprint()
feature_names = getaway_fp.get_feature_names_out()

assert len(feature_names) == getaway_fp.n_features_out
assert len(feature_names) == len(set(feature_names))

# for indices understanding, see:
# https://github.com/rdkit/rdkit/blob/53a01430e057d34cf7a7ca52eb2d8b069178114d/Code/GraphMol/Descriptors/GETAWAY.cpp#L1145
# and slide 28 from https://github.com/rdkit/UGM_2017/blob/master/Presentations/Godin_3D_Descriptors.pdf
# note that here we index from zero

assert "ITH" in feature_names[0]
assert "ISH" in feature_names[1]
assert "HIC" in feature_names[2]
assert "HGM" in feature_names[3]

assert all(
"total H index" in feature_names[idx] for idx in [13, 33, 53, 73, 93, 113, 133]
)
assert all(
"total HATS index" in feature_names[idx]
for idx in [23, 43, 63, 83, 103, 123, 143]
)

assert "RCON" in feature_names[144]
assert "RARS" in feature_names[145]
assert "REIG" in feature_names[146]

assert all(
"R total index" in feature_names[idx]
for idx in [155, 173, 191, 209, 227, 254, 263]
)
assert all(
"maximal R total index" in feature_names[idx]
for idx in [164, 182, 200, 218, 236, 254, 272]
)
Loading