@@ -24,6 +24,14 @@ impl Debug for LlamaSampler {
2424}
2525
2626impl LlamaSampler {
27+ /// Create a new `LlamaSampler`.
28+ /// ```
29+ /// # use llama_cpp_2::sampler_chain::{LlamaSampler, params::LlamaSamplerChainParams};
30+ /// let mut chain = LlamaSampler::new(LlamaSamplerChainParams::default());
31+ /// chain = chain.add_temp(0.7);
32+ /// chain = chain.add_dist(42);
33+ /// assert_eq!(chain.len(), 2);
34+ /// ```
2735 pub fn new ( sampler_chain_params : params:: LlamaSamplerChainParams ) -> Self {
2836 let sampler = unsafe {
2937 NonNull :: new ( llama_cpp_sys_2:: llama_sampler_chain_init (
@@ -120,22 +128,6 @@ impl LlamaSampler {
120128 self
121129 }
122130
123- /// Initialize a tail-free sampler with the given z value and add it to the sampler chain.
124- ///
125- /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
126- pub fn add_tail_free ( self , z : f32 , min_keep : usize ) -> Self {
127- unsafe {
128- let tail_free_sampler =
129- NonNull :: new ( llama_cpp_sys_2:: llama_sampler_init_tail_free ( z, min_keep) )
130- . expect ( "llama_sampler_chain_init_tail_free returned null" ) ;
131- llama_cpp_sys_2:: llama_sampler_chain_add (
132- self . sampler . as_ptr ( ) ,
133- tail_free_sampler. as_ptr ( ) ,
134- ) ;
135- }
136- self
137- }
138-
139131 /// Initialize a typical-p sampler with the given value and add it to the sampler chain.
140132 pub fn add_typical_p ( self , p : f32 , min_keep : usize ) -> Self {
141133 unsafe {
@@ -209,6 +201,23 @@ impl LlamaSampler {
209201 self
210202 }
211203
204+ /// Initialize an XTC sampler with the given values and add it to the sampler chain.
205+ pub fn add_xtc ( self , p : f32 , t : f32 , min_keep : usize , seed : u32 ) -> Self {
206+ unsafe {
207+ let xtc_sampler = NonNull :: new ( llama_cpp_sys_2:: llama_sampler_init_xtc (
208+ p, t, min_keep, seed,
209+ ) )
210+ . expect ( "llama_sampler_chain_init_xtc returned null" ) ;
211+ llama_cpp_sys_2:: llama_sampler_chain_add ( self . sampler . as_ptr ( ) , xtc_sampler. as_ptr ( ) ) ;
212+ }
213+ self
214+ }
215+
216+ /// Get the number of samplers in the chain.
217+ pub fn len ( & self ) -> i32 {
218+ unsafe { llama_cpp_sys_2:: llama_sampler_chain_n ( self . sampler . as_ptr ( ) ) }
219+ }
220+
212221 /// Reset the sampler chain.
213222 pub fn reset ( & self ) {
214223 unsafe {
@@ -225,13 +234,6 @@ impl LlamaSampler {
225234 LlamaToken ( token)
226235 }
227236
228- /// Accept a sampled token.
229- pub fn accept ( & self , token : LlamaToken ) {
230- unsafe {
231- llama_cpp_sys_2:: llama_sampler_accept ( self . sampler . as_ptr ( ) , token. 0 ) ;
232- }
233- }
234-
235237 /// Reset the timings for the sampler.
236238 pub fn reset_timings ( & self ) {
237239 unsafe {
0 commit comments