@@ -7,12 +7,18 @@ macro_rules! slice {
77
88mod gather;
99
10- use common:: utok;
10+ use common:: { f16 , utok} ;
1111use common_devices:: { Operators , SliceOn } ;
12+ use digit_layout:: types:: F16 ;
1213use operators:: {
13- fuesd_softmax:: common_cpu as softmax, mat_mul:: common_cpu as mat_mul,
14- reform:: common_cpu as reform, rms_norm:: common_cpu as rms_norm, rope:: common_cpu as rope,
15- swiglu:: common_cpu as swiglu, Operator , QueueOf ,
14+ fuesd_softmax:: common_cpu as softmax,
15+ mat_mul:: common_cpu as mat_mul,
16+ random_sample:: { common_cpu as random_sample, Args , KVPair , SampleArgs } ,
17+ reform:: common_cpu as reform,
18+ rms_norm:: common_cpu as rms_norm,
19+ rope:: common_cpu as rope,
20+ swiglu:: common_cpu as swiglu,
21+ Operator , QueueOf ,
1622} ;
1723use std:: ops:: { Deref , DerefMut } ;
1824use tensor:: Tensor ;
@@ -29,6 +35,23 @@ pub struct CpuKernels {
2935 rope : rope:: Operator ,
3036 softmax : softmax:: Operator ,
3137 swiglu : swiglu:: Operator ,
38+ sample : random_sample:: Operator ,
39+ }
40+
41+ impl CpuKernels {
42+ pub fn sample ( & self , temperature : f32 , top_p : f32 , top_k : usize , logits : & [ f16 ] ) -> utok {
43+ let mut kv_pair = KVPair :: new ( 0 , f16:: ZERO ) ;
44+ let mut args = Args :: < Cpu > :: new ( F16 , logits. len ( ) ) ;
45+ args. kv_pair_base = & mut kv_pair as * mut _ as _ ;
46+ args. data_base = logits. as_ptr ( ) as _ ;
47+ args. detail = SampleArgs {
48+ temperature,
49+ top_p,
50+ top_k,
51+ } ;
52+ self . sample . launch ( & args, & ThisThread ) . unwrap ( ) ;
53+ kv_pair. idx ( ) as _
54+ }
3255}
3356
3457impl Default for CpuKernels {
@@ -40,6 +63,7 @@ impl Default for CpuKernels {
4063 rope : rope:: Operator :: new ( & Cpu ) ,
4164 softmax : softmax:: Operator :: new ( & Cpu ) ,
4265 swiglu : swiglu:: Operator :: new ( & Cpu ) ,
66+ sample : random_sample:: Operator :: new ( & Cpu ) ,
4367 }
4468 }
4569}
0 commit comments