Skip to content

Commit 4a34864

Browse files
author
sfluegel
committed
unify spelling of Collator
1 parent 96e9c42 commit 4a34864

File tree

6 files changed

+24
-22
lines changed

6 files changed

+24
-22
lines changed

chebai/preprocessing/collate.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from chebai.preprocessing.structures import XYData
66

77

8-
class Collater:
8+
class Collator:
99
"""Base class for collating data samples into a batch."""
1010

1111
def __init__(self, **kwargs):
@@ -23,8 +23,8 @@ def __call__(self, data: List[Dict]) -> XYData:
2323
raise NotImplementedError
2424

2525

26-
class DefaultCollater(Collater):
27-
"""Default collater that extracts features and labels."""
26+
class DefaultCollator(Collator):
27+
"""Default collator that extracts features and labels."""
2828

2929
def __call__(self, data: List[Dict]) -> XYData:
3030
"""Collate data samples by extracting features and labels.
@@ -39,11 +39,12 @@ def __call__(self, data: List[Dict]) -> XYData:
3939
return XYData(x, y)
4040

4141

42-
class RaggedCollater(Collater):
43-
"""Collater for handling ragged data samples."""
42+
class RaggedCollator(Collator):
43+
"""Collator for handling ragged data samples."""
4444

4545
def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
46-
"""Collate ragged data samples into a batch.
46+
"""Collate ragged data samples (i.e., samples of unequal size such as string representations of molecules) into
47+
a batch.
4748
4849
Args:
4950
data (List[Union[Dict, Tuple]]): List of ragged data samples.

chebai/preprocessing/datasets/pubchem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -901,7 +901,7 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader:
901901
unlabeled_data = unlabeled_data[: self.data_limit]
902902
return DataLoader(
903903
labeled_data + unlabeled_data,
904-
collate_fn=self.reader.collater,
904+
collate_fn=self.reader.collator,
905905
batch_size=self.batch_size,
906906
**kwargs,
907907
)

chebai/preprocessing/reader.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import os
2-
from typing import Any, Dict, List, Optional, Union
2+
from typing import Any, Dict, List, Optional
33

44
from pysmiles.read_smiles import _tokenize
55
from transformers import RobertaTokenizerFast
66
import deepsmiles
77
import selfies as sf
88

9-
from chebai.preprocessing.collate import DefaultCollater, RaggedCollater
9+
from chebai.preprocessing.collate import DefaultCollator, RaggedCollator
1010

1111
EMBEDDING_OFFSET = 10
1212
PADDING_TOKEN_INDEX = 0
@@ -16,15 +16,16 @@
1616

1717
class DataReader:
1818
"""
19-
Base class for reading and preprocessing data.
19+
Base class for reading and preprocessing data. Turns the raw input data (e.g., a SMILES string) into the model
20+
input format (e.g., a list of tokens).
2021
2122
Args:
2223
collator_kwargs: Optional dictionary of keyword arguments for the collator.
2324
token_path: Optional path for the token file.
24-
kwargs: Additional keyword arguments.
25+
kwargs: Additional keyword arguments (not used).
2526
"""
2627

27-
COLLATER = DefaultCollater
28+
COLLATOR = DefaultCollator
2829

2930
def __init__(
3031
self,
@@ -34,7 +35,7 @@ def __init__(
3435
):
3536
if collator_kwargs is None:
3637
collator_kwargs = dict()
37-
self.collater = self.COLLATER(**collator_kwargs)
38+
self.collator = self.COLLATOR(**collator_kwargs)
3839
self.dirname = os.path.dirname(__file__)
3940
self._token_path = token_path
4041

@@ -126,7 +127,7 @@ class ChemDataReader(DataReader):
126127
kwargs: Additional keyword arguments.
127128
"""
128129

129-
COLLATER = RaggedCollater
130+
COLLATOR = RaggedCollator
130131

131132
@classmethod
132133
def name(cls) -> str:
@@ -201,7 +202,7 @@ class ChemDataUnlabeledReader(ChemDataReader):
201202
kwargs: Additional keyword arguments.
202203
"""
203204

204-
COLLATER = RaggedCollater
205+
COLLATOR = RaggedCollator
205206

206207
@classmethod
207208
def name(cls) -> str:
@@ -220,13 +221,13 @@ class ChemBPEReader(DataReader):
220221
Args:
221222
data_path: Path for the pretrained BPE tokenizer.
222223
max_len: Maximum length of the tokenized sequence.
223-
vsize: Vocabulary size for the tokenizer.
224+
vsize: Vocabulary size for the tokenizer (not used).
224225
collator_kwargs: Optional dictionary of keyword arguments for the collator.
225226
token_path: Optional path for the token file.
226227
kwargs: Additional keyword arguments.
227228
"""
228229

229-
COLLATER = RaggedCollater
230+
COLLATOR = RaggedCollator
230231

231232
@classmethod
232233
def name(cls) -> str:
@@ -264,7 +265,7 @@ class SelfiesReader(ChemDataReader):
264265
kwargs: Additional keyword arguments.
265266
"""
266267

267-
COLLATER = RaggedCollater
268+
COLLATOR = RaggedCollator
268269

269270
def __init__(
270271
self,
@@ -309,7 +310,7 @@ class OrdReader(DataReader):
309310
kwargs: Additional keyword arguments.
310311
"""
311312

312-
COLLATER = RaggedCollater
313+
COLLATOR = RaggedCollator
313314

314315
@classmethod
315316
def name(cls) -> str:

chebai/result/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _process_row(self, row):
4545

4646
def _generate_predictions(self, data_path, raw=False, **kwargs):
4747
self._model.eval()
48-
collate = self._reader.COLLATER()
48+
collate = self._reader.COLLATOR()
4949
if raw:
5050
data_tuples = [
5151
(x["features"], x["ident"], self._reader.to_data(self._process_row(x)))

chebai/result/pretraining.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def evaluate_model(logs_base_path, model_filename, data_module):
3131
)
3232
)
3333
assert isinstance(model, electra.ElectraPre)
34-
collate = data_module.reader.COLLATER()
34+
collate = data_module.reader.COLLATOR()
3535
test_file = "test.pt"
3636
data_path = os.path.join(data_module.processed_dir, test_file)
3737
data_list = torch.load(data_path)

chebai/result/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def evaluate_model(
7373
Tensors with predictions and labels.
7474
"""
7575
model.eval()
76-
collate = data_module.reader.COLLATER()
76+
collate = data_module.reader.COLLATOR()
7777

7878
data_list = data_module.load_processed_data("test", filename)
7979
data_list = data_list[: data_module.data_limit]

0 commit comments

Comments
 (0)