Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/xpmir/learning/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,13 @@ def load_bestcheckpoint(self, max_epoch: int):

return False

@staticmethod
def get_checkpoint_path(checkpointspath: Path, epoch: int) -> Path:
return checkpointspath / f"{TrainerContext.PREFIX}{epoch:08d}"

def save_checkpoint(self):
# Serialize
path = self.path / f"{TrainerContext.PREFIX}{self.epoch:08d}"
path = TrainerContext.get_checkpoint_path(self.path, self.epoch)
if self.state.path is not None:
# No need to save twice
return
Expand Down
17 changes: 14 additions & 3 deletions src/xpmir/learning/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import torch
from pathlib import Path
from typing import Dict, Iterator, List, NamedTuple, Any
from typing import Dict, Iterator, List, NamedTuple, Any, Optional
from experimaestro import (
Task,
Config,
Expand Down Expand Up @@ -36,6 +36,11 @@

logger = easylog()

class CheckpointModuleLoader(ModuleLoader):
"""Useful to load a specific checkpoint"""

epoch: Param[Optional[int]] = None
"""The epoch of the checkpoint"""

class LearnerListenerStatus(Enum):
NO_DECISION = 0
Expand Down Expand Up @@ -78,12 +83,15 @@ def init_task(self, learner: "Learner", dep):
class LearnerOutput(NamedTuple):
"""The data structure for the output of a learner. It contains a dictionary
where the key is the name of the listener and the value is the output of
that listener"""
that listener. It also allows to access the checkpoints saved during
the training"""

listeners: Dict[str, Any]

learned_model: ModuleLoader

checkpoints: Dict[str, Any]


class Learner(Task, EasyLogger):
"""Model Learner
Expand Down Expand Up @@ -157,11 +165,14 @@ def task_outputs(self, dep) -> LearnerOutput:
for listener in self.listeners
},
learned_model=dep(
ModuleLoader.C(
CheckpointModuleLoader.C(
value=self.model,
path=self.last_checkpoint_path / TrainState.MODEL_PATH,
)
),
checkpoints={
interval: dep(CheckpointModuleLoader.C(value=self.model, path=TrainerContext.get_checkpoint_path(self.checkpointspath, interval) / TrainState.MODEL_PATH, epoch=interval)) for interval in range(0, self.max_epochs, self.checkpoint_interval)
},
)

@property
Expand Down
45 changes: 41 additions & 4 deletions src/xpmir/neural/cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ class CrossScorer(LearnableScorer, DistributableModel):
AKA Cross-Encoder
"""

max_length: Param[int]
"""Maximum length (in tokens) for the query-document string"""

encoder: Param[TextEncoderBase[Tuple[str, str], torch.Tensor]]
"""an encoder for encoding the concatenated query-document tokens which
doesn't contains the final linear layer"""
Expand All @@ -44,7 +41,6 @@ def __initialize__(self, options):
super().__initialize__(options)
self.encoder.initialize(options)
self.classifier = torch.nn.Linear(self.encoder.dimension, 1)
self.tokenizer_options = TokenizerOptions(max_length=self.max_length)

def forward(self, inputs: BaseRecords, info: TrainerContext = None):
# Encode queries and documents
Expand Down Expand Up @@ -111,3 +107,44 @@ def getRetriever(
device=device,
top_k=top_k,
)


class MiniLMCrossScorer(LearnableScorer, DistributableModel):
"""Query-Document Representation Classifier

Based on a query-document representation representation (e.g. BERT [CLS] token).
AKA Cross-Encoder
"""

encoder: Param[TextEncoderBase[Tuple[str, str], torch.Tensor]]
"""an encoder for encoding the concatenated query-document tokens which
doesn't contains the final linear layer"""

def __validate__(self):
super().__validate__()
assert not self.encoder.static(), "The vocabulary should be learnable"

def __initialize__(self, options):
super().__initialize__(options)
self.encoder.initialize(options)

# Equivalent to classifier
self.dropout = torch.nn.Dropout(0.1, inplace=False)
self.classifier = torch.nn.Linear(self.encoder.dimension, 1)

def forward(self, inputs: BaseRecords, info: TrainerContext = None):
# Encode queries and documents
pairs = self.encoder(
[
(tr[TextItem].text, dr[TextItem].text)
for tr, dr in zip(inputs.topics, inputs.documents)
],
# options=self.tokenizer_options,
) # shape (batch_size * dimension)

# Classifier
output = self.dropout(pairs.value)
return self.classifier(output).squeeze(1)

def distribute_models(self, update):
self.encoder = update(self.encoder)
11 changes: 8 additions & 3 deletions src/xpmir/neural/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,14 @@ def forward(self, inputs: BaseRecords, info: TrainerContext = None):
tokenized = self.batch_tokenize(inputs, maxlen=self.max_length, mask=True)
# strange that some existing models on the huggingface don't use the token_type
with torch.set_grad_enabled(torch.is_grad_enabled()):
result = self.model(
tokenized.ids, token_type_ids=tokenized.token_type_ids.to(self.device), attention_mask=tokenized.mask.to(self.device)
).logits # Tensor[float] of length records size
if tokenized.token_type_ids is None:
result = self.model(
tokenized.ids, attention_mask=tokenized.mask.to(self.device)
).logits
else:
result = self.model(
tokenized.ids, token_type_ids=tokenized.token_type_ids.to(self.device), attention_mask=tokenized.mask.to(self.device)
).logits # Tensor[float] of length records size
return result

def distribute_models(self, update):
Expand Down
10 changes: 5 additions & 5 deletions src/xpmir/papers/helpers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def msmarco_v1_docpairs_sampler(
:param sample_rate: Sample rate for the triplets (default 1)
"""
topics = prepare_dataset("irds.msmarco-passage.train.queries")
train_triples = prepare_dataset("irds.msmarco-passage.train.docpairs")
triplets = ShuffledTrainingTripletsLines.C(
train_triples = prepare_dataset("irds.msmarco-passage.train.triples-v2.docpairs")
triplets = ShuffledTrainingTripletsLines(
seed=123,
data=StoreTrainingTripletTopicAdapter.C(data=train_triples, store=topics),
sample_rate=sample_rate,
Expand Down Expand Up @@ -91,7 +91,7 @@ def msmarco_v1_docpairs_efficient_sampler(
:param sample_rate: Sample rate for the triplets (default 1)
"""
topics = prepare_dataset("irds.msmarco-passage.train.queries")
train_triples = prepare_dataset("irds.msmarco-passage.train.docpairs")
train_triples = prepare_dataset("irds.msmarco-passage.train.triples-v2.docpairs")
triplets = ShuffledTrainingTripletsLines.C(
seed=seed,
data=StoreTrainingTripletTopicAdapter.C(data=train_triples, store=topics),
Expand All @@ -102,8 +102,8 @@ def msmarco_v1_docpairs_efficient_sampler(
).submit(launcher=launcher)

# Builds the sampler by hydrating documents
sampler = TripletBasedSampler(source=triplets)
hydrator = SampleHydrator(
sampler = TripletBasedSampler.C(source=triplets)
hydrator = SampleHydrator.C(
documentstore=prepare_collection("irds.msmarco-passage.documents")
)

Expand Down
17 changes: 3 additions & 14 deletions src/xpmir/text/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ def static(self) -> bool:
"""
return True

def maxtokens(self) -> Optional[int]:
"""Maximum number of tokens that can be processed"""
return None


LegacyEncoderInput = Union[List[str], List[Tuple[str, str]], List[Tuple[str, str, str]]]

Expand All @@ -127,10 +123,6 @@ def dimension(self) -> int:
"""Returns the dimension of the output space"""
raise NotImplementedError()

def max_tokens(self):
"""Returns the maximum number of tokens this encoder can process"""
return sys.maxsize


class TextEncoder(TextEncoderBase[str, torch.Tensor]):
"""Encodes a text into a vector
Expand Down Expand Up @@ -217,11 +209,6 @@ class TokenizedEncoder(Encoder, Generic[EncoderOutput, TokenizerOutput]):
def forward(self, inputs: TokenizerOutput) -> EncoderOutput:
pass

@property
def max_length(self):
"""Returns the maximum length that the model can process"""
return sys.maxsize


class TokenizedTextEncoderBase(TextEncoderBase[InputType, EncoderOutput]):
@abstractmethod
Expand Down Expand Up @@ -255,6 +242,8 @@ def forward(
self, inputs: List[InputType], *args, options: Optional[TokenizerOptions] = None
) -> EncoderOutput:
assert len(args) == 0, "Unhandled extra arguments"
options = options or TokenizerOptions()
options.max_length = min(self.tokenizer.max_length, options.max_length if options.max_length else sys.maxsize)
tokenized = self.tokenizer.tokenize(inputs, options)
return self.forward_tokenized(tokenized, *args)

Expand All @@ -265,7 +254,7 @@ def tokenize(
self, inputs: List[InputType], options: Optional[TokenizerOptions] = None
):
options = options or TokenizerOptions()
options.max_length = min(self.encoder.max_length, options.max_length or None)
options.max_length = min(self.tokenizer.max_length, options.max_length if options.max_length else sys.maxsize)
return self.tokenizer.tokenize(inputs, options)

def static(self):
Expand Down
4 changes: 2 additions & 2 deletions src/xpmir/text/huggingface/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def from_pretrained_id(cls, model_id: str, **kwargs):
:param kwargs: keyword arguments passed to the model constructor
:return: A hugging-fasce based encoder
"""
return cls(model=HFModel.from_pretrained_id(model_id), **kwargs)
return cls.C(model=HFModel.from_pretrained_id(model_id), **kwargs)

def __initialize__(self, options):
super().__initialize__(options)
Expand All @@ -42,7 +42,7 @@ def dimension(self):
@property
def max_length(self):
"""Returns the maximum length that the model can process"""
return sys.maxsize
return self.model.hf_config.max_position_embeddings


class HFTokensEncoder(
Expand Down
13 changes: 8 additions & 5 deletions src/xpmir/text/huggingface/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)

try:
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoConfig
except Exception:
logging.error("Install huggingface transformers to use these configurations")
raise
Expand All @@ -33,7 +33,7 @@ class HFTokenizer(Config, Initializable):
"""The tokenizer hugginface ID"""

max_length: Param[int] = 4096
"""Maximum length for the tokenizer (can be overridden by the model)"""
"""Maximum length for the tokenizer (can be overridden at inference)"""

DEFAULT_OPTIONS = TokenizerOptions()

Expand All @@ -53,8 +53,11 @@ def __initialize__(self, options: ModuleInitOptions):
"Could not find saved tokenizer in %s, using HF loading", path
)

# Load config to read `max_position_embeddings` as proxy for `max_length`
self.config = AutoConfig.from_pretrained(model_id_or_path)

self.tokenizer = AutoTokenizer.from_pretrained(
model_id_or_path, model_max_length=self.max_length
model_id_or_path, model_max_length=min(self.max_length, self.config.max_position_embeddings)
)

self.cls = self.tokenizer.cls_token
Expand All @@ -69,7 +72,7 @@ def tokenize(
options = options or HFTokenizer.DEFAULT_OPTIONS
max_length = options.max_length
if max_length is None:
max_length = self.tokenizer.model_max_length
max_length = self.maxtokens()
else:
max_length = min(max_length, self.maxtokens())

Expand Down Expand Up @@ -130,7 +133,7 @@ def __initialize__(self, options: ModuleInitOptions):

@classmethod
def from_pretrained_id(cls, hf_id: str, **kwargs):
return cls(tokenizer=HFTokenizer(model_id=hf_id), **kwargs)
return cls.C(tokenizer=HFTokenizer.C(model_id=hf_id), **kwargs)

def vocabulary_size(self):
return self.tokenizer.vocab_size
Expand Down
6 changes: 6 additions & 0 deletions src/xpmir/text/tokenizers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
import sys
from typing import List, NamedTuple, Optional, TypeVar, Generic
from attr import define
import re
Expand Down Expand Up @@ -198,3 +199,8 @@ def id2tok(self, idx: int) -> str:
Converts an integer id to a token
"""
...

@property
def max_length(self) -> int:
"""Returns the maximum number of tokens this tokenizer can process"""
return sys.maxsize
Loading