Skip to content

Commit cf87e5f

Browse files
authored
Show context length error msg (#262)
1 parent 21f053a commit cf87e5f

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

src/inference/InferenceEngine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ class FinishReason(StrEnum):
2929
# The model took longer than our timeout to return the first token
3030
ModelOverloaded = "model overloaded"
3131

32+
# Encountered RPC error from inferD
33+
BadConnection = "bad connection"
34+
35+
# Value error can be like when context length is too long
36+
ValueError = "value error"
37+
38+
# General exceptions
39+
Unknown = "unknown"
3240

3341
@dataclass
3442
class InferenceEngineMessage:

src/message/create_message_service.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,14 +392,21 @@ def map_chunk(chunk: InferenceEngineChunk):
392392
yield map_chunk(chunk)
393393

394394
except grpc.RpcError as e:
395+
finish_reason = FinishReason.BadConnection
395396
err = f"inference failed: {e}"
396397
yield format_message(message.MessageStreamError(reply.id, err, "grpc inference failed"))
397398

398399
except multiprocessing.TimeoutError:
399400
finish_reason = FinishReason.ModelOverloaded
400401

401-
gen = time_ns() - start_gen
402-
gen //= 1000000
402+
except ValueError as e:
403+
finish_reason = FinishReason.ValueError
404+
# value error can be like when context length is too long
405+
yield format_message(message.MessageStreamError(reply.id, f"{e}", "value error from inference result"))
406+
407+
except Exception as e:
408+
finish_reason = FinishReason.Unknown
409+
yield format_message(message.MessageStreamError(reply.id, f"{e}", "general exception"))
403410

404411
match finish_reason:
405412
case FinishReason.UnclosedStream:
@@ -430,6 +437,10 @@ def map_chunk(chunk: InferenceEngineChunk):
430437
# The generation is complete. Store it.
431438
# TODO: InferD should store this so that we don't have to.
432439
# TODO: capture InferD request input instead of our manifestation of the prompt format
440+
441+
gen = time_ns() - start_gen
442+
gen //= 1000000
443+
433444
prompt = create_prompt_from_engine_input(chain)
434445
output, logprobs = create_output_from_chunks(chunks)
435446

0 commit comments

Comments
 (0)