@@ -17,6 +17,77 @@ namespace torch {
1717namespace executor {
1818
1919namespace native {
20+ Tensor& sdpa_with_kv_cache_out_no_context (
21+ const Tensor& q_projected,
22+ const Tensor& k_projected,
23+ const Tensor& v_projected,
24+ Tensor& key_cache,
25+ Tensor& value_cache,
26+ const int64_t start_pos,
27+ const int64_t seq_len,
28+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
29+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
30+ const optional<Tensor> attn_mask,
31+ const double dropout_p,
32+ const bool is_causal,
33+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
34+ const optional<double > scale,
35+ Tensor& output);
36+
37+ at::Tensor sdpa_with_kv_cache_aten (
38+ const at::Tensor& q_projected,
39+ const at::Tensor& k_projected,
40+ const at::Tensor& v_projected,
41+ at::Tensor& key_cache,
42+ at::Tensor& value_cache,
43+ const int64_t start_pos,
44+ const int64_t seq_len,
45+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
46+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
47+ const std::optional<at::Tensor> attn_mask,
48+ const double dropout_p,
49+ const bool is_causal,
50+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
51+ const std::optional<double > scale);
52+
53+ Tensor& custom_sdpa_out_no_context (
54+ const Tensor& q,
55+ const Tensor& k,
56+ const Tensor& v,
57+ const int64_t start_pos,
58+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
59+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
60+ const optional<Tensor> attn_mask,
61+ const double dropout_p,
62+ const bool is_causal,
63+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
64+ const optional<double > scale,
65+ Tensor& output);
66+
67+ at::Tensor custom_sdpa_aten (
68+ const at::Tensor& q,
69+ const at::Tensor& k,
70+ const at::Tensor& v,
71+ const int64_t start_pos,
72+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
73+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
74+ const std::optional<at::Tensor> attn_mask,
75+ const double dropout_p,
76+ const bool is_causal,
77+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
78+ const std::optional<double > scale);
79+
80+ Tensor& update_cache_out_no_context (
81+ const Tensor& value,
82+ Tensor& cache,
83+ const int64_t start_pos,
84+ Tensor& output);
85+
86+ at::Tensor update_cache_aten (
87+ const at::Tensor& value,
88+ at::Tensor& cache,
89+ const int64_t start_pos);
90+
2091Tensor& sdpa_with_kv_cache_out_no_context (
2192 const Tensor& q_projected,
2293 const Tensor& k_projected,
0 commit comments