Skip to content

Commit e28f84b

Browse files
committed
Undo other random changes
1 parent 09d81b7 commit e28f84b

File tree

6 files changed

+18
-205
lines changed

6 files changed

+18
-205
lines changed

apps/on_policy_distillation/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ python -m apps.on-policy-distillation.main --config apps/on-policy-distillation/
6868
1. **Ensure proper initialization**: Load the SFT checkpoint before starting OPD
6969
2. **Use prompts only**: During OPD, sample completions from student, don't use dataset solutions
7070
3. **Teacher quality matters**: Better teachers provide better supervision
71-
4. **Monitor reverse KL**: Should decrease to near-zero as training progresses
71+
4. **Monitor reverse KL**: Should go to near-zero as training progresses
7272

7373
## References
7474

apps/sft/main.py

Lines changed: 6 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,11 @@
2121

2222
import torch
2323

24+
import torchtitan.experiments.forge.train_spec as forge_train_spec
2425
from forge.controller import ForgeActor
2526
from forge.data.collate import collate_packed
2627
from forge.data.datasets.packed import PackedDataset, TextPacker
27-
from forge.data.datasets.sft_dataset import (
28-
AlpacaToMessages,
29-
OpenThoughtsToMessages,
30-
sft_iterable_dataset,
31-
)
28+
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
3229
from forge.data.tokenizer import HuggingFaceModelTokenizer
3330
from forge.observability import get_or_create_metric_logger, record_metric, Reduce
3431
from forge.util.config import parse
@@ -84,34 +81,8 @@ def __init__(self, config: DictConfig):
8481
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
8582
self._rank = current_rank().rank
8683
self._size = math.prod(current_size().values())
87-
self._init_dist()
8884
super().__init__(job_config)
8985

90-
def _init_dist(self):
91-
"""Initializes torch distributed.
92-
93-
torchrun normally hands this, but we need to do it ourselves
94-
in monarch for now.
95-
96-
We should consider putting this into ForgeActor, but having this
97-
be explicit for now.
98-
99-
"""
100-
env = {
101-
"RANK": str(self._rank),
102-
"LOCAL_RANK": str(self._rank),
103-
"LOCAL_WORLD_SIZE": str(self._size),
104-
"GROUP_RANK": str(self._size),
105-
"GROUP_WORLD_SIZE": str(self._size),
106-
"ROLE_RANK": str(self._rank),
107-
"ROLE_WORLD_SIZE": str(self._size),
108-
"ROLE_NAME": "rank",
109-
"WORLD_SIZE": str(self._size),
110-
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
111-
}
112-
os.environ.update(env)
113-
logger.info("env: {}".format(env))
114-
11586
async def setup_metric_logger(self):
11687
"""Initialization happens in the main process. Here we just retrieve it"""
11788
mlogger = await get_or_create_metric_logger()
@@ -168,32 +139,13 @@ def setup_data(self):
168139
),
169140
)
170141

171-
# Get dataset configuration from job_config
172-
dataset_config = self.job_config["dataset"]
173-
dataset_path = dataset_config["path"]
174-
dataset_split = dataset_config["split"]
175-
message_transform_type = dataset_config.get("message_transform", "alpaca")
176-
masking_strategy = dataset_config.get("masking_strategy", "train_on_assistant")
177-
178-
# Select the appropriate message transform
179-
if message_transform_type == "openthoughts":
180-
message_transform = OpenThoughtsToMessages(
181-
masking_strategy=masking_strategy
182-
)
183-
elif message_transform_type == "alpaca":
184-
message_transform = AlpacaToMessages(masking_strategy=masking_strategy)
185-
else:
186-
raise ValueError(
187-
f"Unknown message_transform type: {message_transform_type}"
188-
)
189-
190142
dataset = sft_iterable_dataset(
191143
model_transform=tokenizer,
192-
message_transform=message_transform,
193-
path=dataset_path,
194-
split=dataset_split,
144+
message_transform=AlpacaToMessages(),
145+
path="yahma/alpaca-cleaned",
146+
split="train",
195147
)
196-
packer = TextPacker(padding_idx=151643)
148+
packer = TextPacker(padding_idx=0)
197149
dataset = PackedDataset(
198150
dataset=dataset,
199151
packer=packer,

src/forge/actors/generator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,12 @@ def _spawn_fetchers(self):
239239
# TODO: this assumes the generator is on the same host as the worker
240240
# and only works for single host generators. Figure out how to support
241241
# generators with workers spanned across multiple hosts.
242-
fetcher_procs = this_host().spawn_procs(
243-
per_host={"procs": self.n_fetcher_procs}
244-
)
245-
self._fetcher_procs = fetcher_procs
246-
self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher)
242+
pass
243+
# fetcher_procs = this_host().spawn_procs(
244+
# per_host={"procs": self.n_fetcher_procs}
245+
# )
246+
# self._fetcher_procs = fetcher_procs
247+
# self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher)
247248

248249
def _start_processing(self):
249250
if self._run_task is None or self._run_task.done():

src/forge/actors/reference_model.py

Lines changed: 4 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
from dataclasses import dataclass, field, fields
1414

1515
import torch
16-
import torch.nn.functional as F
17-
18-
# from forge.util.ops import compute_logprobs
1916
from monarch.actor import current_rank, current_size, endpoint
2017
from torch.distributed.tensor import DTensor
2118

@@ -33,6 +30,7 @@
3330
from forge.controller import ForgeActor
3431
from forge.observability.metrics import record_metric, Reduce
3532
from forge.observability.perf_tracker import Tracer
33+
from forge.util.ops import compute_logprobs
3634

3735
logger = logging.getLogger(__name__)
3836
logger.setLevel(logging.INFO)
@@ -182,77 +180,15 @@ async def forward(
182180
with torch.inference_mode():
183181
logits = self.model(input_ids)
184182
self.step += 1
185-
# if isinstance(logits, DTensor):
186-
# logits = logits.full_tensor()
183+
if isinstance(logits, DTensor):
184+
logits = logits.full_tensor()
187185
t.step("forward")
188186

189187
if not return_logprobs:
190188
t.stop()
191-
if isinstance(logits, DTensor):
192-
return logits.full_tensor()
193189
return logits
194190
else:
195-
logprobs = compute_logprobs_chunked(logits, input_ids[:, max_req_tokens:])
191+
logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:])
196192
t.step("compute_logprobs")
197193
t.stop()
198194
return logprobs
199-
200-
201-
def compute_logprobs_chunked(
202-
logits: torch.Tensor | DTensor,
203-
input_ids: torch.Tensor,
204-
temperature: float = 1.0,
205-
align: bool = True,
206-
chunk_size: int = 512,
207-
) -> torch.Tensor:
208-
"""
209-
Memory-efficient version that processes logits in chunks along the sequence dimension.
210-
Useful for very long sequences where even the DTensor operations might cause memory issues.
211-
212-
Args:
213-
chunk_size: Number of tokens to process at once. Lower values use less memory.
214-
"""
215-
is_dtensor = isinstance(logits, DTensor)
216-
217-
# Align logits with input_ids if requested
218-
if align:
219-
target_len = input_ids.size(1)
220-
logits = logits[:, -target_len - 1 : -1, :]
221-
if not is_dtensor:
222-
logits = logits.to(input_ids.device)
223-
224-
batch_size, seq_len, vocab_size = logits.shape
225-
226-
# Initialize output tensor
227-
logprobs = torch.zeros(
228-
batch_size, seq_len, dtype=torch.float32, device=logits.device
229-
)
230-
231-
# Process in chunks
232-
for start_idx in range(0, seq_len, chunk_size):
233-
end_idx = min(start_idx + chunk_size, seq_len)
234-
235-
# Get chunk of logits and input_ids
236-
logits_chunk = logits[:, start_idx:end_idx, :]
237-
input_chunk = input_ids[:, start_idx:end_idx]
238-
239-
# Scale and convert to fp32
240-
scaled_chunk = (logits_chunk / temperature).float()
241-
242-
# Compute log probabilities for this chunk
243-
chunk_size_actual = end_idx - start_idx
244-
flat_logits = scaled_chunk.reshape(-1, vocab_size)
245-
flat_targets = input_chunk.reshape(-1).long()
246-
247-
chunk_logprobs = -F.cross_entropy(
248-
flat_logits,
249-
flat_targets,
250-
reduction="none",
251-
)
252-
253-
# Store in output tensor
254-
logprobs[:, start_idx:end_idx] = chunk_logprobs.reshape(
255-
batch_size, chunk_size_actual
256-
)
257-
258-
return logprobs

src/forge/data/datasets/__init__.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,13 @@
77
from .dataset import DatasetInfo, InfiniteTuneIterableDataset, InterleavedDataset
88
from .hf_dataset import HfIterableDataset
99
from .packed import PackedDataset
10-
from .sft_dataset import (
11-
AlpacaToMessages,
12-
OpenThoughtsToMessages,
13-
sft_iterable_dataset,
14-
SFTOutputTransform,
15-
)
10+
from .sft_dataset import sft_iterable_dataset, SFTOutputTransform
1611

1712
__all__ = [
18-
"AlpacaToMessages",
1913
"DatasetInfo",
2014
"HfIterableDataset",
2115
"InterleavedDataset",
2216
"InfiniteTuneIterableDataset",
23-
"OpenThoughtsToMessages",
2417
"PackedDataset",
2518
"SFTOutputTransform",
2619
"sft_iterable_dataset",

src/forge/data/datasets/sft_dataset.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -105,75 +105,6 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
105105
return {"messages": messages}
106106

107107

108-
class OpenThoughtsToMessages:
109-
"""
110-
Message transform class for OpenThoughts-style datasets with a "conversations" column
111-
containing a list of dictionaries with "from" and "value" fields.
112-
113-
Args:
114-
column_map (dict[str, str] | None): a mapping to change the expected "conversations"
115-
column name to the actual column name in the dataset. Default is None,
116-
keeping the default column name.
117-
masking_strategy (str): masking strategy to use for model training.
118-
Must be one of: `train_on_all`, `train_on_assistant`, `train_on_last`.
119-
Default is "train_on_assistant".
120-
121-
- ``train_on_all``: both user and assistant messages are unmasked
122-
- ``train_on_assistant``: user messages are masked, only assistant messages are unmasked
123-
- ``train_on_last``: only the last assistant message is unmasked
124-
"""
125-
126-
def __init__(
127-
self,
128-
column_map: dict[str, str] | None = None,
129-
masking_strategy: str = "train_on_assistant",
130-
):
131-
self.masking_strategy = masking_strategy
132-
if column_map:
133-
if "conversations" not in column_map:
134-
raise ValueError(
135-
f"Expected a key of 'conversations' in column_map but found {column_map.keys()}."
136-
)
137-
self._column_map = column_map
138-
else:
139-
self._column_map = {
140-
"conversations": "conversations",
141-
}
142-
143-
def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
144-
conversations = sample[self._column_map["conversations"]]
145-
146-
if not isinstance(conversations, list):
147-
raise ValueError(
148-
f"Expected 'conversations' to be a list, got {type(conversations)}"
149-
)
150-
151-
messages = []
152-
for message_dict in conversations:
153-
role = message_dict.get("from", "")
154-
content = message_dict.get("value", "")
155-
156-
# Map OpenThoughts roles to standard roles
157-
if role in ["human", "user"]:
158-
role = "user"
159-
elif role in ["gpt", "assistant", "model"]:
160-
role = "assistant"
161-
else:
162-
# Skip unknown roles
163-
continue
164-
165-
messages.append(
166-
TuneMessage(
167-
role=role,
168-
content=content,
169-
eot=True,
170-
)
171-
)
172-
173-
mask_messages(messages, self.masking_strategy)
174-
return {"messages": messages}
175-
176-
177108
class SFTOutputTransform:
178109
"""Applied to each dataset sample to build the `"labels"` tensor for causal-LM SFT training.
179110

0 commit comments

Comments
 (0)