Skip to content

Commit d10859e

Browse files
committed
api refactored.
bugs fixed
1 parent e0fdf30 commit d10859e

File tree

6 files changed

+24
-121
lines changed

6 files changed

+24
-121
lines changed

chytorch/nn/activation.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

chytorch/nn/transformer/attention/graphormer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(self, embed_dim, num_heads, dropout: float = .1, bias: bool = True,
100100
self._register_load_state_dict_pre_hook(_update_packed)
101101
self.o_proj = Linear(embed_dim, embed_dim, bias=bias)
102102

103-
def forward(self, x: Tensor, attn_mask: Optional[Tensor], pad_mask: Optional[Tensor] = None, *,
103+
def forward(self, x: Tensor, attn_mask: Tensor, *,
104104
cache: Optional[Tuple[Tensor, Tensor]] = None,
105105
need_weights: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
106106
if self.separate_proj:
@@ -126,9 +126,7 @@ def forward(self, x: Tensor, attn_mask: Optional[Tensor], pad_mask: Optional[Ten
126126
v = v.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
127127

128128
# BxHxTxE @ BxHxExS > BxHxTxS
129-
a = (q @ k) * self._scale
130-
if attn_mask is not None:
131-
a = a + attn_mask
129+
a = (q @ k) * self._scale + attn_mask
132130
a = softmax(a, dim=-1)
133131
if self.training and self.dropout:
134132
a = dropout(a, self.dropout)

chytorch/nn/transformer/encoder.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,23 +68,23 @@ class EncoderLayer(Module):
6868
"""
6969
def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1, activation=GELU, layer_norm_eps=1e-5,
7070
norm_first: bool = False, attention: Type[Module] = GraphormerAttention, mlp: Type[Module] = MLP,
71-
projection_bias: bool = True, ff_bias: bool = True):
71+
norm_layer: Type[Module] = LayerNorm, projection_bias: bool = True, ff_bias: bool = True):
7272
super().__init__()
7373
self.self_attn = attention(d_model, nhead, dropout, projection_bias)
7474
self.mlp = mlp(d_model, dim_feedforward, dropout, activation, ff_bias)
7575

76-
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
77-
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
76+
self.norm1 = norm_layer(d_model, eps=layer_norm_eps)
77+
self.norm2 = norm_layer(d_model, eps=layer_norm_eps)
7878
self.dropout1 = Dropout(dropout)
7979
self.dropout2 = Dropout(dropout)
8080
self.norm_first = norm_first
8181
self._register_load_state_dict_pre_hook(_update)
8282

83-
def forward(self, x: Tensor, attn_mask: Optional[Tensor], pad_mask: Optional[Tensor] = None, *,
84-
cache: Optional[Tuple[Tensor, Tensor]] = None,
85-
need_embedding: bool = True, need_weights: bool = False) -> Tuple[Optional[Tensor], Optional[Tensor]]:
83+
def forward(self, x: Tensor, attn_mask: Optional[Tensor], *,
84+
need_embedding: bool = True, need_weights: bool = False,
85+
**kwargs) -> Tuple[Optional[Tensor], Optional[Tensor]]:
8686
nx = self.norm1(x) if self.norm_first else x # pre-norm or post-norm
87-
e, a = self.self_attn(nx, attn_mask, pad_mask, cache=cache, need_weights=need_weights)
87+
e, a = self.self_attn(nx, attn_mask, need_weights=need_weights, **kwargs)
8888

8989
if need_embedding:
9090
x = x + self.dropout1(e)
@@ -96,4 +96,4 @@ def forward(self, x: Tensor, attn_mask: Optional[Tensor], pad_mask: Optional[Ten
9696
return None, a
9797

9898

99-
__all__ = ['EncoderLayer']
99+
__all__ = ['EncoderLayer', 'MLP']

chytorch/utils/data/molecule/_unpack.pyx

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ DTYPE = np.int32
3131
ctypedef cnp.int32_t DTYPE_t
3232

3333

34-
cdef extern from "Python.h":
35-
dict _PyDict_NewPresized(Py_ssize_t minused)
36-
37-
3834
# Format specification::
3935
#
4036
# Big endian bytes order
@@ -68,15 +64,15 @@ cdef extern from "Python.h":
6864
@cython.cdivision(True)
6965
@cython.wraparound(False)
7066
def unpack(const unsigned char[::1] data not None, unsigned short add_cls, unsigned short symmetric_attention,
71-
unsigned short components_attention, DTYPE_t max_neighbors, DTYPE_t max_distance, DTYPE_t padding):
67+
unsigned short components_attention, DTYPE_t max_neighbors, DTYPE_t max_distance):
7268
"""
7369
Optimized chython pack to graph tensor converter.
7470
Ignores charge, radicals, isotope, coordinates, bond order, and stereo info
7571
"""
7672
cdef unsigned char a, b, c, hydrogens, neighbors_count
7773
cdef unsigned char *connections
7874

79-
cdef unsigned short atoms_count, bonds_count = 0, order_count = 0, cis_trans_count, padded_count
75+
cdef unsigned short atoms_count, bonds_count = 0, order_count = 0, cis_trans_count
8076
cdef unsigned short i, j, k, n, m
8177
cdef unsigned short[4096] mapping
8278
cdef unsigned int size, shift = 4
@@ -85,9 +81,6 @@ def unpack(const unsigned char[::1] data not None, unsigned short add_cls, unsig
8581
cdef cnp.ndarray[DTYPE_t, ndim=2] distance
8682
cdef DTYPE_t d, attention
8783

88-
cdef object py_n
89-
cdef dict py_mapping
90-
9184
# read header
9285
if data[0] != 2:
9386
raise ValueError('invalid pack version')
@@ -98,12 +91,9 @@ def unpack(const unsigned char[::1] data not None, unsigned short add_cls, unsig
9891
atoms_count = (a << 4| b >> 4) + add_cls
9992
cis_trans_count = (b & 0x0f) << 8 | c
10093

101-
py_mapping = _PyDict_NewPresized(atoms_count)
102-
103-
padded_count = atoms_count + padding
104-
atoms = np.empty(padded_count, dtype=DTYPE)
105-
neighbors = np.zeros(padded_count, dtype=DTYPE)
106-
distance = np.full((padded_count, padded_count), 9999, dtype=DTYPE) # fill with unreachable value
94+
atoms = np.empty(atoms_count, dtype=DTYPE)
95+
neighbors = np.zeros(atoms_count, dtype=DTYPE)
96+
distance = np.full((atoms_count, atoms_count), 9999, dtype=DTYPE) # fill with unreachable value
10797

10898
# allocate memory
10999
connections = <unsigned char*> PyMem_Malloc(atoms_count * sizeof(unsigned char))
@@ -126,7 +116,6 @@ def unpack(const unsigned char[::1] data not None, unsigned short add_cls, unsig
126116
a, b = data[shift], data[shift + 1]
127117
n = a << 4 | b >> 4
128118
mapping[n] = i
129-
py_mapping[n] = i
130119
connections[i] = neighbors_count = b & 0x0f
131120
bonds_count += neighbors_count
132121

@@ -187,13 +176,6 @@ def unpack(const unsigned char[::1] data not None, unsigned short add_cls, unsig
187176
else:
188177
distance[i, j] = distance[j, i] = d + 2
189178

190-
# disable attention on padding
191-
for i in range(atoms_count, padded_count):
192-
atoms[i] = 2 # set explicit hydrogen
193-
for j in range(padded_count):
194-
distance[i, j] = distance[j, i] = 0
195-
distance[i, i] = 1 # self-attention of padding
196-
197179
size = shift + order_count + 4 * cis_trans_count
198180
PyMem_Free(connections)
199-
return atoms, neighbors, distance, size, py_mapping
181+
return atoms, neighbors, distance, size

chytorch/utils/data/molecule/encoder.py

Lines changed: 7 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from torch.nn.utils.rnn import pad_sequence
2828
from torch.utils.data import Dataset
2929
from torchtyping import TensorType
30-
from typing import Sequence, Union, NamedTuple, Optional, Tuple
30+
from typing import Sequence, Union, NamedTuple
3131
from zlib import decompress
3232
from .._abc import default_collate_fn_map
3333

@@ -89,8 +89,7 @@ def collate_molecules(batch, *, padding_left: bool = False, collate_fn_map=None)
8989

9090
class MoleculeDataset(Dataset):
9191
def __init__(self, molecules: Sequence[Union[MoleculeContainer, bytes]], *,
92-
hydrogens: Optional[Sequence[Sequence[Tuple[int, ...]]]] = None, cls_token: int = 1,
93-
max_distance: int = 10, add_cls: bool = True, max_neighbors: int = 14,
92+
cls_token: int = 1, max_distance: int = 10, add_cls: bool = True, max_neighbors: int = 14,
9493
symmetric_attention: bool = True, components_attention: bool = True,
9594
unpack: bool = False, compressed: bool = True, distance_cutoff=None):
9695
"""
@@ -106,7 +105,6 @@ def __init__(self, molecules: Sequence[Union[MoleculeContainer, bytes]], *,
106105
that code unreachable atoms (e.g. salts).
107106
108107
:param molecules: molecules collection
109-
:param hydrogens: shared hydrogen mapping. First element is hydrogen donor, other are acceptors
110108
:param max_distance: set distances greater than cutoff to cutoff value
111109
:param add_cls: add special token at first position
112110
:param max_neighbors: set neighbors count greater than cutoff to cutoff value
@@ -116,10 +114,7 @@ def __init__(self, molecules: Sequence[Union[MoleculeContainer, bytes]], *,
116114
:param compressed: packed molecules are compressed
117115
:param cls_token: idx of cls token
118116
"""
119-
assert hydrogens is None or len(hydrogens) == len(molecules), 'hydrogens and molecules must have the same size'
120-
121117
self.molecules = molecules
122-
self.hydrogens = hydrogens
123118
# distance_cutoff is deprecated
124119
self.max_distance = distance_cutoff if distance_cutoff is not None else max_distance
125120
self.add_cls = add_cls
@@ -132,14 +127,6 @@ def __init__(self, molecules: Sequence[Union[MoleculeContainer, bytes]], *,
132127

133128
def __getitem__(self, item: int) -> MoleculeDataPoint:
134129
mol = self.molecules[item]
135-
136-
if self.hydrogens is not None:
137-
hmap = self.hydrogens[item]
138-
pad = len(hmap)
139-
else:
140-
hmap = None
141-
pad = 0
142-
143130
if self.unpack:
144131
try:
145132
from ._unpack import unpack
@@ -148,22 +135,15 @@ def __getitem__(self, item: int) -> MoleculeDataPoint:
148135
else:
149136
if self.compressed:
150137
mol = decompress(mol)
151-
atoms, neighbors, distances, _, mapping = unpack(mol, self.add_cls, self.symmetric_attention,
152-
self.components_attention, self.max_neighbors,
153-
self.max_distance, pad)
154-
if pad:
155-
for n, da in enumerate(hmap, -pad):
156-
neighbors[mapping[da[0]]] -= 1
157-
for m in da:
158-
m = mapping[m]
159-
distances[n, m] = distances[m, n] = 1
138+
atoms, neighbors, distances, _ = unpack(mol, self.add_cls, self.symmetric_attention,
139+
self.components_attention, self.max_neighbors,
140+
self.max_distance)
160141
if self.add_cls and self.cls_token != 1:
161142
atoms[0] = self.cls_token
162143
return MoleculeDataPoint(IntTensor(atoms), IntTensor(neighbors), IntTensor(distances))
163144

164145
nc = self.max_neighbors
165-
lp = len(mol) + pad
166-
mapping = {}
146+
lp = len(mol)
167147

168148
if self.add_cls:
169149
lp += 1
@@ -176,7 +156,6 @@ def __getitem__(self, item: int) -> MoleculeDataPoint:
176156
ngb = mol._bonds # noqa speedup
177157
hgs = mol._hydrogens # noqa
178158
for i, (n, a) in enumerate(mol.atoms(), self.add_cls):
179-
mapping[n] = i
180159
atoms[i] = a.atomic_number + 2
181160
nb = len(ngb[n]) + (hgs[n] or 0) # treat bad valence as 0-hydrogen
182161
if nb > nc:
@@ -188,23 +167,7 @@ def __getitem__(self, item: int) -> MoleculeDataPoint:
188167
minimum(distances, self.max_distance + 2, out=distances)
189168
distances = IntTensor(distances)
190169

191-
if pad:
192-
atoms[-pad:] = 2 # set explicit hydrogens
193-
tmp = eye(lp, dtype=int32)
194-
if self.add_cls:
195-
tmp[0] = 1 # enable CLS to atom attention
196-
tmp[1:, 0] = 1 if self.symmetric_attention else 0 # enable or disable atom to CLS attention
197-
tmp[1:-pad, 1:-pad] = distances
198-
else:
199-
tmp[:-pad, :-pad] = distances
200-
distances = tmp
201-
202-
for n, da in enumerate(hmap, -pad):
203-
neighbors[mapping[da[0]]] -= 1
204-
for m in da:
205-
m = mapping[m]
206-
distances[n, m] = distances[m, n] = 1
207-
elif self.add_cls:
170+
if self.add_cls:
208171
tmp = ones((lp, lp), dtype=int32)
209172
if not self.symmetric_attention:
210173
tmp[1:, 0] = 0 # disable atom to CLS attention

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = 'chytorch'
3-
version = '1.64'
3+
version = '1.65'
44
description = 'Library for modeling molecules and reactions in torch way'
55
authors = ['Ramil Nugmanov <nougmanoff@protonmail.com>']
66
license = 'MIT'

0 commit comments

Comments
 (0)