Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions configs/_base_/datasets/mmlu_fs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from datasets import load_dataset
from mmchat.datasets import process_hf_dataset
from mmengine.dataset import DefaultSampler


data_root = 'data/mmlu/'

mmlu_fs_dataset = dict(
type=load_dataset,
path='json',
data_files=dict(
val=data_root + 'five_shot_mmlu_val.json',
test=data_root + 'five_shot_mmlu_test.json'))

val_mmlu_fs = dict(
type=process_hf_dataset,
dataset=mmlu_fs_dataset,
mode='val')
val_dataloader = dict(
batch_size=1,
num_workers=1,
dataset=val_mmlu_fs,
sampler=dict(type=DefaultSampler, shuffle=False))

test_mmlu_fs = dict(
type=process_hf_dataset,
dataset=mmlu_fs_dataset,
mode='test')
test_dataloader = dict(
batch_size=1,
num_workers=1,
dataset=test_mmlu_fs,
sampler=dict(type=DefaultSampler, shuffle=False))

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

val_evaluator = dict(type='MMLUMetric', tokenizer=None, prefix='mmlu_fs_val')
test_evaluator = dict(type='MMLUMetric', tokenizer=None, prefix='mmlu_fs_test')
39 changes: 39 additions & 0 deletions configs/_base_/datasets/mmlu_zs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from datasets import load_dataset
from mmchat.datasets import process_hf_dataset
from mmengine.dataset import DefaultSampler


data_root = 'data/mmlu/'

mmlu_zs_dataset = dict(
type=load_dataset,
path='json',
data_files=dict(
val=data_root + 'zero_shot_mmlu_val.json',
test=data_root + 'zero_shot_mmlu_test.json'))

val_mmlu_zs = dict(
type=process_hf_dataset,
dataset=mmlu_zs_dataset,
mode='val')
val_dataloader = dict(
batch_size=1,
num_workers=1,
dataset=val_mmlu_zs,
sampler=dict(type=DefaultSampler, shuffle=False))

test_mmlu_zs = dict(
type=process_hf_dataset,
dataset=mmlu_zs_dataset,
mode='test')
test_dataloader = dict(
batch_size=1,
num_workers=1,
dataset=test_mmlu_zs,
sampler=dict(type=DefaultSampler, shuffle=False))

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

val_evaluator = dict(type='MMLUMetric', tokenizer=None, prefix='mmlu_zs_val')
test_evaluator = dict(type='MMLUMetric', tokenizer=None, prefix='mmlu_zs_test')
18 changes: 15 additions & 3 deletions configs/guanaco/gunaco_llama_7B.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
with read_base():
from .._base_.datasets.oasst1 import *
from .._base_.datasets.mmlu_fs import *
from .._base_.schedules.guanaco import *
from .._base_.default_runtime import *

Expand All @@ -20,21 +21,21 @@
use_fast = False,
padding_side="right",
),
source_max_len = 16,
source_max_len = 2048,
target_max_len = 512,
train_on_source = False,
predict_with_generate = False,
),
llm = dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path = '/nvme/share_data/llama-7b',
torch_dtype = torch.float32,
torch_dtype = torch.float16,
quantization_config=dict(
type = BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float32,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type = 'nf4'
)
Expand All @@ -50,3 +51,14 @@

)

val_evaluator['tokenizer'] = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path='/nvme/share_data/llama-7b',
use_fast=False,
padding_side="right")

test_evaluator['tokenizer'] = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path='/nvme/share_data/llama-7b',
use_fast=False,
padding_side="right")
7 changes: 4 additions & 3 deletions mmchat/datasets/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def _prompt_format(example):
dataset = dataset.rename_column(old, new)

# Remove unused columns.
dataset = dataset.remove_columns(
[col for col in dataset.column_names['train'] if col not in ['input', 'output']]
)
if 'train' in dataset.column_names:
dataset = dataset.remove_columns(
[col for col in dataset.column_names['train'] if col not in ['input', 'output']]
)
return dataset[mode]


Expand Down
1 change: 1 addition & 0 deletions mmchat/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .metrics import *
3 changes: 3 additions & 0 deletions mmchat/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .mmlu_metric import MMLUMetric

__all__ = ['MMLUMetric']
202 changes: 202 additions & 0 deletions mmchat/evaluation/metrics/mmlu_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from typing import Any, List, Optional, Sequence, Union
from rich.console import Console
from rich.table import Table

import numpy as np
import torch
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger

from mmchat.registry import METRICS, TOKENIZER


@METRICS.register_module()
class MMLUMetric(BaseMetric):
METAINFO = {
'subcategories': {
"abstract_algebra": ["math"],
"anatomy": ["health"],
"astronomy": ["physics"],
"business_ethics": ["business"],
"clinical_knowledge": ["health"],
"college_biology": ["biology"],
"college_chemistry": ["chemistry"],
"college_computer_science": ["computer science"],
"college_mathematics": ["math"],
"college_medicine": ["health"],
"college_physics": ["physics"],
"computer_security": ["computer science"],
"conceptual_physics": ["physics"],
"econometrics": ["economics"],
"electrical_engineering": ["engineering"],
"elementary_mathematics": ["math"],
"formal_logic": ["philosophy"],
"global_facts": ["other"],
"high_school_biology": ["biology"],
"high_school_chemistry": ["chemistry"],
"high_school_computer_science": ["computer science"],
"high_school_european_history": ["history"],
"high_school_geography": ["geography"],
"high_school_government_and_politics": ["politics"],
"high_school_macroeconomics": ["economics"],
"high_school_mathematics": ["math"],
"high_school_microeconomics": ["economics"],
"high_school_physics": ["physics"],
"high_school_psychology": ["psychology"],
"high_school_statistics": ["math"],
"high_school_us_history": ["history"],
"high_school_world_history": ["history"],
"human_aging": ["health"],
"human_sexuality": ["culture"],
"international_law": ["law"],
"jurisprudence": ["law"],
"logical_fallacies": ["philosophy"],
"machine_learning": ["computer science"],
"management": ["business"],
"marketing": ["business"],
"medical_genetics": ["health"],
"miscellaneous": ["other"],
"moral_disputes": ["philosophy"],
"moral_scenarios": ["philosophy"],
"nutrition": ["health"],
"philosophy": ["philosophy"],
"prehistory": ["history"],
"professional_accounting": ["other"],
"professional_law": ["law"],
"professional_medicine": ["health"],
"professional_psychology": ["psychology"],
"public_relations": ["politics"],
"security_studies": ["politics"],
"sociology": ["culture"],
"us_foreign_policy": ["politics"],
"virology": ["health"],
"world_religions": ["philosophy"],
},
'categories': {
"STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"],
"humanities": ["history", "philosophy", "law"],
"social sciences": ["politics", "culture", "economics", "geography", "psychology"],
"other (business, health, misc.)": ["other", "business", "health"],
},
}
METAINFO['subcategories_list'] = list(set([subcat for subcats in METAINFO['subcategories'].values()
for subcat in subcats]))

def __init__(self, tokenizer, *args, **kwargs):
super().__init__(*args, **kwargs)
self.logger: MMLogger = MMLogger.get_current_instance()
tokenizer = TOKENIZER.build(tokenizer)
self.abcd_idx = [
tokenizer("A", add_special_tokens=False).input_ids[0],
tokenizer("B", add_special_tokens=False).input_ids[0],
tokenizer("C", add_special_tokens=False).input_ids[0],
tokenizer("D", add_special_tokens=False).input_ids[0],
]

@staticmethod
def ABCD_to_0123(abcd):
return {'A': 0, 'B': 1, 'C': 2, 'D': 3}[abcd]

@staticmethod
def accuracy(preds, gts):
"""Computes the accuracy for preds and gts"""
correct = [1 if pred == gt else 0 for pred, gt in zip(preds, gts)]
acc = np.mean(correct) * 100
return acc

def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.

Args:
data_batch (Any): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from
the model.
"""
subjects = data_batch['subject']
gts = [self.ABCD_to_0123(gt) for gt in data_batch['output']]
preds = []
for sample, subject, gt in zip(data_samples, subjects, gts):
pred_logits = sample['logits']
labels = sample['labels']
labels_non_zero_id = (labels != -100).nonzero()[0][0]
pred_logtis_abcd = pred_logits[labels_non_zero_id-1, self.abcd_idx]
pred = torch.argmax(pred_logtis_abcd).item()
preds.append(pred)
self.results.append((subject, pred, gt))

def compute_metrics(self, results: list) -> dict:
"""Compute the metrics from processed results.

Args:
results (list): The processed results of each batch.

Returns:
dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
subjects_results = {subject: {'preds': [], 'gts': []} for subject in self.METAINFO['subcategories'].keys()}
subcats_results = {subcat: {'preds': [], 'gts': []} for subcat in self.METAINFO['subcategories_list']}
cats_results = {cat: {'preds': [], 'gts': []} for cat in self.METAINFO['categories'].keys()}
for subject, pred, gt in results:
subjects_results[subject]['preds'].append(pred)
subjects_results[subject]['gts'].append(gt)
subcats = self.METAINFO['subcategories'][subject]
for subcat in subcats:
subcats_results[subcat]['preds'].append(pred)
subcats_results[subcat]['gts'].append(gt)
for cat, subcats in self.METAINFO['categories'].items():
for subcat in subcats:
if subcat in subcats_results:
cats_results[cat]['preds'].extend(subcats_results[subcat]['preds'])
cats_results[cat]['gts'].extend(subcats_results[subcat]['gts'])

subjects_metrics = dict()
subcats_metrics = dict()
cats_metrics = dict()
for subject in self.METAINFO['subcategories'].keys():
assert len(subjects_results[subject]['preds']) == len(subjects_results[subject]['gts'])
if len(subjects_results[subject]['preds']) == 0:
self.logger.info(f'Skip subject {subject} for mmlu')
else:
score = self.accuracy(subjects_results[subject]['preds'], subjects_results[subject]['gts'])
subjects_metrics[f'{subject}'] = score
for subcat in self.METAINFO['subcategories_list']:
assert len(subcats_results[subcat]['preds']) == len(subcats_results[subcat]['gts'])
if len(subcats_results[subcat]['preds']) == 0:
self.logger.info(f'Skip subcategory {subcat} for mmlu')
else:
score = self.accuracy(subcats_results[subcat]['preds'], subcats_results[subcat]['gts'])
subcats_metrics[f'{subcat}'] = score
for cat in self.METAINFO['categories'].keys():
assert len(cats_results[cat]['preds']) == len(cats_results[cat]['gts'])
if len(cats_results[cat]['preds']) == 0:
self.logger.info(f'Skip category {cat} for mmlu')
else:
score = self.accuracy(cats_results[cat]['preds'], cats_results[cat]['gts'])
cats_metrics[f'{cat}'] = score

metrics = dict()
metrics.update(subjects_metrics)
metrics.update(subcats_metrics)
metrics.update(cats_metrics)
metrics['average'] = np.mean(list(subjects_metrics.values()))

table_metrics = dict()
table_metrics.update(cats_metrics)
table_metrics['average'] = np.mean(list(subjects_metrics.values()))
self._print_results(table_metrics)
return metrics

def _print_results(self, table_metrics: dict) -> None:
table_title = ' MMLU Benchmark '
table = Table(title=table_title)
console = Console()
table.add_column('Categories', justify='left')
table.add_column('Accuracy (%)', justify='right')
for cat, acc in table_metrics.items():
table.add_row(cat, '{:.1f}'.format(acc))
with console.capture() as capture:
console.print(table, end='')
self.logger.info('\n' + capture.get())
Empty file removed mmchat/evaluation/mmlu.py
Empty file.
19 changes: 9 additions & 10 deletions mmchat/models/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ def __init__(self, llm, data_preprocessor):
self.llm = self._build_from_cfg_or_module(llm, LLM)
self.llm.config.use_cache = False
self.llm.config.torch_dtype = torch.float32
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=self.tokenizer,
model=self.llm,
)
if self.tokenizer._pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=self.tokenizer,
model=self.llm,
)
from transformers.models.llama import LlamaTokenizer

if isinstance(self.tokenizer, LlamaTokenizer):
Expand Down Expand Up @@ -110,14 +111,12 @@ def _forward(self, data, data_samples=None):
return outputs

def predict(self, data, data_samples=None):

outputs = self.llm(**data)

return outputs

logits_dict = [{'labels': labels, 'logits': logits} \
for labels, logits in zip(data['labels'], outputs.logits)]
return logits_dict

def compute_loss(self, data, data_samples=None):

outputs = self.llm(**data)
# import pdb;pdb.set_trace()
loss_dict = {'loss_llm': outputs.loss}
Expand Down
2 changes: 1 addition & 1 deletion mmchat/models/utils/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ def forward(self,instances: Sequence[Dict], training=True) -> Dict[str, torch.Te
if labels is not None:
data_dict['labels'] = labels

return {'data': data_dict, 'data_samples': None}
return self.cast_data({'data': data_dict, 'data_samples': None})

Loading