Skip to content

Commit 32c9814

Browse files
committed
Add in validation for compile
1 parent b7652a9 commit 32c9814

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

apps/sft/main.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ def record_batch_metrics(self, data_metrics: list):
9595

9696
@endpoint
9797
async def setup(self):
98+
# Validate that compile is only used with flex attention
99+
if self.job_config.training.compile:
100+
raise ValueError(
101+
"training.compile=True is not currently supported. "
102+
"Compile is only supported with flex attention enabled, which requires PyTorch nightly. "
103+
"Please set training.compile=false in your config."
104+
)
98105

99106
# all ranks should record loss, except when PP=True. Then, only the last stage should record loss.
100107
self.rank_should_record_loss = True
@@ -150,6 +157,7 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader:
150157
Raises:
151158
ValueError: If multiple datasets provided (not yet supported)
152159
"""
160+
153161
# TODO felipemello: Currently only support single dataset
154162
if len(dataset_configs) > 1:
155163
raise ValueError(

src/forge/data/collate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def collate_padded(batch: list[dict[str, Any]]) -> dict[str, Any]:
4646
tokens_list.append(padded_tokens)
4747

4848
# Pad labels with CROSS_ENTROPY_IGNORE_IDX (-100)
49-
padded_labels = F.pad(sample["labels"], (0, pad_len), value=CROSS_ENTROPY_IGNORE_IDX)
49+
padded_labels = F.pad(
50+
sample["labels"], (0, pad_len), value=CROSS_ENTROPY_IGNORE_IDX
51+
)
5052
labels_list.append(padded_labels)
5153

5254
# Stack into batch

tests/unit_tests/datasets/test_packed.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,4 +1125,3 @@ def test_metrics_flattened(self):
11251125
assert "metrics" in result
11261126
# Should be flattened from [[metric1, metric2], [metric3]] to [metric1, metric2, metric3]
11271127
assert len(result["metrics"]) == 3
1128-

0 commit comments

Comments
 (0)