Skip to content

Commit 5add4cb

Browse files
committed
Remove OfflineEagleRegistry; Benjamin:Fix conversation ID indexing
Signed-off-by: h-guo18 <[email protected]>
1 parent 84ab2f8 commit 5add4cb

File tree

4 files changed

+13
-20
lines changed

4 files changed

+13
-20
lines changed

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,6 @@ async def main(args: argparse.Namespace) -> None:
8484
with args.input_file.open("r", encoding="utf-8") as f:
8585
all_conversations.extend([json.loads(line) for line in f if line.strip()])
8686

87-
if any(not entry.get("conversation_id") for entry in all_conversations):
88-
msg = "All conversations must have a 'conversation_id' field."
89-
raise ValueError(msg)
90-
9187
print("Loaded", len(all_conversations), "conversations from", args.input_file)
9288

9389
model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto")

examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,6 @@ async def main(args: argparse.Namespace) -> None:
9595
with args.input_file.open("r", encoding="utf-8") as f:
9696
all_conversations.extend([json.loads(line) for line in f if line.strip()])
9797

98-
if any(not entry.get("conversation_id") for entry in all_conversations):
99-
msg = "All conversations must have a 'conversation_id' field."
100-
raise ValueError(msg)
101-
10298
print("Loaded", len(all_conversations), "conversations from", args.input_file)
10399

104100
client: AsyncOpenAI = AsyncOpenAI(
@@ -127,12 +123,14 @@ async def main(args: argparse.Namespace) -> None:
127123
num_total_conversations = min(
128124
len(all_conversations), args.debug_max_num_conversations or len(all_conversations)
129125
)
130-
for entry in tqdm(
131-
all_conversations[: args.debug_max_num_conversations],
132-
desc="Processing conversations",
133-
total=num_total_conversations,
126+
for idx, entry in enumerate(
127+
tqdm(
128+
all_conversations[: args.debug_max_num_conversations],
129+
desc="Processing conversations",
130+
total=num_total_conversations,
131+
)
134132
):
135-
conversation_id = entry["conversation_id"]
133+
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))
136134
conversations = entry["conversations"]
137135
if not conversations or not isinstance(conversations, list):
138136
num_invalid += 1

examples/speculative_decoding/eagle_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,12 @@ def make_eagle_supervised_data_module(
269269

270270
# Filter to conversations that exist in the offline data and in the provided json
271271
valid_entries = []
272-
for entry in data_json:
273-
conv_id = entry.get("conversation_id") or entry.get("id")
274-
if not conv_id:
275-
raise ValueError(
276-
"Each entry in the data json must have a 'conversation_id' or 'id' field."
277-
)
272+
for idx, entry in enumerate(data_json):
273+
conv_id = entry.get("conversation_id")
274+
if conv_id is None:
275+
conv_id = entry.get("id")
276+
if conv_id is None:
277+
conv_id = "{:08d}".format(idx)
278278
file_path = str(offline_data_path / f"{conv_id}.pt")
279279
if file_path in all_files:
280280
valid_entries.append((entry, file_path))

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,6 @@ def forward(
330330

331331

332332
@EagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"})
333-
@OfflineEagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"})
334333
class HFEagleModel(EagleModel):
335334
"""Eagle Model Class for huggingface models."""
336335

0 commit comments

Comments
 (0)