Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug causing few-shot to use more than specified number of shots
- Fixed bug in cached_transformer.get() that prevented using override_weights_file arg
- Fixed the `load_weights` arg in cached_transformers.get() which was documented but not implemented
- Fixed support for distributed training with data parallelism (without parallelism for validation)

## [v0.1.0](https://github.com/allenai/catwalk/releases/tag/v0.1.0) - 2022-06-10

Expand Down
13 changes: 11 additions & 2 deletions catwalk/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@ class Model(Registrable, DetHashWithVersion, ABC):
def predict(self, task: Task, instances: Sequence[Dict[str, Any]], **kwargs) -> Iterator[Dict[str, Any]]:
raise NotImplementedError()

def calculate_metrics(self, task: Task, predictions: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
def calculate_metrics(
self,
task: Task,
predictions: Sequence[Dict[str, Any]],
*,
disable_torchmetrics_distributed_sync: bool = False,
**kwargs
) -> Dict[str, torch.Tensor]:
# Annoyingly, torchmetrics only supports tensors as input, not raw values. So we have to convert raw values
# into tensors.
def tensor_args(args: Tuple[Any]) -> Tuple[Any, ...]:
Expand All @@ -40,7 +47,9 @@ def unsqueeze_args(args: Tuple[Any]) -> Tuple[Any, ...]:
fixed_args.append(arg)
return tuple(fixed_args)

metrics = task.make_metrics()
metrics = task.make_metrics(
disable_torchmetrics_distributed_sync=disable_torchmetrics_distributed_sync
)
for prediction in Tqdm.tqdm(predictions, desc="Calculating metrics"):
for metric_name, metric_args in prediction.items():
try:
Expand Down
4 changes: 2 additions & 2 deletions catwalk/models/eleuther.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _run_greedy_until(

return results

def calculate_metrics(self, task: Task, predictions: Sequence[Dict[str, Any]]) -> Dict[str, float]:
def calculate_metrics(self, task: Task, predictions: Sequence[Dict[str, Any]], **kwargs) -> Dict[str, float]:
assert isinstance(task, EleutherTask), "We can only calculate metrics for EleutherTasks."
return {
key: fn([p[key] for p in predictions])
Expand Down Expand Up @@ -431,7 +431,7 @@ def _run_greedy_until(
) -> Sequence:
raise NotImplementedError

def calculate_metrics(self, task: Task, predictions: Sequence[Dict[str, Any]]) -> Dict[str, float]:
def calculate_metrics(self, task: Task, predictions: Sequence[Dict[str, Any]], **kwargs) -> Dict[str, float]:
assert isinstance(task, EleutherTask), "We can only calculate metrics for EleutherTasks."
return {
key: fn([p[key] for p in predictions])
Expand Down
7 changes: 5 additions & 2 deletions catwalk/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DataLoader, TrainingEngine, TrainConfig,
)
from tango.integrations.torch.model import Model as TangoModel
from tango.integrations.torch.util import resolve_device
import torch

from catwalk.task import Task
Expand Down Expand Up @@ -136,6 +137,8 @@ def run(
distributed_port: int = 54761,
train_split: str = "train",
validation_split: Optional[str] = "validation",
validate_every: int = 1000,
checkpoint_every: int = 1000
) -> Model: # type: ignore
if isinstance(model, str):
model = MODELS[model]
Expand Down Expand Up @@ -169,8 +172,8 @@ def run(
validation_steps=validation_steps,
train_split="train",
validation_split=None if validation_split is None else "validation",
validate_every=1000,
checkpoint_every=1000,
validate_every=validate_every,
checkpoint_every=checkpoint_every,
grad_accum=grad_accum,
is_distributed=is_distributed,
world_size=num_workers,
Expand Down
10 changes: 7 additions & 3 deletions catwalk/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class Task(Registrable, ABC):
def __init__(self, *, version_override: Optional[str] = None):
if version_override is not None:
self.VERSION = version_override
self.metrics: Dict[str, Callable[[], torchmetrics.Metric]] = {}
self.metrics: Dict[str, Callable[..., torchmetrics.Metric]] = {}
self.instance_conversions: Dict[InstanceFormat, InstanceConversion] = {}

def det_hash_object(self) -> Any:
Expand Down Expand Up @@ -106,9 +106,13 @@ def fewshot_instances_split(self) -> str:
return split_name
raise ValueError("This task has no split to take fewshot instances from.")

def make_metrics(self) -> Dict[str, torchmetrics.Metric]:
def make_metrics(
self,
*,
disable_torchmetrics_distributed_sync: bool = False
) -> Dict[str, torchmetrics.Metric]:
return {
name: metric_fn()
name: metric_fn(sync_on_compute= not disable_torchmetrics_distributed_sync)
for name, metric_fn in self.metrics.items()
}

Expand Down
13 changes: 12 additions & 1 deletion catwalk/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse

from tango.integrations.torch.util import resolve_device

from tango import Workspace
from tango.common.logging import initialize_logging

Expand All @@ -15,6 +17,7 @@ def main():
parser.add_argument("--task", type=str, nargs="+")
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--grad_acc", type=int, default=1)
parser.add_argument("--device_count", type=int, default=1)
parser.add_argument(
"-d",
"-w",
Expand All @@ -41,8 +44,16 @@ def main():
except KeyError:
tasks.add(task)

model_step = FinetuneStep(model=args.model, tasks=tasks, batch_size=args.batch_size, grad_accum=args.grad_acc)
model_step = FinetuneStep(
model=args.model,
tasks=tasks,
batch_size=args.batch_size,
grad_accum=args.grad_acc,
device_count=args.device_count
)

# Resolve to single device, because distributed evaluation is not supported
model_step = model_step.result().to(resolve_device())
metric_task_dict = {}
for task in tasks:
predictions = PredictStep(model=model_step, task=task, batch_size=args.batch_size)
Expand Down
7 changes: 5 additions & 2 deletions catwalk/training_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,19 @@ def __init__(
def post_val_loop(
self, step: int, epoch: int, val_metric: float, best_val_metric: float
) -> None:
if not self.is_local_main_process:
return

model_was_training = self.model.training
self.model.eval()
try:
catwalk_model = cast(Model, self.model)
catwalk_model = cast(Model, self.model.module) if hasattr(self.model, "module") else cast(Model, self.model)
for task in self.tasks:
instances = task.get_split(self.eval_split)
if self.eval_limit is not None:
instances = instances[:self.eval_limit]
predictions = catwalk_model.predict(task, instances)
metrics = catwalk_model.calculate_metrics(task, list(predictions))
metrics = catwalk_model.calculate_metrics(task, list(predictions), disable_torchmetrics_distributed_sync=True)
metrics_string = []
for metric_name, metric_value in metrics.items():
if len(metric_value.shape) > 0:
Expand Down