@@ -7,7 +7,10 @@ extern crate accelerate_src;
77use anyhow:: { Error as E , Result } ;
88use clap:: Parser ;
99
10- use candle_transformers:: models:: helium:: { Config , Model } ;
10+ use candle_transformers:: models:: helium:: { Config as ConfigPreview , Model as ModelPreview } ;
11+ use candle_transformers:: models:: llama:: {
12+ Cache as CacheV1 , Llama as ModelV1 , LlamaConfig as ConfigV1 , LlamaEosToks ,
13+ } ;
1114
1215use candle:: { DType , Device , Tensor } ;
1316use candle_examples:: token_output_stream:: TokenOutputStream ;
@@ -16,6 +19,44 @@ use candle_transformers::generation::{LogitsProcessor, Sampling};
1619use hf_hub:: { api:: sync:: Api , Repo , RepoType } ;
1720use tokenizers:: Tokenizer ;
1821
22+ #[ derive( Debug , Clone ) ]
23+ enum Model {
24+ V1 { model : ModelV1 , cache : CacheV1 } ,
25+ Preview ( ModelPreview ) ,
26+ }
27+
28+ impl Model {
29+ fn forward ( & mut self , input : & Tensor , start_pos : usize ) -> Result < Tensor > {
30+ let model = match self {
31+ Model :: V1 { model, cache } => model. forward ( input, start_pos, cache) ?,
32+ Model :: Preview ( m) => m. forward ( input, start_pos) ?,
33+ } ;
34+ Ok ( model)
35+ }
36+ }
37+
38+ #[ derive( Debug , Clone ) ]
39+ enum Config {
40+ V1 ( ConfigV1 ) ,
41+ Preview ( ConfigPreview ) ,
42+ }
43+
44+ impl Config {
45+ fn bos_token_id ( & self ) -> Option < u32 > {
46+ match self {
47+ Config :: V1 ( c) => c. bos_token_id ,
48+ Config :: Preview ( c) => Some ( c. bos_token_id ) ,
49+ }
50+ }
51+
52+ fn eos_token_id ( & self ) -> Option < LlamaEosToks > {
53+ match self {
54+ Config :: V1 ( c) => c. eos_token_id . clone ( ) ,
55+ Config :: Preview ( c) => Some ( LlamaEosToks :: Single ( c. eos_token_id ) ) ,
56+ }
57+ }
58+ }
59+
1960struct TextGeneration {
2061 model : Model ,
2162 device : Device ,
@@ -106,7 +147,15 @@ impl TextGeneration {
106147 let next_token = self . logits_processor . sample ( & logits) ?;
107148 tokens. push ( next_token) ;
108149 generated_tokens += 1 ;
109- if next_token == self . config . bos_token_id || next_token == self . config . eos_token_id {
150+ let is_eos = self
151+ . config
152+ . eos_token_id ( )
153+ . as_ref ( )
154+ . is_some_and ( |v| match v {
155+ LlamaEosToks :: Single ( eos) => * eos == next_token,
156+ LlamaEosToks :: Multiple ( eos) => eos. contains ( & next_token) ,
157+ } ) ;
158+ if Some ( next_token) == self . config . bos_token_id ( ) || is_eos {
110159 break ;
111160 }
112161 if let Some ( t) = self . tokenizer . next_token ( next_token) ? {
@@ -131,6 +180,8 @@ impl TextGeneration {
131180enum Which {
132181 #[ value( name = "v1-preview" ) ]
133182 V1Preview ,
183+ #[ value( name = "v1" ) ]
184+ V1 ,
134185}
135186
136187#[ derive( Parser , Debug ) ]
@@ -144,9 +195,6 @@ struct Args {
144195 #[ arg( long) ]
145196 tracing : bool ,
146197
147- #[ arg( long) ]
148- use_flash_attn : bool ,
149-
150198 #[ arg( long) ]
151199 prompt : String ,
152200
@@ -171,7 +219,7 @@ struct Args {
171219 sample_len : usize ,
172220
173221 /// The model size to use.
174- #[ arg( long, default_value = "v1-preview " ) ]
222+ #[ arg( long, default_value = "v1" ) ]
175223 which : Which ,
176224
177225 #[ arg( long) ]
@@ -230,6 +278,7 @@ fn main() -> Result<()> {
230278 None => {
231279 let name = match args. which {
232280 Which :: V1Preview => "kyutai/helium-1-preview-2b" ,
281+ Which :: V1 => "kyutai/helium-1-2b" ,
233282 } ;
234283 name. to_string ( )
235284 }
@@ -254,18 +303,27 @@ fn main() -> Result<()> {
254303 let tokenizer = Tokenizer :: from_file ( tokenizer_filename) . map_err ( E :: msg) ?;
255304
256305 let start = std:: time:: Instant :: now ( ) ;
257- let config: Config = match args. config {
258- Some ( config_file) => serde_json:: from_slice ( & std:: fs:: read ( config_file) ?) ?,
259- None => {
260- let config_file = repo. get ( "config.json" ) ?;
261- serde_json:: from_slice ( & std:: fs:: read ( config_file) ?) ?
262- }
306+ let config_file = match args. config {
307+ Some ( config_file) => std:: path:: PathBuf :: from ( config_file) ,
308+ None => repo. get ( "config.json" ) ?,
309+ } ;
310+ let config = match args. which {
311+ Which :: V1Preview => Config :: Preview ( serde_json:: from_slice ( & std:: fs:: read ( config_file) ?) ?) ,
312+ Which :: V1 => Config :: V1 ( serde_json:: from_slice ( & std:: fs:: read ( config_file) ?) ?) ,
263313 } ;
264314 let device = candle_examples:: device ( args. cpu ) ?;
265315 let ( model, device) = {
266316 let dtype = device. bf16_default_to_f32 ( ) ;
267317 let vb = unsafe { VarBuilder :: from_mmaped_safetensors ( & filenames, dtype, & device) ? } ;
268- let model = Model :: new ( & config, vb) ?;
318+ let model = match & config {
319+ Config :: V1 ( c) => {
320+ let c = c. clone ( ) . into_config ( false ) ;
321+ let model = ModelV1 :: load ( vb, & c) ?;
322+ let cache = CacheV1 :: new ( true , dtype, & c, & device) ?;
323+ Model :: V1 { model, cache }
324+ }
325+ Config :: Preview ( c) => Model :: Preview ( ModelPreview :: new ( c, vb) ?) ,
326+ } ;
269327 ( model, device)
270328 } ;
271329
0 commit comments