Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/forge/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,22 @@ def _merge_yaml_and_cli_args(yaml_args: Namespace, cli_args: list[str]) -> DictC
cli args, respectively) and merges them into a single OmegaConf DictConfig.

If a cli arg overrides a yaml arg with a _component_ field, the cli arg can
be specified with the parent field directly, e.g., model=torchtune.models.lora_llama2_7b
instead of model._component_=torchtune.models.lora_llama2_7b. Nested fields within the
be specified with the parent field directly, e.g., model=my_module.models.my_model
instead of model._component_=my_module.models.my_model. Nested fields within the
component should be specified with dot notation, e.g., model.lora_rank=16.

Example:
>>> config.yaml:
>>> a: 1
>>> b:
>>> _component_: torchtune.models.my_model
>>> _component_: my_module.models.my_model
>>> c: 3

>>> tune full_finetune --config config.yaml b=torchtune.models.other_model b.c=4
>>> python main.py --config config.yaml b=my_module.models.other_model b.c=4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually work or is it just an example?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is really working 😅

>>> yaml_args, cli_args = parser.parse_known_args()
>>> conf = _merge_yaml_and_cli_args(yaml_args, cli_args)
>>> print(conf)
>>> {"a": 1, "b": {"_component_": "torchtune.models.other_model", "c": 4}}
>>> {"a": 1, "b": {"_component_": "my_module.models.other_model", "c": 4}}

Args:
yaml_args (Namespace): Namespace containing args from yaml file, components
Expand Down
2 changes: 1 addition & 1 deletion src/forge/data/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class DatasetInfo:


class TuneIterableDataset(IterableDataset, ABC):
"""Base class for all torchtune iterable datasets.
"""Base class for all forge iterable datasets.

Datasets are composable, enabling complex structures such as:
``PackedDataset(InterleavedDataset([InterleavedDataset([ds1, ds2]), ds3]))``
Expand Down
3 changes: 1 addition & 2 deletions src/forge/data/datasets/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ class AlpacaToMessages(Transform):
(or equivalent fields specified in column_map) columns. User messages are formed from the
instruction + input columns and assistant messages are formed from the output column. Prompt
templating is conditional on the presence of the "input" column, and thus is handled directly
in this transform class instead of a dedicated :class:`~torchtune.data.PromptTemplate` class
due to this custom logic.
in this transform class.

Args:
column_map (dict[str, str] | None): a mapping to change the expected "instruction", "input",
Expand Down
4 changes: 2 additions & 2 deletions src/forge/data/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class HuggingFaceBaseTokenizer(BaseTokenizer):
"""
A wrapper around Hugging Face tokenizers. See https://github.com/huggingface/tokenizers
This can be used to load from a Hugging Face tokenizer.json file into a torchtune BaseTokenizer.
This can be used to load from a Hugging Face tokenizer.json file into a forge BaseTokenizer.

This class will load the tokenizer.json file from tokenizer_json_path. It will
attempt to infer BOS and EOS token IDs from config.json if possible, and if not
Expand Down Expand Up @@ -210,7 +210,7 @@ class HuggingFaceModelTokenizer(ModelTokenizer):
Then, it will load all special tokens and chat template from tokenizer config file.

It can be used to tokenize messages with correct chat template, and it eliminates the requirement of
the specific ModelTokenizer and custom PromptTemplate.
the specific ModelTokenizer.

Args:
tokenizer_json_path (str): Path to tokenizer.json file
Expand Down
5 changes: 2 additions & 3 deletions src/forge/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TuneMessage:
"""
This class represents individual messages in a fine-tuning dataset. It supports
text-only content, text with interleaved images, and tool calls. The
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` will tokenize
:class:`~forge.interfaces.ModelTokenizer` will tokenize
the content of the message using ``tokenize_messages`` and attach the appropriate
special tokens based on the flags set in this class.

Expand Down Expand Up @@ -61,8 +61,7 @@ class TuneMessage:
- All ipython messages (tool call returns) should set ``eot=False``.

Note:
TuneMessage class expects any image content to be a ``torch.Tensor``, as output
by e.g. :func:`~torchtune.data.load_image`
TuneMessage class expects any image content to be a ``torch.Tensor``.
"""

def __init__(
Expand Down
5 changes: 2 additions & 3 deletions src/forge/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ async def update_weights(self, policy_version: int):
class BaseTokenizer(ABC):
"""
Abstract token encoding model that implements ``encode`` and ``decode`` methods.
See :class:`~torchtune.modules.transforms.tokenizers.SentencePieceBaseTokenizer` and
:class:`~torchtune.modules.transforms.tokenizers.TikTokenBaseTokenizer` for example implementations of this protocol.
See :class:`forge.data.HuggingFaceModelTokenizer for an example implementation of this protocol.
"""

@abstractmethod
Expand Down Expand Up @@ -133,7 +132,7 @@ def decode(self, token_ids: list[int], **kwargs: dict[str, Any]) -> str:
class ModelTokenizer(ABC):
"""
Abstract tokenizer that implements model-specific special token logic in
the ``tokenize_messages`` method. See :class:`~torchtune.models.llama3.Llama3Tokenizer`
the ``tokenize_messages`` method. See :class:`forge.data.HuggingFaceModelTokenizer`
for an example implementation of this protocol.
"""

Expand Down
7 changes: 5 additions & 2 deletions src/forge/util/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@ def get_logger(level: str | None = None) -> logging.Logger:
Example:
>>> logger = get_logger("INFO")
>>> logger.info("Hello world!")
INFO:torchtune.utils._logging:Hello world!
INFO:forge.util.logging: Hello world!

Returns:
logging.Logger: The logger.
"""
logger = logging.getLogger(__name__)
if not logger.hasHandlers():
logger.addHandler(logging.StreamHandler())
handler = logging.StreamHandler()
formatter = logging.Formatter("%(levelname)s:%(name)s: %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
if level is not None:
level = getattr(logging, level.upper())
logger.setLevel(level)
Expand Down
4 changes: 2 additions & 2 deletions src/forge/util/metric_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class WandBLogger(MetricLogger):
If int, all metrics will be logged at this frequency.
If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
log_dir (str | None): WandB log directory.
project (str): WandB project name. Default is `torchtune`.
project (str): WandB project name. Default is `TorchForge`.
entity (str | None): WandB entity name. If you don't specify an entity,
the run will be sent to your default entity, which is usually your username.
group (str | None): WandB group name for grouping runs together. If you don't
Expand All @@ -205,7 +205,7 @@ class WandBLogger(MetricLogger):
def __init__(
self,
freq: Union[int, Mapping[str, int]],
project: str,
project: str = "TorchForge",
log_dir: str = "metrics_log",
entity: str | None = None,
group: str | None = None,
Expand Down
Loading