Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions bsmetadata/deepspeed_configs/v2.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@
"reduce_scatter": true,
"reduce_bucket_size": 500000000,
"contiguous_gradients": true,
"cpu_offload": false
"cpu_offload": true
},
"gradient_accumulation_steps": 1,
"gradient_accumulation_steps": 16,
"gradient_clipping": "auto",
"steps_per_print": 100,
"train_batch_size": 256,
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
}
221 changes: 205 additions & 16 deletions bsmetadata/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,50 @@
import argparse
import functools
import itertools
import json
from typing import Dict

import rich
import torch
import torch.nn.functional as F
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from rich.text import Text
from tqdm.auto import tqdm

from bsmetadata.metadata_utils import add_metadata_and_chunk_examples
from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

from bsmetadata.metadata_utils import add_metadata_and_chunk_examples

def format_by_one_mask(input_ids, mask, tokenizer):
i = 0
data = []
for key, igroup in itertools.groupby(mask):
size = len(list(igroup))
text = tokenizer.decode(input_ids[i : i + size])
i += size
data.append((text, "green" if key else None))
return Text.assemble(*data)


@torch.no_grad()
def ppl_fn(
batch: Dict[str, torch.Tensor], outputs: CausalLMOutputWithCrossAttentions, metadata_mask: torch.Tensor = None
batch: Dict[str, torch.Tensor],
outputs: CausalLMOutputWithCrossAttentions,
metadata_mask: torch.Tensor = None,
save_data: bool = False,
idx: int = None,
) -> torch.Tensor:
"""Calculates the perplexity for a given batch.

Args:
batch: A dict with keys "input_ids" and "attention_mask".
outputs: The model outputs for the batch.
metadata_mask: 1 for tokens corresponding to metadata and 0 for all other tokens.
save_data: Whether to tokens & losses.
idx: The index of the batch.

Returns:
The perplexity of the given batch.
Expand All @@ -35,21 +56,88 @@ def ppl_fn(

shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

if metadata_mask is not None:
loss_mask = torch.logical_and(attention_mask, ~metadata_mask)
metadata_mask = metadata_mask.bool()
nonmetadata_cumsum = torch.cumsum(~metadata_mask, dim=-1)
first_nonmetadata = nonmetadata_cumsum == 1
rich.print(f"{~(metadata_mask.bool())=}")
rich.print("(attention_mask.bool())")
rich.print(attention_mask.bool())
loss_mask = torch.logical_and(attention_mask.bool(), ~(metadata_mask.bool()))
loss_mask = torch.logical_and(loss_mask, ~first_nonmetadata)
rich.print(f"{loss_mask=}")
else:
loss_mask = attention_mask

loss_mask = attention_mask.bool()
shift_mask = loss_mask[..., 1:].contiguous()

"""

max len: 10
(label, by convention, is unshifted)
label: a b c d e f g x x x
input: a b c d e f g x x x
mask : 1 1 1 1 1 1 1 0 0 0

shift label : b c d e f g x x x
shift logit : a b c d e f g x x
shift a mask: 1 1 1 1 1 1 0 0 0


calculated part
input: a b c d e f
label: b c d e f g

metdata example:
label : M M a b c d e f g x
input : M M a b c d e f g x
a mask: 1 1 1 1 1 1 1 1 1 0
m mask: 1 1 0 0 0 0 0 0 0 0
a & !m: 0 0 1 1 1 1 1 1 1 0

shift label : M a b c d e f g x
shift logit : M M a b c d e f g
shift a mask: 1 1 1 1 1 1 1 1 0
shift (a&!m): 0 1 1 1 1 1 1 1 0
diff (bug) : x

# fix: mask out the loss if ((the source token is metadata) or (the target token is padding))
#

shift m mask:
ideal mask :
"""

# if metadata_mask is not None:
# shift_metadata_mask = metadata_mask[..., 1:].contiguous().bool()
# shift_mask = torch.logical_and(shift_mask, ~shift_metadata_mask)
rich.print(f"shift_mask{shift_mask}")
rich.print(f"{shift_mask.sum()=}")

# Flatten the tokens
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
reduction="none",
).view(b, -1)

if save_data:
# Save the non-masked tokens & their loss
suffix = "_meta" if metadata_mask is not None else ""
torch.save(
batch["input_ids"],
f"{idx}_input_ids{suffix}.pt",
)
torch.save(
loss.cpu().squeeze(),
f"{idx}_loss{suffix}.pt",
)

loss = loss.cpu().squeeze().numpy().tolist()
shift_mask = shift_mask.cpu().squeeze().numpy().tolist()

return loss, shift_mask, shift_labels.cpu().squeeze().numpy().tolist()
return loss, shift_mask

# Normalize to avoid an overflow when there are many tokens
normed_loss_weights = shift_mask / shift_mask.sum()
loss = (loss * normed_loss_weights).sum()
Expand All @@ -61,23 +149,28 @@ def ppl_fn(


@torch.no_grad()
def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor:
def get_ppl(
batch: Dict[str, torch.Tensor],
save_data: bool = False,
idx: int = None,
) -> torch.Tensor:
"""Prepares the arguments for perplexity calculation and passes them to the perplexity function.

Args:
batch: A dict with keys "input_ids", "attention_mask" and "metadata_mask", where:
- the input ids are a list of token ids corresponding to the input text with metadata;
- the attention mask is 0 for padding tokens and 1 everywhere else;
- the metadata mask is 1 for tokens corresponding to metadata and 0 for all other tokens.

save_data: Whether to save tokens & losses
idx: The index of the batch for saving
Returns:
The perplexity of the given batch.
"""
labels = batch.pop("labels")
metadata_mask = batch.pop("metadata_mask", None)
outputs = model(**batch)
batch["labels"] = labels
ppl = ppl_fn(batch, outputs, metadata_mask)
ppl = ppl_fn(batch, outputs, metadata_mask, save_data=save_data, idx=idx)
return ppl


Expand All @@ -104,18 +197,50 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor:
action="store_true",
help="If set to true, the script runs in test mode and only takes 10 examples per dataset",
)
parser.add_argument(
"--save_data",
action="store_true",
help="If set to true, save tokens & losses",
)
parser.add_argument(
"--local",
action="store_true",
help="If set to true, the script runs in test mode and only takes 10 examples per dataset",
)
parser.add_argument(
"--metadata_to_test",
type=str,
default="html,entity,entity_paragraph,website_desc,generation_datasource,timestamp,title,generation_length_sentence,generation_length_text,url,paragraph",
help="metadata types to test",
)
parser.add_argument(
"--untrained",
action="store_true",
help="If set to true, will load gpt2-xl",
)

args = parser.parse_args()
print(f"Parameters: {args}")

# Load config
config_file_path = hf_hub_download(repo_id=args.repo_id, filename="actual_config.yaml", use_auth_token=True)
if args.local:
import os

config_file_path = os.path.join(args.repo_id, "actual_config.yaml")
else:
config_file_path = hf_hub_download(repo_id=args.repo_id, filename="actual_config.yaml", use_auth_token=True)
repo_args = OmegaConf.load(config_file_path)
data_config = repo_args.data_config

# make sure loss (ppl) masking is on for local metadata
data_config.metadata_config.treat_local_metadata_as_regular_text = False

# Load model
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(args.repo_id, subfolder=args.subfolder, use_auth_token=True)
if args.untrained:
model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
else:
model = AutoModelForCausalLM.from_pretrained(args.repo_id, subfolder=args.subfolder, use_auth_token=True)
model.eval().cuda() if not args.no_cuda else model.eval()

# Load tokenizer
Expand All @@ -130,6 +255,7 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor:
preprocess_fn = functools.partial(add_metadata_and_chunk_examples, tokenizer=tokenizer, cfg=cfg)

# Validation datasets

dataset_paths = [
"bs-modeling-metadata/c4-en-html-with-validation_metadata_html",
"bs-modeling-metadata/c4-en-html-with-validation_metadata_entity",
Expand All @@ -143,6 +269,7 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"bs-modeling-metadata/c4-en-html-with-validation_metadata_url",
"bs-modeling-metadata/c4-en-html-with-validation_metadata_paragraph",
]
dataset_paths = [path for path in dataset_paths if path.split("_metadata_")[1] in args.metadata_to_test.split(",")]

for path in dataset_paths:
n_examples = 0
Expand All @@ -158,7 +285,7 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor:
split = "validation" if not args.test else "validation[:10]"
validation_dataset = load_dataset(path, use_auth_token=True, split=split)

for example in tqdm(validation_dataset, desc=f"Calculating perplexity for {metadata_type}..."):
for idx, example in tqdm(enumerate(validation_dataset), desc=f"Calculating perplexity for {metadata_type}..."):
# Preprocess examples
examples = {k: [v] for k, v in example.items()}
try:
Expand All @@ -176,6 +303,10 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor:
normal_example = tokenizer(examples["text"][0])
normal_example_len = len(normal_example["input_ids"])
metadata_example = {k: v[0] for k, v in processed_examples.items()}
# rich.print(f"{metadata_example['attention_mask']=}")
# rich.print(f"{normal_example['attention_mask']=}")
# import sys
# sys.exit()
metadata_example_len = len(metadata_example["input_ids"])
min_seq_len = min(normal_example_len, metadata_example_len)
max_seq_len = max(normal_example_len, metadata_example_len)
Expand All @@ -197,12 +328,70 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor:
if not args.no_cuda:
normal_batch = {k: v.cuda() for k, v in normal_batch.items()}
metadata_batch = {k: v.cuda() for k, v in metadata_batch.items()}
if n_examples == 1:
ex = format_by_one_mask(normal_batch["input_ids"][0], normal_batch["attention_mask"][0], tokenizer)
rich.print(f"Normal example:")
rich.print(ex)

ex = format_by_one_mask(
metadata_batch["input_ids"][0], metadata_batch["metadata_mask"][0], tokenizer
)
rich.print(f"Metadata example:")
rich.print(ex)
rich.print(tokenizer.decode(metadata_batch["input_ids"][0]))

# Calculate ppl
normal_ppl = get_ppl(normal_batch)
total_normal_ppl += float(normal_ppl) * normal_example_len
metadata_ppl = get_ppl(metadata_batch)
total_metadata_ppl += float(metadata_ppl) * metadata_example_len
normal_ppl = get_ppl(normal_batch, save_data=args.save_data, idx=idx)
# total_normal_ppl += float(normal_ppl) * normal_example_len
metadata_ppl = get_ppl(metadata_batch, save_data=args.save_data, idx=idx)
# total_metadata_ppl += float(metadata_ppl) * metadata_example_len
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you comment it out? Same above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I don't remember. I think I just wanted to see the ppl of the first example, but I don't know why I commented it out.

if n_examples == 1:
loss, mask, shift_labels = normal_ppl
print("normal ppl")
printed = 0
for i, (l, m, sl) in enumerate(zip(loss, mask, shift_labels)):
if m:
if printed < 10:
rich.print(f"Loss {json.dumps(tokenizer.decode(sl))}: {l}")
printed += 1

unmasked_labels = [label for label, m in zip(shift_labels, mask) if m]
# print(f"first 10 unmasked labels: {[tokenizer.decode(x) for x in unmasked_labels[:10]]}")
print(f"first 10 unmasked labels: {tokenizer.decode(unmasked_labels[:10])}")
# ex = format_by_one_mask(normal_batch["input_ids"][0], mask, tokenizer)
# rich.print(ex)

loss, mask, shift_labels = metadata_ppl
printed = 0
print("metadata ppl")
for i, (l, m, sl) in enumerate(zip(loss, mask, shift_labels)):
if m:
if printed < 10:
rich.print(f"Loss {json.dumps(tokenizer.decode(sl))}: {l}")
printed += 1

unmasked_labels = [label for label, m in zip(shift_labels, mask) if m]
print(f"first 10 unmasked labels: {tokenizer.decode(unmasked_labels[:10])}")
# ex = format_by_one_mask(metadata_batch["input_ids"][0], mask, tokenizer)
# rich.print(ex)

# ex = format_by_one_mask(normal_batch["input_ids"][0], normal_batch["attention_mask"][0], tokenizer)
# rich.print(ex)
# rich.print(f"Normal example: (ppl={normal_ppl[0]})")

# ex = format_by_one_mask(
# metadata_batch["input_ids"][0], metadata_batch["metadata_mask"][0], tokenizer
# )
# rich.print(ex)
# rich.print(f"Metadata example: (ppl={metadata_ppl[0]})")
# rich.print(f"Normal example: (mask={normal_ppl[1]})")
# rich.print(f"Metadata example: (mask={metadata_ppl[1]})")
import sys

sys.exit()

if n_examples > 1000:
break

if exit_flag:
continue
Expand Down
9 changes: 0 additions & 9 deletions bsmetadata/experiments/datasetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,15 +265,6 @@
"c4-en-html_cc-main-2019-18_pq01-009.jsonl.gz",
"c4-en-html_cc-main-2019-18_pq01-010.jsonl.gz",
"c4-en-html_cc-main-2019-18_pq01-011.jsonl.gz",
"c4-en-html_cc-main-2019-18_pq01-012.jsonl.gz",
"c4-en-html_cc-main-2019-18_pq01-013.jsonl.gz",
"c4-en-html_cc-main-2019-18_pq01-014.jsonl.gz",
"c4-en-html_cc-main-2019-18_pq01-016.jsonl.gz",
"c4-en-html_cc-main-2019-18_pq01-017.jsonl.gz",
"c4-en-html_cc-main-2019-18_pq01-018.jsonl.gz",
"c4-en-html_cc-main-2019-18_pq01-019.jsonl.gz",
"c4-en-html_cc-main-2019-18_pq01-020.jsonl.gz",
"c4-en-html_cc-main-2019-18_pq01-021.jsonl.gz",
]

features = {
Expand Down
Loading