-
Notifications
You must be signed in to change notification settings - Fork 16
Adds TitanRefModel in place of HF based Reference Model #94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
c3f2c8c
4ca0685
3b7ee6d
41bdd93
7621fe4
29e74aa
135deaf
ff7c120
a33030d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import asyncio | ||
|
||
from datasets import load_dataset | ||
|
||
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig | ||
from forge.actors.reference_actor import HuggingFaceRefModel, RefModel, TitanRefModel | ||
|
||
from forge.controller.actor import ForgeActor | ||
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service | ||
from monarch.actor import endpoint | ||
from torchtitan.config.job_config import Model | ||
|
||
|
||
class DatasetActor(ForgeActor): | ||
"""Actor wrapper for HuggingFace dataset to provide async interface.""" | ||
|
||
def __init__( | ||
self, path: str, config_name: str, split: str, streaming: bool, **kwargs | ||
): | ||
super().__init__() | ||
|
||
def gsm8k_to_messages(sample): | ||
question = sample["question"] | ||
full_answer: str = sample["answer"] | ||
answer = full_answer.split("#### ")[1] | ||
return {"question": question, "answer": answer} | ||
|
||
ds = load_dataset(path, config_name, split=split, streaming=streaming) | ||
ds = ds.map(gsm8k_to_messages) | ||
ds = ds.shuffle() | ||
self._iterator = iter(ds) | ||
|
||
@endpoint | ||
async def __next__(self) -> dict[str, str] | None: | ||
return next(self._iterator) | ||
|
||
|
||
# Sandbox; will be removed | ||
async def main(): | ||
group_size = 1 | ||
|
||
# For Torchtitan | ||
model = "Qwen/Qwen3-1.7B" | ||
# model = "meta-llama/Meta-Llama-3.1-8B" | ||
|
||
# Spawn Reference "Agents" | ||
|
||
# # Joe | ||
# hf_model = await spawn_service( | ||
# ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True), | ||
# HuggingFaceRefModel, | ||
# model_name=model, | ||
# ) | ||
|
||
# # Philip | ||
# hf_model = await spawn_service( | ||
# ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True), | ||
# RefModel, | ||
# model_name=model, | ||
# ) | ||
|
||
titan_model = await spawn_service( | ||
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True), | ||
TitanRefModel, | ||
) | ||
|
||
# Spawn Policy for getting responses | ||
policy = await spawn_service( | ||
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), | ||
Policy, | ||
config=PolicyConfig( | ||
worker_params=WorkerConfig(model=model), | ||
sampling_params=SamplingOverrides(num_samples=group_size, max_tokens=16), | ||
), | ||
) | ||
|
||
# Load Dataset | ||
dataloader = await spawn_service( | ||
ServiceConfig(procs_per_replica=1, num_replicas=1), | ||
DatasetActor, | ||
path="openai/gsm8k", | ||
config_name="main", | ||
split="train", | ||
streaming=True, | ||
) | ||
sample = await dataloader.__next__.choose() | ||
prompt, target = sample["question"], sample["answer"] | ||
print("Sample: ", sample) | ||
|
||
# Generate output from policy, then pass to reference agents | ||
responses = await policy.generate.choose(prompt) | ||
actions = responses.outputs | ||
for action in actions: | ||
request_tokens = responses.prompt_token_ids | ||
response_tokens = action.token_ids | ||
|
||
print("request_tokens: ", request_tokens) | ||
print("response_tokens: ", response_tokens) | ||
|
||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
# print("HuggingFace Results") | ||
# hf_logprobs = await hf_model.forward.choose( | ||
# request=request_tokens, response=response_tokens | ||
# ) | ||
# print("HF logprob: ", hf_logprobs) | ||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
await asyncio.gather( | ||
shutdown_service(policy), | ||
shutdown_service(dataloader), | ||
# shutdown_service(hf_model), | ||
) | ||
|
||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
print("Titan Results") | ||
titan_logprobs: float = await titan_model.forward.choose( | ||
request=request_tokens, response=response_tokens | ||
) | ||
print("Titan logprob: ", titan_logprobs) | ||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
# await shutdown_service(titan_model) | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer"] | ||
__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer", "TitanRefModel"] | ||
|
||
|
||
def __getattr__(name): | ||
|
@@ -24,5 +24,9 @@ def __getattr__(name): | |
from .replay_buffer import ReplayBuffer | ||
|
||
return ReplayBuffer | ||
elif name == "TitanRefModel": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🙃 |
||
from .reference_actor import TitanRefModel | ||
|
||
return TitanRefModel | ||
else: | ||
raise AttributeError(f"module {__name__} has no attribute {name}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,12 @@ | |
from typing import Dict, List | ||
|
||
import torch | ||
|
||
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh | ||
|
||
from forge.data.sharding import VLLMSharding | ||
from forge.interfaces import Policy as PolicyInterface | ||
from forge.types import ProcessConfig | ||
from monarch.actor import current_rank, endpoint, ProcMesh | ||
from torchstore import MultiProcessStore | ||
from torchstore._state_dict_utils import DELIM | ||
|
@@ -37,12 +43,6 @@ | |
from vllm.v1.structured_output import StructuredOutputManager | ||
from vllm.worker.worker_base import WorkerWrapperBase | ||
|
||
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh | ||
|
||
from forge.data.sharding import VLLMSharding | ||
from forge.interfaces import Policy as PolicyInterface | ||
from forge.types import ProcessConfig | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -310,7 +310,8 @@ async def run(self): | |
for request_output in processed_outputs.request_outputs: | ||
if request_output.finished: | ||
_, fut = self.requests.pop(request_output.request_id) | ||
fut.set_result(request_output.outputs) | ||
# fut.set_result(request_output.outputs) | ||
fut.set_result(request_output) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adopted from #97 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this instead of raw outputs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pragmatically: Less merge conflict with Philip's PR I don't have strong preference, but it does make the output self contained which is nice when we need to pass the results around |
||
|
||
@endpoint | ||
async def update_weights(self): | ||
|
Uh oh!
There was an error while loading. Please reload this page.