Skip to content

Commit eb651c8

Browse files
authored
add clear kv cache to quantized qwen3 weights (#3189)
1 parent 549eacb commit eb651c8

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

candle-transformers/src/models/quantized_qwen3.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ impl AttentionWeights {
233233
.reshape((b, l, self.num_heads * self.head_dim))?;
234234
self.o_proj.forward(&reshaped_ctx)
235235
}
236+
237+
fn clear_kv_cache(&mut self) {
238+
self.kv_cache.reset();
239+
}
236240
}
237241

238242
#[derive(Debug, Clone)]
@@ -283,6 +287,10 @@ impl LayerWeights {
283287
let h2 = h2.apply(&self.mlp)?;
284288
x + h2
285289
}
290+
291+
fn clear_kv_cache(&mut self) {
292+
self.self_attn.clear_kv_cache();
293+
}
286294
}
287295

288296
#[derive(Debug, Clone)]
@@ -416,4 +424,10 @@ impl ModelWeights {
416424
let last_hidden = h.narrow(1, l - 1, 1)?;
417425
self.lm_head.forward(&last_hidden)?.squeeze(1)
418426
}
427+
428+
pub fn clear_kv_cache(&mut self) {
429+
for layer in &mut self.layers {
430+
layer.clear_kv_cache();
431+
}
432+
}
419433
}

0 commit comments

Comments
 (0)