Skip to content

Commit bd27b86

Browse files
committed
add fallback for collecting hidden states from datasets without conversation_id
Signed-off-by: Benjamin Chislett <[email protected]>
1 parent f74bf59 commit bd27b86

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,14 @@ async def main(args: argparse.Namespace) -> None:
106106
num_total_conversations = min(
107107
len(all_conversations), args.debug_max_num_conversations or len(all_conversations)
108108
)
109-
for entry in tqdm(
110-
all_conversations[: args.debug_max_num_conversations],
111-
desc="Processing conversations",
112-
total=num_total_conversations,
109+
for idx, entry in enumerate(
110+
tqdm(
111+
all_conversations[: args.debug_max_num_conversations],
112+
desc="Processing conversations",
113+
total=num_total_conversations,
114+
)
113115
):
114-
conversation_id = entry["conversation_id"]
116+
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))
115117
conversations = entry["conversations"]
116118
if not conversations or not isinstance(conversations, list):
117119
num_invalid += 1

0 commit comments

Comments
 (0)