Skip to content

Commit 38fc866

Browse files
Add support for Helium-v1. (#2932)
1 parent 5029ac5 commit 38fc866

File tree

1 file changed

+71
-13
lines changed
  • candle-examples/examples/helium

1 file changed

+71
-13
lines changed

candle-examples/examples/helium/main.rs

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ extern crate accelerate_src;
77
use anyhow::{Error as E, Result};
88
use clap::Parser;
99

10-
use candle_transformers::models::helium::{Config, Model};
10+
use candle_transformers::models::helium::{Config as ConfigPreview, Model as ModelPreview};
11+
use candle_transformers::models::llama::{
12+
Cache as CacheV1, Llama as ModelV1, LlamaConfig as ConfigV1, LlamaEosToks,
13+
};
1114

1215
use candle::{DType, Device, Tensor};
1316
use candle_examples::token_output_stream::TokenOutputStream;
@@ -16,6 +19,44 @@ use candle_transformers::generation::{LogitsProcessor, Sampling};
1619
use hf_hub::{api::sync::Api, Repo, RepoType};
1720
use tokenizers::Tokenizer;
1821

22+
#[derive(Debug, Clone)]
23+
enum Model {
24+
V1 { model: ModelV1, cache: CacheV1 },
25+
Preview(ModelPreview),
26+
}
27+
28+
impl Model {
29+
fn forward(&mut self, input: &Tensor, start_pos: usize) -> Result<Tensor> {
30+
let model = match self {
31+
Model::V1 { model, cache } => model.forward(input, start_pos, cache)?,
32+
Model::Preview(m) => m.forward(input, start_pos)?,
33+
};
34+
Ok(model)
35+
}
36+
}
37+
38+
#[derive(Debug, Clone)]
39+
enum Config {
40+
V1(ConfigV1),
41+
Preview(ConfigPreview),
42+
}
43+
44+
impl Config {
45+
fn bos_token_id(&self) -> Option<u32> {
46+
match self {
47+
Config::V1(c) => c.bos_token_id,
48+
Config::Preview(c) => Some(c.bos_token_id),
49+
}
50+
}
51+
52+
fn eos_token_id(&self) -> Option<LlamaEosToks> {
53+
match self {
54+
Config::V1(c) => c.eos_token_id.clone(),
55+
Config::Preview(c) => Some(LlamaEosToks::Single(c.eos_token_id)),
56+
}
57+
}
58+
}
59+
1960
struct TextGeneration {
2061
model: Model,
2162
device: Device,
@@ -106,7 +147,15 @@ impl TextGeneration {
106147
let next_token = self.logits_processor.sample(&logits)?;
107148
tokens.push(next_token);
108149
generated_tokens += 1;
109-
if next_token == self.config.bos_token_id || next_token == self.config.eos_token_id {
150+
let is_eos = self
151+
.config
152+
.eos_token_id()
153+
.as_ref()
154+
.is_some_and(|v| match v {
155+
LlamaEosToks::Single(eos) => *eos == next_token,
156+
LlamaEosToks::Multiple(eos) => eos.contains(&next_token),
157+
});
158+
if Some(next_token) == self.config.bos_token_id() || is_eos {
110159
break;
111160
}
112161
if let Some(t) = self.tokenizer.next_token(next_token)? {
@@ -131,6 +180,8 @@ impl TextGeneration {
131180
enum Which {
132181
#[value(name = "v1-preview")]
133182
V1Preview,
183+
#[value(name = "v1")]
184+
V1,
134185
}
135186

136187
#[derive(Parser, Debug)]
@@ -144,9 +195,6 @@ struct Args {
144195
#[arg(long)]
145196
tracing: bool,
146197

147-
#[arg(long)]
148-
use_flash_attn: bool,
149-
150198
#[arg(long)]
151199
prompt: String,
152200

@@ -171,7 +219,7 @@ struct Args {
171219
sample_len: usize,
172220

173221
/// The model size to use.
174-
#[arg(long, default_value = "v1-preview")]
222+
#[arg(long, default_value = "v1")]
175223
which: Which,
176224

177225
#[arg(long)]
@@ -230,6 +278,7 @@ fn main() -> Result<()> {
230278
None => {
231279
let name = match args.which {
232280
Which::V1Preview => "kyutai/helium-1-preview-2b",
281+
Which::V1 => "kyutai/helium-1-2b",
233282
};
234283
name.to_string()
235284
}
@@ -254,18 +303,27 @@ fn main() -> Result<()> {
254303
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
255304

256305
let start = std::time::Instant::now();
257-
let config: Config = match args.config {
258-
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
259-
None => {
260-
let config_file = repo.get("config.json")?;
261-
serde_json::from_slice(&std::fs::read(config_file)?)?
262-
}
306+
let config_file = match args.config {
307+
Some(config_file) => std::path::PathBuf::from(config_file),
308+
None => repo.get("config.json")?,
309+
};
310+
let config = match args.which {
311+
Which::V1Preview => Config::Preview(serde_json::from_slice(&std::fs::read(config_file)?)?),
312+
Which::V1 => Config::V1(serde_json::from_slice(&std::fs::read(config_file)?)?),
263313
};
264314
let device = candle_examples::device(args.cpu)?;
265315
let (model, device) = {
266316
let dtype = device.bf16_default_to_f32();
267317
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
268-
let model = Model::new(&config, vb)?;
318+
let model = match &config {
319+
Config::V1(c) => {
320+
let c = c.clone().into_config(false);
321+
let model = ModelV1::load(vb, &c)?;
322+
let cache = CacheV1::new(true, dtype, &c, &device)?;
323+
Model::V1 { model, cache }
324+
}
325+
Config::Preview(c) => Model::Preview(ModelPreview::new(c, vb)?),
326+
};
269327
(model, device)
270328
};
271329

0 commit comments

Comments
 (0)