Skip to content

Commit 7b9ba6b

Browse files
committed
Merge branch 'main' into a3c-implementation
Merge from main
2 parents 72eea77 + 95637f3 commit 7b9ba6b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+8802
-1114
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ htmlcov/
4444
.coverage
4545
.coverage.*
4646
.cache
47+
.neptune
4748
nosetests.xml
4849
coverage.xml
4950
*.cover

docs/source/reference/llms.rst

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,29 @@ TorchRL offers a set of tools for LLM post-training, as well as some examples fo
1010
Collectors
1111
----------
1212

13-
TorchRL offers a specialized collector class (:class:`~torchrl.collectors.llm.LLMCollector`) that is tailored for LLM
13+
TorchRL offers specialized collector classes (:class:`~torchrl.collectors.llm.LLMCollector` and :class:`~torchrl.collectors.llm.RayLLMCollector`) that are tailored for LLM
1414
use cases. We also provide dedicated updaters for some inference engines.
1515

16+
LLM Collectors allow to track the version of the policy, which is useful for some use cases.
17+
This is done by adding a :class:`~torchrl.envs.llm.transforms.PolicyVersion` transform to the environment, which is
18+
then incremented by the collector after each weight update. To do this, one either provides the stateful version of the
19+
transform, or a boolean to the collector constructor.
20+
21+
>>> from torchrl.envs.llm.transforms import PolicyVersion
22+
>>> from torchrl.collectors.llm import LLMCollector
23+
>>> from torchrl.collectors.llm.weight_update import vLLMUpdater
24+
>>> env = make_env() # place your code here
25+
>>> policy = make_policy() # place your code here
26+
>>> collector = LLMCollector(env, policy=policy, weight_updater=vLLMUpdater(), track_policy_version=True)
27+
>>> # init the updater
28+
>>> collector.weight_updater.init(...)
29+
>>> # the version is incremented after each weight update
30+
>>> collector.update_policy_weights_(state_dict=...)
31+
>>> print(collector.policy_version_tracker.version)
32+
>>> # the policy version is written in the data
33+
>>> for data in collector:
34+
... print(data["policy_version"])
35+
1636
.. currentmodule:: torchrl.collectors.llm
1737

1838
.. autosummary::
@@ -21,6 +41,7 @@ use cases. We also provide dedicated updaters for some inference engines.
2141

2242
vLLMUpdater
2343
LLMCollector
44+
RayLLMCollector
2445

2546

2647
Data structures
@@ -179,9 +200,11 @@ transforms).
179200

180201
DataLoadingPrimer
181202
KLRewardTransform
203+
RetrieveLogProb
182204
MCPToolTransform
183205
BrowserTransform
184206
PythonInterpreter
207+
PolicyVersion
185208
TemplateTransform
186209
Tokenizer
187210
as_nested_tensor
@@ -234,6 +257,9 @@ LLM post training require some appropriate versions of the losses implemented in
234257
GRPO
235258
~~~~
236259

260+
The :class:`~torchrl.objectives.llm.GRPOLoss` class is a thin wrapper around the :class:`~torchrl.objectives.PPOLoss` class
261+
that codes the LLM-specific functionnalities.
262+
237263
.. currentmodule:: torchrl.objectives.llm
238264

239265
.. autosummary::
@@ -243,3 +269,24 @@ GRPO
243269
GRPOLoss
244270
GRPOLossOutput
245271
MCAdvantage
272+
273+
274+
SFT
275+
~~~
276+
277+
.. currentmodule:: torchrl.objectives.llm
278+
279+
.. autosummary::
280+
:toctree: generated/
281+
:template: rl_template.rst
282+
283+
SFTLoss
284+
SFTLossOutput
285+
286+
.. currentmodule:: torchrl.data.llm
287+
288+
.. autosummary::
289+
:toctree: generated/
290+
:template: rl_template.rst
291+
292+
TopKRewardSelector

0 commit comments

Comments
 (0)