-
Notifications
You must be signed in to change notification settings - Fork 79
Open
Description
When the first GUESS_SIZE elements of the correct list and the myguess list are consistent, it means that all guesses have been made. At this time, the last element of the correct list should also be the correct token, so it should be added to the hits list.
original code
hits = [first_guess] + [0] * (GUESS_SIZE - 1) #multi-level window is filled
#match guess tokens
if guess_tokens is not None:
guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist()
for eg in range(len(guess_results) // GUESS_SIZE):
egx = eg * GUESS_SIZE
correct = [first_guess] + guess_results[egx:egx + GUESS_SIZE]
myguess = guess_tokens[egx:egx + GUESS_SIZE]
gg = 0
for gg in range(len(myguess)):
if myguess[gg] != correct[gg]:
break
if gg > max_hit:
max_hit = gg
max_hit_idx = eg
hits[:max_hit + 1] = correct[:max_hit + 1]
#max_hit is the length of longest accepted sequence in verification branch Modified code
hits = [first_guess] + [0] * GUESS_SIZE #multi-level window is filled
#match guess tokens
if guess_tokens is not None:
guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist()
for eg in range(len(guess_results) // GUESS_SIZE):
egx = eg * GUESS_SIZE
correct = [first_guess] + guess_results[egx:egx + GUESS_SIZE]
myguess = guess_tokens[egx:egx + GUESS_SIZE]
gg = 0
while gg < len(myguess):
if myguess[gg] != correct[gg]:
break
gg += 1
if gg > max_hit:
max_hit = gg
max_hit_idx = eg
hits[:max_hit + 1] = correct[:max_hit + 1]
#max_hit is the length of longest accepted sequence in verification branch Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels