Skip to content

Commit 5df4a5c

Browse files
committed
Refactorings
1 parent ca42822 commit 5df4a5c

File tree

5 files changed

+39
-137
lines changed

5 files changed

+39
-137
lines changed

.gitignore

Lines changed: 5 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
1-
# Byte-compiled / optimized / DLL files
21
__pycache__/
32
*.py[cod]
43
*$py.class
5-
6-
# C extensions
74
*.so
8-
9-
# Distribution / packaging
105
.Python
116
build/
127
develop-eggs/
@@ -24,19 +19,12 @@ share/python-wheels/
2419
*.egg-info/
2520
.installed.cfg
2621
*.egg
22+
space-invaders.py
2723
MANIFEST
28-
29-
# PyInstaller
30-
# Usually these files are written by a python script from a template
31-
# before PyInstaller builds the exe, so as to inject date/other infos into it.
3224
*.manifest
3325
*.spec
34-
35-
# Installer logs
3626
pip-log.txt
3727
pip-delete-this-directory.txt
38-
39-
# Unit test / coverage reports
4028
htmlcov/
4129
.tox/
4230
.nox/
@@ -50,115 +38,31 @@ coverage.xml
5038
.hypothesis/
5139
.pytest_cache/
5240
cover/
53-
54-
# Translations
55-
*.mo
56-
*.pot
57-
58-
# Django stuff:
59-
*.log
60-
local_settings.py
61-
db.sqlite3
62-
db.sqlite3-journal
63-
64-
# Flask stuff:
65-
instance/
66-
.webassets-cache
67-
68-
# Scrapy stuff:
69-
.scrapy
70-
71-
# Sphinx documentation
72-
docs/_build/
73-
74-
# PyBuilder
7541
.pybuilder/
7642
target/
77-
78-
# Jupyter Notebook
7943
.ipynb_checkpoints
80-
81-
# IPython
8244
profile_default/
8345
ipython_config.py
84-
85-
# pyenv
86-
# For a library or package, you might want to ignore these files since the code is
87-
# intended to run in multiple environments; otherwise, check them in:
88-
# .python-version
89-
90-
# pipenv
91-
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92-
# However, in case of collaboration, if having platform-specific dependencies or dependencies
93-
# having no cross-platform support, pipenv may install dependencies that don't work, or not
94-
# install all needed dependencies.
95-
#Pipfile.lock
96-
97-
# poetry
98-
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99-
# This is especially recommended for binary packages to ensure reproducibility, and is more
100-
# commonly ignored for libraries.
101-
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102-
#poetry.lock
103-
104-
# pdm
105-
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106-
#pdm.lock
107-
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108-
# in version control.
109-
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
46+
.python-version
47+
poetry.lock
11048
.pdm.toml
11149
.pdm-python
11250
.pdm-build/
113-
114-
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
11551
__pypackages__/
116-
117-
# Celery stuff
118-
celerybeat-schedule
119-
celerybeat.pid
120-
121-
# SageMath parsed files
122-
*.sage.py
123-
124-
# Environments
12552
.env
12653
.venv
12754
env/
12855
venv/
12956
ENV/
13057
env.bak/
13158
venv.bak/
132-
133-
# Spyder project settings
134-
.spyderproject
135-
.spyproject
136-
137-
# Rope project settings
138-
.ropeproject
139-
140-
# mkdocs documentation
14159
/site
142-
143-
# mypy
14460
.mypy_cache/
14561
.dmypy.json
14662
dmypy.json
147-
148-
# Pyre type checker
149-
.pyre/
150-
151-
# pytype static type analyzer
15263
.pytype/
153-
154-
# Cython debug symbols
15564
cython_debug/
156-
157-
# PyCharm
158-
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159-
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160-
# and can be added to the global gitignore or merged into this file. For a more nuclear
161-
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
16265
.idea/
16366
.vscode/
164-
.DS_Store
67+
.DS_Store
68+
.ruff_cache

tests/test_ranker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,5 @@ def test_ranker_bestlayer(small_language_models, trec):
4747
ranker = TransformerRanker(dataset=trec, dataset_downsample=0.05)
4848
result = ranker.run(small_language_models, layer_aggregator='bestlayer')
4949
assert len(str(result).split("\n")) >= 2
50+
51+
assert isinstance(result.layerwise_scores, dict) # see if layer scores are there

transformer_ranker/datacleaner.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,12 @@ class DatasetCleaner:
4040
def prepare_dataset(
4141
self, dataset: Union[str, Dataset, DatasetDict]
4242
) -> tuple[Union[list[str], list[list[str]]], torch.Tensor, TaskCategory]:
43-
"""Prepare texts and labels, assign task category.
43+
"""Prepares texts, labels, and assigns the task category.
4444
45-
Downsample dataset, find text and label columns, create label map,
46-
preprocess labels, pre-tokenize, clean rows, merge text pair columns.
45+
Downsamples dataset, finds text and label columns, cleans empty/noisy rows,
46+
pre-tokenizes texts, merges text pair columns, creates label map for classification.
4747
Returns: (processed texts, label tensor, task category)
4848
"""
49-
50-
# Verify dataset type
5149
if not isinstance(dataset, (Dataset, DatasetDict)):
5250
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
5351

@@ -85,14 +83,12 @@ def prepare_dataset(
8583
if task_category == TaskCategory.TOKEN_CLASSIFICATION and self.remove_bio_encoding:
8684
dataset, label_map = self._remove_bio_encoding(dataset, label_column, label_map)
8785

88-
# Prepare all texts and labels as tensors
8986
texts = dataset[text_column]
9087
labels = dataset[label_column]
9188
if task_category == TaskCategory.TOKEN_CLASSIFICATION:
9289
labels = [word_label for labels in dataset[label_column] for word_label in labels]
9390
labels = torch.tensor(labels)
9491

95-
# Log dataset info
9692
self._log_dataset_info(
9793
text_column,
9894
label_column,

transformer_ranker/ranker.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ def __init__(
2323
**kwargs: Any,
2424
):
2525
"""
26-
Prepares a dataset and compiles metrics to assess transferability.
26+
Prepares huggingface text dataset (downsamples and finds the task category).
27+
Sets up transferability metrics.
2728
2829
:param dataset: a dataset from huggingface with texts and labels.
29-
:param dataset_downsample: a fraction to which the dataset should be reduced.
30+
:param dataset_downsample: a fraction to downsample the dataset to.
3031
:param text_column: the name of the column containing texts.
3132
:param label_column: the name of the column containing labels.
32-
:param kwargs: additional dataset-specific parameters for data cleaning.
33+
:param kwargs: additional parameters for dataset preprocessing.
3334
"""
34-
# Preprocess and down-sample a dataset
3535
datacleaner = DatasetCleaner(
3636
dataset_downsample=dataset_downsample,
3737
text_column=text_column,
@@ -46,6 +46,7 @@ def __init__(
4646
"logme": LogME,
4747
"hscore": HScore,
4848
"knn": NearestNeighbors,
49+
# add more
4950
}
5051

5152
def run(
@@ -57,23 +58,22 @@ def run(
5758
**kwargs: Any,
5859
):
5960
"""
60-
Loads models, collects embeddings, and scores them.
61+
Loads language models, collects embedding from each, and scores them.
6162
6263
:param models: A list of model names
6364
:param estimator: Transferability metric ('hscore', 'logme', 'knn').
64-
:param layer_aggregator: Method to aggregate layers ('lastlayer', 'layermean', 'bestlayer').
65+
:param layer_aggregator: Method to aggregate layers ('layermean', 'lastlayer', 'bestlayer').
6566
:param batch_size: Number of samples per batch, defaults to 32.
6667
:param device: Device for embedding ('cpu', 'cuda', 'cuda:2').
67-
:param gpu_estimation: Store and score embeddings on GPU for speedup.
68+
:param gpu_estimation: Boolean if to compute transferability estimation using gpu.
6869
:param kwargs: Additional parameters for embedder class.
6970
:return: Returns sorted dictionary of model names and their scores
7071
"""
7172
self._confirm_ranker_setup(estimator=estimator, layer_aggregator=layer_aggregator)
7273

73-
# Set device for models and the metric
7474
device = kwargs.pop("device", None)
75-
gpu_estimation = kwargs.get("gpu_estimation", True)
76-
if gpu_estimation:
75+
estimation_using_gpu = kwargs.get("gpu_estimation", True)
76+
if estimation_using_gpu:
7777
self.labels = self.labels.to(device)
7878

7979
# Download models to hf cache
@@ -91,7 +91,6 @@ def run(
9191
else kwargs.get("sentence_pooling", "mean")
9292
)
9393

94-
# Setup the embedder
9594
embedder = Embedder(
9695
model=model,
9796
layer_ids="0" if layer_aggregator == "lastlayer" else "all",
@@ -101,26 +100,27 @@ def run(
101100
**kwargs,
102101
)
103102

104-
# Collect embeddings
103+
# Collect language model embeddings
105104
embeddings = embedder.embed(
106-
self.texts, batch_size=batch_size, unpack_to_cpu=not gpu_estimation, show_progress=True,
105+
self.texts, batch_size=batch_size, unpack_to_cpu=not estimation_using_gpu, show_progress=True,
107106
) # fmt: skip
108107

109-
# Flatten them for ner tasks
110108
if self.task_category == TaskCategory.TOKEN_CLASSIFICATION:
111109
embeddings = [word for sentence in embeddings for word in sentence]
112110

113111
model_name = embedder.name
114-
del embedder # remove from memory
112+
del embedder # remove language model from memory
115113
torch.cuda.empty_cache()
116114

117115
# Compute transferability
118116
score = self._transferability_score(embeddings, metric, layer_aggregator)
119117

120-
# Store and log results
121118
result.add_score(model_name, score)
122119
logger.info(f"{model_name} {result.metric}: {score:.4f}")
123120

121+
logger.info(f"Results ▲\n{result}")
122+
logger.info(f"Done!")
123+
124124
return result
125125

126126
def _transferability_score(self, embeddings, metric, layer_aggregator, show_progress=True) -> float:
@@ -132,7 +132,7 @@ def _transferability_score(self, embeddings, metric, layer_aggregator, show_prog
132132
range(num_layers), desc="Transferability score", bar_format=tqdm_bar_format, disable=not show_progress
133133
)
134134

135-
# Score each layer separately
135+
# Score each hidden layer separately
136136
for layer_id in transferability_progress:
137137
layer_embeddings = torch.stack([emb[layer_id] for emb in embeddings])
138138
score = metric.fit(embeddings=layer_embeddings, labels=self.labels)

transformer_ranker/utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ def prepare_popular_models(model_size="base") -> list[str]:
5757

5858
def configure_logger(name: str, level: int = logging.INFO, log_to_console: bool = True) -> logging.Logger:
5959
"""
60-
Configure the package's logger.
60+
Configure logger for this transferability NLP framework
6161
62-
:param name: The name of the logger.
62+
:param name: The name of the logger "transformer-ranker".
6363
:param level: The logging level (default: logging.INFO).
6464
:param log_to_console: Whether to log to console (default: True)
65-
:return: Configured TransformerRanker logger
65+
:return: Configured logger
6666
"""
6767
logger = logging.getLogger(name)
6868
logger.setLevel(level)
@@ -73,12 +73,12 @@ def configure_logger(name: str, level: int = logging.INFO, log_to_console: bool
7373
console_handler.setFormatter(logging.Formatter("transformer_ranker:%(message)s"))
7474
logger.addHandler(console_handler)
7575

76-
# Suppress future and user warnings
76+
# Suppress future and user warnings
7777
warnings.simplefilter("ignore", category=FutureWarning)
7878
warnings.simplefilter("ignore", category=UserWarning)
7979
transformers_logging.set_verbosity_error()
8080

81-
# Suppress unused weights messages when loading models
81+
# Suppress unused weights messages when loading models (transformers)
8282
logger.addFilter(
8383
lambda record: not (
8484
"Some weights of BertModel were not initialized" in record.getMessage()
@@ -96,21 +96,21 @@ class Result:
9696
def __init__(self, metric: str):
9797
self.metric = metric
9898
self.scores = {}
99-
self.layer_scores = {}
99+
self.layerwise_scores = {}
100100

101-
def add_score(self, model_name: str, score: float, layer_scores: list[float] = None) -> None:
102-
"""Add score for a model."""
101+
def add_score(self, model_name: str, score: float) -> None:
103102
self.scores[model_name] = score
104103

105-
if layer_scores is not None: # only used for the bestlayer option
106-
self.layer_scores[model_name] = layer_scores
104+
def add_layerwise_scores(self, model_name: str, scores: dict) -> None:
105+
"""Only used for the bestlayer option."""
106+
self.layerwise_scores[model_name] = scores
107107

108108
def append(self, other: "Result") -> None:
109109
"""Append scores from multiple runs."""
110110
if self.metric != other.metric:
111111
raise ValueError(f"Metrics do not match ({self.metric} vs {other.metric}).")
112112
self.scores.update(other.scores)
113-
self.layer_scores.update(other.layer_scores)
113+
self.layerwise_scores.update(other.layerwise_scores)
114114

115115
def best_model(self) -> str:
116116
"""Show the model with the highest score."""
@@ -126,6 +126,6 @@ def __str__(self) -> str:
126126
sorted_scores = sorted(self.scores.items(), key=lambda item: item[1], reverse=True)
127127
model_rank = [f"Rank {i+1}. {model}: {score:.4f}" for i, (model, score) in enumerate(sorted_scores)]
128128
return "\n".join(model_rank)
129-
129+
130130
def __repr__(self) -> str:
131131
return self.__str__()

0 commit comments

Comments
 (0)