Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions catwalk/models/rank_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def __init__(
self.likelihood_averaging = likelihood_averaging
self.model_kwargs = model_kwargs

@classmethod
def _make_model(cls, pretrained_model_name_or_path: str, *, make_copy: bool = False, **kwargs) -> _Model:
def _make_model(self, *, make_copy: bool = False, **kwargs) -> _Model:
raise NotImplementedError

def predict( # type: ignore
Expand All @@ -61,12 +60,12 @@ def predict( # type: ignore
) -> Iterator[Dict[str, Any]]:
device = resolve_device()
try:
model = self._make_model(self.pretrained_model_name_or_path, **self.model_kwargs).to(device).eval()
model = self._make_model(**self.model_kwargs).to(device).eval()
except RuntimeError as e:
if not str(e).startswith('CUDA out of memory.'):
raise e
self.model_kwargs['device_map'] = "auto"
model = self._make_model(self.pretrained_model_name_or_path, **self.model_kwargs).eval()
model = self._make_model(**self.model_kwargs).eval()
tokenizer = cached_transformers.get_tokenizer(AutoTokenizer, self.pretrained_model_name_or_path)

for instance_chunk in more_itertools.chunked(instances, max_instances_in_memory):
Expand Down Expand Up @@ -133,7 +132,7 @@ def _run_loglikelihood(

def trainable_copy(self, **kwargs) -> TrainableModel:
return TrainableRankClassificationModel(
self._make_model(self.pretrained_model_name_or_path, make_copy=True, **self.model_kwargs),
self._make_model(make_copy=True, **self.model_kwargs),
cached_transformers.get_tokenizer(AutoTokenizer, self.pretrained_model_name_or_path),
self.predict_chunk)

Expand Down Expand Up @@ -224,9 +223,8 @@ def collate_for_training(self, instances: Sequence[Tuple[Task, Instance]]) -> An
class EncoderDecoderRCModel(RankClassificationModel):
VERSION = RankClassificationModel.VERSION + "002spt"

@classmethod
def _make_model(cls, pretrained_model_name_or_path: str, *, make_copy: bool = False, **kwargs) -> T5ForConditionalGeneration:
return cached_transformers.get(AutoModelForSeq2SeqLM, pretrained_model_name_or_path, make_copy=make_copy, **kwargs)
def _make_model(self, *, make_copy: bool = False, **kwargs) -> T5ForConditionalGeneration:
return cached_transformers.get(AutoModelForSeq2SeqLM, self.pretrained_model_name_or_path, make_copy=make_copy, **kwargs)

def _run_loglikelihood(
self,
Expand Down Expand Up @@ -289,9 +287,8 @@ def _run_loglikelihood(

@Model.register("rc::decoder_only")
class DecoderOnlyRCModel(RankClassificationModel):
@classmethod
def _make_model(cls, pretrained_model_name_or_path: str, *, make_copy: bool = False, **kwargs) -> GPT2LMHeadModel:
return cached_transformers.get(AutoModelForCausalLM, pretrained_model_name_or_path, make_copy=make_copy, **kwargs)
def _make_model(self, *, make_copy: bool = False, **kwargs) -> GPT2LMHeadModel:
return cached_transformers.get(AutoModelForCausalLM, self.pretrained_model_name_or_path, make_copy=make_copy, **kwargs)

@staticmethod
def _prefix_with_space(s: str) -> str:
Expand Down
10 changes: 7 additions & 3 deletions catwalk/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def massage_kwargs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:

def run(
self,
model: Union[str, Model],
model: Union[str, Lazy[Model], Model],
task: Union[str, Task],
split: Optional[str] = None,
limit: Optional[int] = None,
Expand All @@ -60,6 +60,8 @@ def run(
) -> Sequence[Any]:
if isinstance(model, str):
model = MODELS[model]
elif isinstance(model, Lazy):
model = model.construct()
if isinstance(task, str):
task = TASKS[task]
if split is None:
Expand Down Expand Up @@ -89,12 +91,14 @@ def massage_kwargs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:

def run(
self,
model: Union[str, Model],
model: Union[str, Lazy[Model], Model],
task: Union[str, Task],
predictions: Sequence[Any]
) -> Dict[str, float]:
if isinstance(model, str):
model = MODELS[model]
elif isinstance(model, Lazy):
model = model.construct()
if isinstance(task, str):
task = TASKS[task]

Expand Down Expand Up @@ -316,4 +320,4 @@ def run(self, metrics: Dict[str, Dict[str, float]], format: str = "text") -> Ite
raise NotImplementedError()
else:
raise AttributeError("At the moment, only the 'text' format is supported.")