Skip to content

Cannot retrieve log probabilities for generated tokens #738

@mhoangvslev

Description

@mhoangvslev
  • This is actually a bug report.
  • I am not getting good LLM Results
  • I have tried asking for help in the community on discord or discussions and have not received a response.
  • I have tried searching the documentation and have not found an answer.

What Model are you using?

  • [x ] gpt-3.5-turbo
  • gpt-4-turbo
  • gpt-4
  • Other (Mixtral 8x7B Instruct)

Describe the bug
I was trying to do a binary classifier that can only answer "Yes" or "No" and I also want to retrieve the log probability of the answer.
This is as simple as setting lobprobs=2 in the create_completion(). Alas, in the raw response, the logprobs field is None

To Reproduce

import enum
from typing import Literal
import httpx
from pydantic import BaseModel

from llama_cpp import Llama
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding


class Labels(str, enum.Enum):
    """Enumeration for single-label text classification."""

    SPAM = "spam"
    NOT_SPAM = "not_spam"


class SinglePrediction(BaseModel):
    """
    Class for a single class label prediction.
    """

    class_label: Labels

from openai import OpenAI
import instructor

host = "localhost"
port = 8084

class Labels(str, enum.Enum):
    """Enumeration for single-label text classification."""

    SPAM = "spam"
    NOT_SPAM = "not_spam"

class BinaryPrediction(BaseModel):
    # label: Literal["TOKPOS", "TOKNEG"]
    label: Labels

# Server mode
llm = OpenAI(
    base_url=f"http://{host}:{port}/v1", api_key="sk-xxx",
    http_client=httpx.Client(
        transport=httpx.HTTPTransport(local_address="0.0.0.0"),
    ),     
)

client = instructor.patch(client=llm, mode=instructor.Mode.JSON_SCHEMA)

# Offline mode
# llm = Llama(
#     model_path=".models/models--TheBloke--Mixtral-8x7B-Instruct-v0.1-GGUF/snapshots/fa1d3835c5d45a3a74c0b68805fcdc133dba2b6a/mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf", 
#     draft_model=LlamaPromptLookupDecoding(num_pred_tokens=10), # 10 is good for GPU (see https://python.useinstructor.com/hub/llama-cpp-python/#llama-cpp-python)
#     n_ctx=16385,
#     n_gpu_layers=15,
#     logits_all = True,
#     offload_kqv = True,
#     server_mode = True
# )

# client = instructor.patch(
#     create=llm.create_chat_completion_openai_v1, 
#     mode=instructor.Mode.JSON_SCHEMA
# )


def classify(data: str) -> SinglePrediction:
    """Perform single-label classification on the input text."""
 
    return client.chat.completions.create(
        model="mixtral-8x7b-instruct-v0.1",
        response_model=BinaryPrediction, 
        messages=[
            {
                "role": "user",
                "content": f"Classify the following text: {data}",
            },
        ],
        logprobs=True
    )

    # return client(
    #     response_model=BinaryPrediction, 
    #     messages=[
    #         {
    #             "role": "user",
    #             "content": f"Classify the following text: {data}",
    #         },
    #     ],
    #     logprobs=2,
    # )

# Test single-label classification
prediction = classify("Hello there I'm a Nigerian prince and I want to give you money")
print(prediction._raw_response)
$ python markup/misc/test_binary_classifier.py
ChatCompletion(id='chatcmpl-a6963005-f37d-456d-86eb-3d564616118c', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='{ "label": "spam"}', role='assistant', function_call=None, tool_calls=None))], created=1718024981, model='gpt-3.5-turbo-16k', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=9, prompt_tokens=235, total_tokens=244))

Expected behavior
logprob = response._raw_response.choices[0].logprobs.content[0].logprob # expect a number

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions