Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/inference/InferenceEngine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ class FinishReason(StrEnum):
# The model took longer than our timeout to return the first token
ModelOverloaded = "model overloaded"

# Encountered RPC error from inferD
BadConnection = "bad connection"

# Value error can be like when context length is too long
ValueError = "value error"

# General exceptions
Unknown = "unknown"

@dataclass
class InferenceEngineMessage:
Expand Down
15 changes: 13 additions & 2 deletions src/message/create_message_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,14 +392,21 @@ def map_chunk(chunk: InferenceEngineChunk):
yield map_chunk(chunk)

except grpc.RpcError as e:
finish_reason = FinishReason.BadConnection
err = f"inference failed: {e}"
yield format_message(message.MessageStreamError(reply.id, err, "grpc inference failed"))

except multiprocessing.TimeoutError:
finish_reason = FinishReason.ModelOverloaded
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed?

Copy link
Copy Markdown
Contributor Author

@yensung yensung Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is changed to yielding error msg directly. I can change it back if you prefer the previous way

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this affect how we save the message at the end?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. Curious why only TimeoutError exception is saved as finish_reason. Should we save RpcError and ValueError in finish_reason when they happen?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ValueError may be a good finish reason, maybe something like FinishReason.InvalidRequest? I think RpcError is caused by a bad connection to InferD, so we may not need to save that?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll name it to BadConnection

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to keep that as-is or change the UI as well. We have some code tied to these values. https://github.com/allenai/olmo-ui/blob/dev/src/api/Message.ts#L99-L119

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll raise a PR to add them


gen = time_ns() - start_gen
gen //= 1000000
except ValueError as e:
finish_reason = FinishReason.ValueError
# value error can be like when context length is too long
yield format_message(message.MessageStreamError(reply.id, f"{e}", "value error from inference result"))

except Exception as e:
finish_reason = FinishReason.Unknown
yield format_message(message.MessageStreamError(reply.id, f"{e}", "general exception"))

match finish_reason:
case FinishReason.UnclosedStream:
Expand Down Expand Up @@ -430,6 +437,10 @@ def map_chunk(chunk: InferenceEngineChunk):
# The generation is complete. Store it.
# TODO: InferD should store this so that we don't have to.
# TODO: capture InferD request input instead of our manifestation of the prompt format

gen = time_ns() - start_gen
gen //= 1000000

prompt = create_prompt_from_engine_input(chain)
output, logprobs = create_output_from_chunks(chunks)

Expand Down