Skip to content

Commit d47ed3e

Browse files
committed
load chat_template from file
1 parent 2ad822d commit d47ed3e

File tree

5 files changed

+67
-15
lines changed

5 files changed

+67
-15
lines changed

mlx-lm-utils/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ license.workspace = true
1010
documentation.workspace = true
1111

1212
[dependencies]
13-
minijinja = "2"
13+
minijinja = { version = "2", features = ["loader"] }
1414
serde = { version = "1", features = ["derive"] }
15+
serde_json = "1"
1516
thiserror = "2"
1617
tokenizers = "0.21"

mlx-lm-utils/src/tokenizer.rs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
// set, will return a dict of tokenizer outputs instead.
6666
// """
6767

68-
use std::{borrow::Cow, collections::HashMap};
68+
use std::{borrow::Cow, collections::HashMap, fs::read_to_string, path::Path};
6969

7070
use minijinja::{context, Environment, Template};
7171
use serde::{Deserialize, Serialize};
@@ -74,7 +74,7 @@ use crate::error::Error;
7474

7575
#[derive(Serialize)]
7676
#[serde(untagged)]
77-
pub enum Content<T: Serialize=()> {
77+
pub enum Content<T: Serialize = ()> {
7878
String(String),
7979
Map(HashMap<String, String>),
8080
Typed(T),
@@ -115,6 +115,21 @@ pub struct TokenizeOptions {
115115
pub return_assistant_tokens_mask: Option<bool>,
116116
}
117117

118+
pub fn load_chat_template_from_str(content: &str) -> std::io::Result<Option<String>> {
119+
serde_json::from_str::<serde_json::Value>(content).map(|value| {
120+
value
121+
.get("chat_template")
122+
.and_then(|value| value.as_str())
123+
.map(ToString::to_string)
124+
})
125+
.map_err(Into::into)
126+
}
127+
128+
pub fn load_chat_template_from_file(file: impl AsRef<Path>) -> std::io::Result<Option<String>> {
129+
let content = read_to_string(file)?;
130+
load_chat_template_from_str(&content)
131+
}
132+
118133
// chat_template = self.get_chat_template(chat_template, tools)
119134

120135
// if isinstance(conversation, (list, tuple)) and (
@@ -192,7 +207,6 @@ pub struct TokenizeOptions {
192207
// else:
193208
// return rendered_chat
194209

195-
196210
// def render_jinja_template(
197211
// conversations: list[list[dict[str, str]]],
198212
// tools: Optional[list[Union[dict, Callable]]] = None,
@@ -286,7 +300,6 @@ pub struct TokenizeOptions {
286300

287301
// return rendered, all_generation_indices
288302

289-
290303
pub fn apply_chat_template<'a>(
291304
env: &'a mut Environment<'a>,
292305
model_template: &'a str,
@@ -311,19 +324,20 @@ pub fn apply_chat_template<'a>(
311324
Ok(template) => template,
312325
Err(_) => {
313326
env.add_template(model_id, model_template)?;
314-
env.get_template(model_id).expect("Newly added template must be present")
315-
},
327+
env.get_template(model_id)
328+
.expect("Newly added template must be present")
329+
}
316330
},
317331
};
318332

319333
// TODO: what about list of list of conversations
320-
334+
321335
// TODO: handle tool
322336

323337
// TODO: handle documents``
324338

325339
// TODO: allow return_generation_indices
326-
340+
327341
let rendered_chat = template.render(context! {
328342
messages => conversations,
329343
documents => documents,

mlx-lm/src/utils/mod.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ fn index_out_of_bound_exception() -> Exception {
5858

5959
#[allow(non_snake_case)]
6060
pub(crate) fn quantized_scaled_dot_product_attention(
61-
queries: &Array,
61+
queries: Array,
6262
mut q_keys: QuantizedKeys,
6363
mut q_values: QuantizedValues,
6464
scale: f32,
@@ -97,8 +97,8 @@ pub(crate) fn quantized_scaled_dot_product_attention(
9797
// TODO: handle str type mask
9898

9999
if mask.dtype() == Dtype::Bool {
100-
// scores = mlx_rs::ops::r#where(mask, scores, b)
101-
todo!("need to add finfo.min equivalent to Dtype")
100+
let finfo_min = scores.dtype().finfo_min()?;
101+
scores = mlx_rs::ops::r#where(mask, scores, Array::from_f64(finfo_min))?;
102102
} else {
103103
scores += mask;
104104
}
@@ -113,8 +113,6 @@ pub(crate) fn quantized_scaled_dot_product_attention(
113113
Ok(out)
114114
}
115115

116-
// type QuantizedKeyValue = (Array, Array, Array);
117-
118116
pub struct QuantizedKeys {
119117
pub keys: Array,
120118
pub scales: Array,
@@ -162,7 +160,7 @@ impl From<QuantizedValues> for MaybeQuantizedValues {
162160
}
163161

164162
pub(crate) fn scaled_dot_product_attention<C>(
165-
queries: &Array,
163+
queries: Array,
166164
keys: impl Into<MaybeQuantizedKeys>,
167165
values: impl Into<MaybeQuantizedValues>,
168166
cache: Option<C>,

mlx-rs/src/dtype.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
use half::{bf16, f16};
12
use mlx_internal_macros::generate_test_cases;
23
use strum::EnumIter;
34

5+
use crate::error::{Exception, InexactDtypeError};
6+
47
generate_test_cases! {
58
/// Array element type
69
#[derive(
@@ -83,6 +86,30 @@ impl Dtype {
8386
pub fn from_promoting_types(a: Dtype, b: Dtype) -> Self {
8487
a.promote_with(b)
8588
}
89+
90+
/// Minimum value of the float point types. Returns `Err(_)` if the type is not
91+
/// float point
92+
pub fn finfo_min(&self) -> Result<f64, InexactDtypeError> {
93+
match self {
94+
Dtype::Float16 => Ok(f16::MIN.to_f64_const()),
95+
Dtype::Float32 => Ok(f32::MIN as f64),
96+
Dtype::Complex64 => Ok(f32::MIN as f64),
97+
Dtype::Bfloat16 => Ok(bf16::MIN.to_f64_const()),
98+
_ => Err(InexactDtypeError(*self))
99+
}
100+
}
101+
102+
/// Maximum value of the float point types. Returns `Err(_)` if the type is not
103+
/// float point
104+
pub fn finfo_max(&self) -> Result<f64, InexactDtypeError> {
105+
match self {
106+
Dtype::Float16 => Ok(f16::MAX.to_f64_const()),
107+
Dtype::Float32 => Ok(f32::MAX as f64),
108+
Dtype::Complex64 => Ok(f32::MAX as f64),
109+
Dtype::Bfloat16 => Ok(bf16::MAX.to_f64_const()),
110+
_ => Err(InexactDtypeError(*self))
111+
}
112+
}
86113
}
87114

88115
pub(crate) trait TypePromotion {

mlx-rs/src/error.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,18 @@ impl From<MultiHeadAttentionBuildError> for TransformerBulidError {
378378
}
379379
}
380380

381+
/// The dtype is not a float-point type
382+
#[derive(Debug, Error)]
383+
#[error("[finfo] dtype {:?} is not inexact", .0)]
384+
pub struct InexactDtypeError(pub Dtype);
385+
386+
impl From<InexactDtypeError> for Exception {
387+
#[track_caller]
388+
fn from(value: InexactDtypeError) -> Self {
389+
Exception::custom(value.to_string())
390+
}
391+
}
392+
381393
#[cfg(test)]
382394
mod tests {
383395
use crate::array;

0 commit comments

Comments
 (0)