Skip to content

Commit 573d371

Browse files
committed
clean up torchtune
1 parent fd33e3a commit 573d371

File tree

9 files changed

+26
-20
lines changed

9 files changed

+26
-20
lines changed

src/forge/cli/config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,22 +56,22 @@ def _merge_yaml_and_cli_args(yaml_args: Namespace, cli_args: list[str]) -> DictC
5656
cli args, respectively) and merges them into a single OmegaConf DictConfig.
5757
5858
If a cli arg overrides a yaml arg with a _component_ field, the cli arg can
59-
be specified with the parent field directly, e.g., model=torchtune.models.lora_llama2_7b
60-
instead of model._component_=torchtune.models.lora_llama2_7b. Nested fields within the
59+
be specified with the parent field directly, e.g., model=my_module.models.my_model
60+
instead of model._component_=my_module.models.my_model. Nested fields within the
6161
component should be specified with dot notation, e.g., model.lora_rank=16.
6262
6363
Example:
6464
>>> config.yaml:
6565
>>> a: 1
6666
>>> b:
67-
>>> _component_: torchtune.models.my_model
67+
>>> _component_: my_module.models.my_model
6868
>>> c: 3
6969
70-
>>> tune full_finetune --config config.yaml b=torchtune.models.other_model b.c=4
70+
>>> python main.py --config config.yaml b=my_module.models.other_model b.c=4
7171
>>> yaml_args, cli_args = parser.parse_known_args()
7272
>>> conf = _merge_yaml_and_cli_args(yaml_args, cli_args)
7373
>>> print(conf)
74-
>>> {"a": 1, "b": {"_component_": "torchtune.models.other_model", "c": 4}}
74+
>>> {"a": 1, "b": {"_component_": "my_module.models.other_model", "c": 4}}
7575
7676
Args:
7777
yaml_args (Namespace): Namespace containing args from yaml file, components

src/forge/data/datasets/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class DatasetInfo:
6161

6262

6363
class TuneIterableDataset(IterableDataset, ABC):
64-
"""Base class for all torchtune iterable datasets.
64+
"""Base class for all forge iterable datasets.
6565
6666
Datasets are composable, enabling complex structures such as:
6767
``PackedDataset(InterleavedDataset([InterleavedDataset([ds1, ds2]), ds3]))``

src/forge/data/datasets/sft_dataset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ class AlpacaToMessages(Transform):
2222
(or equivalent fields specified in column_map) columns. User messages are formed from the
2323
instruction + input columns and assistant messages are formed from the output column. Prompt
2424
templating is conditional on the presence of the "input" column, and thus is handled directly
25-
in this transform class instead of a dedicated :class:`~torchtune.data.PromptTemplate` class
26-
due to this custom logic.
25+
in this transform class.
2726
2827
Args:
2928
column_map (dict[str, str] | None): a mapping to change the expected "instruction", "input",

src/forge/data/tokenizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
class HuggingFaceBaseTokenizer(BaseTokenizer):
2121
"""
2222
A wrapper around Hugging Face tokenizers. See https://github.com/huggingface/tokenizers
23-
This can be used to load from a Hugging Face tokenizer.json file into a torchtune BaseTokenizer.
23+
This can be used to load from a Hugging Face tokenizer.json file into a forge BaseTokenizer.
2424
2525
This class will load the tokenizer.json file from tokenizer_json_path. It will
2626
attempt to infer BOS and EOS token IDs from config.json if possible, and if not
@@ -210,7 +210,7 @@ class HuggingFaceModelTokenizer(ModelTokenizer):
210210
Then, it will load all special tokens and chat template from tokenizer config file.
211211
212212
It can be used to tokenize messages with correct chat template, and it eliminates the requirement of
213-
the specific ModelTokenizer and custom PromptTemplate.
213+
the specific ModelTokenizer.
214214
215215
Args:
216216
tokenizer_json_path (str): Path to tokenizer.json file

src/forge/data/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class TuneMessage:
3232
"""
3333
This class represents individual messages in a fine-tuning dataset. It supports
3434
text-only content, text with interleaved images, and tool calls. The
35-
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` will tokenize
35+
:class:`~forge.interfaces.ModelTokenizer` will tokenize
3636
the content of the message using ``tokenize_messages`` and attach the appropriate
3737
special tokens based on the flags set in this class.
3838
@@ -61,8 +61,7 @@ class TuneMessage:
6161
- All ipython messages (tool call returns) should set ``eot=False``.
6262
6363
Note:
64-
TuneMessage class expects any image content to be a ``torch.Tensor``, as output
65-
by e.g. :func:`~torchtune.data.load_image`
64+
TuneMessage class expects any image content to be a ``torch.Tensor``.
6665
"""
6766

6867
def __init__(

src/forge/interfaces.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,7 @@ async def update_weights(self, policy_version: int):
9797
class BaseTokenizer(ABC):
9898
"""
9999
Abstract token encoding model that implements ``encode`` and ``decode`` methods.
100-
See :class:`~torchtune.modules.transforms.tokenizers.SentencePieceBaseTokenizer` and
101-
:class:`~torchtune.modules.transforms.tokenizers.TikTokenBaseTokenizer` for example implementations of this protocol.
100+
See :class:`forge.data.HuggingFaceModelTokenizer for an example implementation of this protocol.
102101
"""
103102

104103
@abstractmethod
@@ -133,7 +132,7 @@ def decode(self, token_ids: list[int], **kwargs: dict[str, Any]) -> str:
133132
class ModelTokenizer(ABC):
134133
"""
135134
Abstract tokenizer that implements model-specific special token logic in
136-
the ``tokenize_messages`` method. See :class:`~torchtune.models.llama3.Llama3Tokenizer`
135+
the ``tokenize_messages`` method. See :class:`forge.data.HuggingFaceModelTokenizer`
137136
for an example implementation of this protocol.
138137
"""
139138

src/forge/util/logging.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@ def get_logger(level: str | None = None) -> logging.Logger:
2020
Example:
2121
>>> logger = get_logger("INFO")
2222
>>> logger.info("Hello world!")
23-
INFO:torchtune.utils._logging:Hello world!
23+
INFO:forge.util.logging: Hello world!
2424
2525
Returns:
2626
logging.Logger: The logger.
2727
"""
2828
logger = logging.getLogger(__name__)
2929
if not logger.hasHandlers():
30-
logger.addHandler(logging.StreamHandler())
30+
handler = logging.StreamHandler()
31+
formatter = logging.Formatter("%(levelname)s:%(name)s: %(message)s")
32+
handler.setFormatter(formatter)
33+
logger.addHandler(handler)
3134
if level is not None:
3235
level = getattr(logging, level.upper())
3336
logger.setLevel(level)

src/forge/util/metric_logging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ class WandBLogger(MetricLogger):
178178
If int, all metrics will be logged at this frequency.
179179
If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
180180
log_dir (str | None): WandB log directory.
181-
project (str): WandB project name. Default is `torchtune`.
181+
project (str): WandB project name. Default is `torchforge`.
182182
entity (str | None): WandB entity name. If you don't specify an entity,
183183
the run will be sent to your default entity, which is usually your username.
184184
group (str | None): WandB group name for grouping runs together. If you don't
@@ -205,7 +205,7 @@ class WandBLogger(MetricLogger):
205205
def __init__(
206206
self,
207207
freq: Union[int, Mapping[str, int]],
208-
project: str,
208+
project: str = "torchforge",
209209
log_dir: str = "metrics_log",
210210
entity: str | None = None,
211211
group: str | None = None,

test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import logging
2+
3+
from forge.util import get_logger
4+
5+
logger = get_logger("INFO")
6+
logger.info("Hello world!")

0 commit comments

Comments
 (0)