1- from dataclasses import dataclass
1+ from abc import ABC , abstractmethod
22from typing import List , Optional
33
44from graphgen .bases import Token
55
66
7- @dataclass
8- class TopkTokenModel :
9- do_sample : bool = False
10- temperature : float = 0
11- max_tokens : int = 4096
12- repetition_penalty : float = 1.05
13- num_beams : int = 1
14- topk : int = 50
15- topp : float = 0.95
16-
17- topk_per_token : int = 5 # number of topk tokens to generate for each token
7+ class TopkTokenModel (ABC ):
8+ def __init__ (
9+ self ,
10+ do_sample : bool = False ,
11+ temperature : float = 0 ,
12+ max_tokens : int = 4096 ,
13+ repetition_penalty : float = 1.05 ,
14+ num_beams : int = 1 ,
15+ topk : int = 50 ,
16+ topp : float = 0.95 ,
17+ topk_per_token : int = 5 ,
18+ ):
19+ self .do_sample = do_sample
20+ self .temperature = temperature
21+ self .max_tokens = max_tokens
22+ self .repetition_penalty = repetition_penalty
23+ self .num_beams = num_beams
24+ self .topk = topk
25+ self .topp = topp
26+ self .topk_per_token = topk_per_token
1827
28+ @abstractmethod
1929 async def generate_topk_per_token (self , text : str ) -> List [Token ]:
2030 """
2131 Generate prob, text and candidates for each token of the model's output.
2232 This function is used to visualize the inference process.
2333 """
2434 raise NotImplementedError
2535
36+ @abstractmethod
2637 async def generate_inputs_prob (
2738 self , text : str , history : Optional [List [str ]] = None
2839 ) -> List [Token ]:
@@ -32,6 +43,7 @@ async def generate_inputs_prob(
3243 """
3344 raise NotImplementedError
3445
46+ @abstractmethod
3547 async def generate_answer (
3648 self , text : str , history : Optional [List [str ]] = None
3749 ) -> str :
0 commit comments