Skip to content

Commit 8205914

Browse files
committed
quant-qwen leverage concatKV; add 8_0 to example main
1 parent 8b9e6b2 commit 8205914

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

candle-examples/examples/quantized-qwen3/main.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ const DEFAULT_PROMPT: &str = "Write a Rust function to calculate the factorial o
2121
enum Which {
2222
#[value(name = "0.6b")]
2323
W3_0_6b,
24+
#[value(name = "0.6b8_0")]
25+
W3_0_6b8_0,
2426
#[value(name = "1.7b")]
2527
W3_1_7b,
2628
#[value(name = "4b")]
@@ -103,6 +105,7 @@ impl Args {
103105
let api = hf_hub::api::sync::Api::new()?;
104106
let repo = match self.which {
105107
Which::W3_0_6b => "Qwen/Qwen3-0.6B",
108+
Which::W3_0_6b8_0 => "Qwen/Qwen3-0.6B",
106109
Which::W3_1_7b => "Qwen/Qwen3-1.7B",
107110
Which::W3_4b => "Qwen/Qwen3-4B",
108111
Which::W3_8b => "Qwen/Qwen3-8B",
@@ -122,6 +125,7 @@ impl Args {
122125
None => {
123126
let (repo, filename, revision) = match self.which {
124127
Which::W3_0_6b => ("unsloth/Qwen3-0.6B-GGUF", "Qwen3-0.6B-Q4_K_M.gguf", "main"),
128+
Which::W3_0_6b8_0 => ("unsloth/Qwen3-0.6B-GGUF", "Qwen3-0.6B-Q8_0.gguf", "main"),
125129
Which::W3_1_7b => ("unsloth/Qwen3-1.7B-GGUF", "Qwen3-1.7B-Q4_K_M.gguf", "main"),
126130
Which::W3_4b => ("unsloth/Qwen3-4B-GGUF", "Qwen3-4B-Q4_K_M.gguf", "main"),
127131
Which::W3_8b => ("unsloth/Qwen3-8B-GGUF", "Qwen3-8B-Q4_K_M.gguf", "main"),

candle-transformers/src/models/quantized_qwen3.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use super::with_tracing::QMatMul;
1010
use crate::{quantized_nn::RmsNorm, utils::repeat_kv};
1111
use candle::quantized::{gguf_file, QTensor};
1212
use candle::{DType, Device, Result, Tensor};
13-
use candle_nn::{kv_cache::KvCache, Activation, Embedding, Module};
13+
use candle_nn::{kv_cache::ConcatKvCache, Activation, Embedding, Module};
1414
use std::io::{Read, Seek};
1515
use std::sync::Arc;
1616

@@ -136,7 +136,7 @@ struct AttentionWeights {
136136
num_kv_groups: usize,
137137
head_dim: usize,
138138
rotary_emb: Arc<RotaryEmbedding>,
139-
kv_cache: KvCache,
139+
kv_cache: ConcatKvCache,
140140
span_attn: tracing::Span,
141141
}
142142

@@ -160,9 +160,7 @@ impl AttentionWeights {
160160
let q_norm = gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?;
161161
let k_norm = gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?;
162162

163-
// Initialize KV cache with 512 tokens capacity to reduce initial memory allocation.
164-
// The cache will grow in chunks of 512 tokens when needed.
165-
let kv_cache = KvCache::new(2, 512);
163+
let kv_cache = ConcatKvCache::new(2);
166164

167165
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
168166

@@ -211,10 +209,6 @@ impl AttentionWeights {
211209

212210
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
213211

214-
// Reset KV cache if we're at the first position
215-
if offset == 0 {
216-
self.kv_cache.reset();
217-
}
218212
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
219213

220214
// Make tensor contiguous to avoid some strided copies

0 commit comments

Comments
 (0)