@@ -17,6 +17,77 @@ namespace torch {
1717namespace  executor  {
1818
1919namespace  native  {
20+ Tensor& sdpa_with_kv_cache_out_no_context (
21+     const  Tensor& q_projected,
22+     const  Tensor& k_projected,
23+     const  Tensor& v_projected,
24+     Tensor& key_cache,
25+     Tensor& value_cache,
26+     const  int64_t  start_pos,
27+     const  int64_t  seq_len,
28+     //  @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
29+     //  @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
30+     const  optional<Tensor> attn_mask,
31+     const  double  dropout_p,
32+     const  bool  is_causal,
33+     //  @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
34+     const  optional<double > scale,
35+     Tensor& output);
36+ 
37+ at::Tensor sdpa_with_kv_cache_aten (
38+     const  at::Tensor& q_projected,
39+     const  at::Tensor& k_projected,
40+     const  at::Tensor& v_projected,
41+     at::Tensor& key_cache,
42+     at::Tensor& value_cache,
43+     const  int64_t  start_pos,
44+     const  int64_t  seq_len,
45+     //  @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
46+     //  @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
47+     const  std::optional<at::Tensor> attn_mask,
48+     const  double  dropout_p,
49+     const  bool  is_causal,
50+     //  @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
51+     const  std::optional<double > scale);
52+ 
53+ Tensor& custom_sdpa_out_no_context (
54+     const  Tensor& q,
55+     const  Tensor& k,
56+     const  Tensor& v,
57+     const  int64_t  start_pos,
58+     //  @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
59+     //  @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
60+     const  optional<Tensor> attn_mask,
61+     const  double  dropout_p,
62+     const  bool  is_causal,
63+     //  @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
64+     const  optional<double > scale,
65+     Tensor& output);
66+ 
67+ at::Tensor custom_sdpa_aten (
68+     const  at::Tensor& q,
69+     const  at::Tensor& k,
70+     const  at::Tensor& v,
71+     const  int64_t  start_pos,
72+     //  @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
73+     //  @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
74+     const  std::optional<at::Tensor> attn_mask,
75+     const  double  dropout_p,
76+     const  bool  is_causal,
77+     //  @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
78+     const  std::optional<double > scale);
79+ 
80+ Tensor& update_cache_out_no_context (
81+     const  Tensor& value,
82+     Tensor& cache,
83+     const  int64_t  start_pos,
84+     Tensor& output);
85+ 
86+ at::Tensor update_cache_aten (
87+     const  at::Tensor& value,
88+     at::Tensor& cache,
89+     const  int64_t  start_pos);
90+ 
2091Tensor& sdpa_with_kv_cache_out_no_context (
2192    const  Tensor& q_projected,
2293    const  Tensor& k_projected,
0 commit comments