Skip to content

[BUG Report] jacobi_greedy_search_multilevel function bug #56

@yangbohust

Description

@yangbohust

When the first GUESS_SIZE elements of the correct list and the myguess list are consistent, it means that all guesses have been made. At this time, the last element of the correct list should also be the correct token, so it should be added to the hits list.

https://github.com/hao-ai-lab/LookaheadDecoding/blob/9d50de4a81d1b473bfce104ace18fbbbb6dc3255/lade/decoding.py#L1068C1-L1085C88

original code

hits = [first_guess] + [0] * (GUESS_SIZE - 1)
            #multi-level window is filled
            #match guess tokens 
            if guess_tokens is not None:
                guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist()
                for eg in range(len(guess_results) // GUESS_SIZE):
                    egx = eg * GUESS_SIZE
                    correct = [first_guess] + guess_results[egx:egx + GUESS_SIZE]
                    myguess = guess_tokens[egx:egx + GUESS_SIZE]
                    gg = 0
                    for gg in range(len(myguess)):
                        if myguess[gg] != correct[gg]:
                            break 
                    if gg > max_hit:
                        max_hit = gg 
                        max_hit_idx = eg 
                        hits[:max_hit + 1] = correct[:max_hit + 1]
            #max_hit is the length of longest accepted sequence in verification branch 

Modified code

hits = [first_guess] + [0] * GUESS_SIZE
            #multi-level window is filled
            #match guess tokens 
            if guess_tokens is not None:
                guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist()
                for eg in range(len(guess_results) // GUESS_SIZE):
                    egx = eg * GUESS_SIZE
                    correct = [first_guess] + guess_results[egx:egx + GUESS_SIZE]
                    myguess = guess_tokens[egx:egx + GUESS_SIZE]
                    gg = 0
                    while gg < len(myguess):
                        if myguess[gg] != correct[gg]:
                            break
                        gg += 1
                    if gg > max_hit:
                        max_hit = gg 
                        max_hit_idx = eg 
                        hits[:max_hit + 1] = correct[:max_hit + 1]
            #max_hit is the length of longest accepted sequence in verification branch 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions