|
7 | 7 | import abc
|
8 | 8 | import collections
|
9 | 9 | import importlib
|
| 10 | +from typing import TypeVar |
10 | 11 |
|
11 | 12 | import numpy as np
|
12 | 13 | import torch
|
13 |
| -from tensordict import TensorDict |
| 14 | +from tensordict import TensorClass, TensorDict |
14 | 15 | from torch import nn, Tensor
|
15 | 16 | from torch.nn import functional as F
|
16 | 17 |
|
@@ -541,3 +542,88 @@ def step_scheduler(self):
|
541 | 542 | # remove all values
|
542 | 543 | while len(self._kl_queue):
|
543 | 544 | self._kl_queue.remove(self._kl_queue[0])
|
| 545 | + |
| 546 | +LLMInpOut = TypeVar("LLMInpOut") |
| 547 | + |
| 548 | +class LLMInput(TensorClass["nocast"]): |
| 549 | + """Represents the input to a Large Language Model (LLM). |
| 550 | +
|
| 551 | + Attributes: |
| 552 | + tokens (torch.Tensor): The input tokens as a tensor. |
| 553 | + attention_mask (torch.Tensor, optional): The attention mask for the input tokens. Default to `None`. |
| 554 | + token_list (list[int] | list[list[int]], optional): The input tokens as a list of integers or a list of lists of integers. Default to `None`. |
| 555 | + text (str | list[str], optional): The input text as a string or a list of strings. Default to `None`. |
| 556 | +
|
| 557 | + .. seealso:: :class:`~torchrl.data.LLMOutput` and :class:`~torchrl.data.LLMData`. |
| 558 | +
|
| 559 | + """ |
| 560 | + tokens: torch.Tensor |
| 561 | + attention_mask: torch.Tensor | None = None |
| 562 | + token_list: list[int] | list[list[int]] | None = None |
| 563 | + text: str | list[str] | None = None |
| 564 | + |
| 565 | +class LLMOutput(TensorClass["nocast"]): |
| 566 | + """Represents the output from a Large Language Model (LLM). |
| 567 | +
|
| 568 | + Attributes: |
| 569 | + tokens (torch.Tensor): The output tokens as a tensor. |
| 570 | + tokens_response (torch.Tensor, optional): The response tokens generated by the model. Default to `None`. |
| 571 | +
|
| 572 | + .. note:: the reponse is the sequence of tokens output by a model, excluding the input |
| 573 | + tokens. |
| 574 | +
|
| 575 | + token_list (list[int] | list[list[int]], optional): The output tokens as a list of integers or a list of lists of integers. Default to `None`. |
| 576 | + tokens_response_list (list[list[int]], optional): The response tokens generated by the model as a list of lists of integers. Default to `None`. |
| 577 | + logits (torch.Tensor, optional): The logits of the output tokens. Default to `None`. |
| 578 | + log_probs (torch.Tensor, optional): The log probabilities of the output tokens. Default to `None`. |
| 579 | + text (str | list[str], optional): The output text as a string or a list of strings. Default to `None`. |
| 580 | +
|
| 581 | + .. seealso:: :class:`~torchrl.data.LLMInput` and :class:`~torchrl.data.LLMData`. |
| 582 | +
|
| 583 | + """ |
| 584 | + tokens: torch.Tensor |
| 585 | + tokens_response: torch.Tensor | None = None |
| 586 | + token_list: list[int] | list[list[int]] | None = None |
| 587 | + tokens_response_list: list[list[int]] | None = None |
| 588 | + logits: torch.Tensor | None = None |
| 589 | + log_probs: torch.Tensor | None = None |
| 590 | + text: str | list[str] | None = None |
| 591 | + |
| 592 | + @classmethod |
| 593 | + def from_vllm_output(cls: type[LLMInpOut], vllm_output) -> LLMInpOut: |
| 594 | + # placeholder |
| 595 | + raise NotImplementedError |
| 596 | + |
| 597 | +class LLMData(TensorClass["nocast"]): |
| 598 | + """Represents the input or output of a Large Language Model (LLM). |
| 599 | +
|
| 600 | + Other algorithm-specific attributes such as `reward`, `advantages` or done states are handled automatically by the |
| 601 | + envs and, therefore, are not included in this class. |
| 602 | +
|
| 603 | + Attributes: |
| 604 | + tokens (torch.Tensor): The input/output tokens as a tensor. |
| 605 | + attention_mask (torch.Tensor, optional): The attention mask for the input tokens. Default to `None`. |
| 606 | + tokens_response (torch.Tensor, optional): The response tokens generated by the model. Default to `None`. |
| 607 | +
|
| 608 | + .. note:: the reponse is the sequence of tokens output by a model, excluding the input |
| 609 | + tokens. |
| 610 | +
|
| 611 | + token_list (list[int] | list[list[int]], optional): The output tokens as a list of integers or a list of lists |
| 612 | + of integers. Default to `None`. |
| 613 | + tokens_response_list (list[list[int]], optional): The response tokens generated by the model as a list of |
| 614 | + lists of integers. Default to `None`. |
| 615 | + logits (torch.Tensor, optional): The logits of the output tokens. Default to `None`. |
| 616 | + log_probs (torch.Tensor, optional): The log probabilities of the output tokens. Default to `None`. |
| 617 | + text (str | list[str], optional): The output text as a string or a list of strings. Default to `None`. |
| 618 | +
|
| 619 | + .. seealso:: :class:`~torchrl.data.LLMInput` and :class:`~torchrl.data.LLMOutput`. |
| 620 | +
|
| 621 | + """ |
| 622 | + tokens: torch.Tensor |
| 623 | + tokens_response: torch.Tensor | None = None |
| 624 | + attention_mask: torch.Tensor | None = None |
| 625 | + token_list: list[int] | list[list[int]] | None = None |
| 626 | + tokens_response_list: list[list[int]] | None = None |
| 627 | + logits: torch.Tensor | None = None |
| 628 | + log_probs: torch.Tensor | None = None |
| 629 | + text: str | list[str] | None = None |
0 commit comments