|
8 | 8 | from dataclasses import dataclass
|
9 | 9 | from typing import Dict, Iterator, List, Optional, Sequence, Union
|
10 | 10 |
|
| 11 | +import jsonlines |
11 | 12 | import torch
|
12 | 13 | import torch.nn.functional as F
|
13 | 14 | from coati.dataset.utils import chuncate_sequence, pad_to_max_len
|
@@ -345,3 +346,77 @@ def __len__(self) -> int:
|
345 | 346 |
|
346 | 347 | def set_start_index(self, start_index: int) -> None:
|
347 | 348 | self.start_index = start_index
|
| 349 | + |
| 350 | + |
| 351 | +def apply_chat_template_and_mask( |
| 352 | + tokenizer: PreTrainedTokenizer, |
| 353 | + chat: List[Dict[str, str]], |
| 354 | + max_length: Optional[int] = None, |
| 355 | + padding: bool = True, |
| 356 | + truncation: bool = True, |
| 357 | + ignore_idx: int = -100, |
| 358 | +) -> Dict[str, torch.Tensor]: |
| 359 | + tokens = [] |
| 360 | + assistant_mask = [] |
| 361 | + for i, msg in enumerate(chat): |
| 362 | + msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True) |
| 363 | + # remove unexpected bos token |
| 364 | + if i > 0 and msg_tokens[0] == tokenizer.bos_token_id: |
| 365 | + msg_tokens = msg_tokens[1:] |
| 366 | + tokens.extend(msg_tokens) |
| 367 | + if msg["role"] == "assistant": |
| 368 | + assistant_mask.extend([True] * len(msg_tokens)) |
| 369 | + else: |
| 370 | + assistant_mask.extend([False] * len(msg_tokens)) |
| 371 | + attention_mask = [1] * len(tokens) |
| 372 | + if max_length is not None: |
| 373 | + if padding and len(tokens) < max_length: |
| 374 | + to_pad = max_length - len(tokens) |
| 375 | + if tokenizer.padding_side == "right": |
| 376 | + tokens.extend([tokenizer.pad_token_id] * to_pad) |
| 377 | + assistant_mask.extend([False] * to_pad) |
| 378 | + attention_mask.extend([0] * to_pad) |
| 379 | + else: |
| 380 | + tokens = [tokenizer.pad_token_id] * to_pad + tokens |
| 381 | + assistant_mask = [False] * to_pad + assistant_mask |
| 382 | + attention_mask = [0] * to_pad + attention_mask |
| 383 | + if truncation and len(tokens) > max_length: |
| 384 | + tokens = tokens[:max_length] |
| 385 | + assistant_mask = assistant_mask[:max_length] |
| 386 | + attention_mask = attention_mask[:max_length] |
| 387 | + input_ids = torch.tensor(tokens, dtype=torch.long) |
| 388 | + attention_mask = torch.tensor(attention_mask, dtype=torch.long) |
| 389 | + labels = input_ids.clone() |
| 390 | + labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx |
| 391 | + |
| 392 | + return { |
| 393 | + "input_ids": input_ids, |
| 394 | + "attention_mask": attention_mask, |
| 395 | + "labels": labels, |
| 396 | + } |
| 397 | + |
| 398 | + |
| 399 | +class RawConversationDataset(Dataset): |
| 400 | + """ |
| 401 | + Raw conversation dataset. |
| 402 | + Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`. |
| 403 | + """ |
| 404 | + |
| 405 | + def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int) -> None: |
| 406 | + self.tokenizer = tokenizer |
| 407 | + self.raw_texts = [] |
| 408 | + with jsonlines.open(input_file) as f: |
| 409 | + for line in f: |
| 410 | + self.raw_texts.append(line) |
| 411 | + self.tokenized_texts = [None] * len(self.raw_texts) |
| 412 | + self.max_length = max_length |
| 413 | + |
| 414 | + def __len__(self) -> int: |
| 415 | + return len(self.raw_texts) |
| 416 | + |
| 417 | + def __getitem__(self, index: int): |
| 418 | + if self.tokenized_texts[index] is None: |
| 419 | + message = self.raw_texts[index] |
| 420 | + tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length) |
| 421 | + self.tokenized_texts[index] = dict(tokens) |
| 422 | + return self.tokenized_texts[index] |
0 commit comments