You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
NOTE(lry89757) We use rpyc to transport param between client and server.
36
+
Rpyc only support the type of `POD` in python as the param, so we should take some smart ways to transport the data like tensor or some sophisticated classes.
37
+
Drawing on the logic of `__setstate__`, `__getstate__`, we will let some classes(will be rpc param later) inherit this base class, and rewrite the to_rpc_param and from_rpc_param. We will invoke `to_rpc_param` in client to pass the params and recover the param in server side by `from_rpc_param`.
38
+
"""
39
+
40
+
@abstractmethod
41
+
defto_rpc_param(self):
42
+
returnNotImplementedError
43
+
44
+
@staticmethod
45
+
@abstractmethod
46
+
deffrom_rpc_param():
47
+
returnNotImplementedError
48
+
49
+
33
50
@dataclass
34
-
classInputMetaData:
51
+
classInputMetaData(RPC_PARAM):
35
52
"""The input info for a single step
36
53
37
54
Args:
@@ -48,6 +65,7 @@ class InputMetaData:
48
65
dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32.
49
66
use_spec_dec (bool): Indicate whether to use speculative decoding.
50
67
num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True.
68
+
batch_token_ids (List[List[int]], optional): input_token_ids + output_token_ids of current batch. Only used for `repetition_penalty`, `no_repeat_ngram_size` in sampler process.
51
69
"""
52
70
53
71
block_tables: torch.Tensor=None
@@ -63,6 +81,54 @@ class InputMetaData:
63
81
dtype: torch.dtype=torch.float32
64
82
use_spec_dec: bool=False
65
83
num_tokens_to_verify: int=0
84
+
batch_token_ids: Optional[
85
+
List[List[int]]
86
+
] =None# for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
0 commit comments