Skip to content

Commit f5998d9

Browse files
committed
Use std::optinal
Signed-off-by: cyy <[email protected]>
1 parent b56235c commit f5998d9

File tree

6 files changed

+92
-92
lines changed

6 files changed

+92
-92
lines changed

extension/kernel_util/test/make_boxed_from_unboxed_functor_test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ add_tensor_out(KernelRuntimeContext& ctx, ArrayRef<Tensor> a, Tensor& out) {
5454

5555
Tensor& add_optional_scalar_out(
5656
KernelRuntimeContext& ctx,
57-
optional<int64_t> s1,
58-
optional<int64_t> s2,
57+
std::optional<int64_t> s1,
58+
std::optional<int64_t> s2,
5959
Tensor& out) {
6060
(void)ctx;
6161
if (s1.has_value()) {
@@ -182,7 +182,7 @@ TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxOptionalArrayRef) {
182182

183183
// prepare optional tensors.
184184
torch::executor::testing::TensorFactory<ScalarType::Int> tf;
185-
optional<Tensor> storage[2];
185+
std::optional<Tensor> storage[2];
186186
EValue evalues[2] = {EValue(tf.ones({5})), EValue()};
187187
EValue* values_p[2] = {&evalues[0], &evalues[1]};
188188
BoxedEvalueList<optional<Tensor>> a_box(values_p, storage, 2);

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ bool validate_flash_attention_args(
3333
const Tensor& query,
3434
const Tensor& key,
3535
const Tensor& value,
36-
const optional<Tensor>& attn_mask) {
36+
const std::optional<Tensor>& attn_mask) {
3737
ET_CHECK_OR_RETURN_FALSE(query.dim() == 4, "query must be a 4D tensor");
3838
ET_CHECK_OR_RETURN_FALSE(key.dim() == 4, "key must be a 4D tensor");
3939
ET_CHECK_OR_RETURN_FALSE(value.dim() == 4, "value must be a 4D tensor");
@@ -245,11 +245,11 @@ Tensor& flash_attention_kernel_out(
245245
const Tensor& query,
246246
const Tensor& key,
247247
const Tensor& value,
248-
const optional<Tensor>& attn_mask,
248+
const std::optional<Tensor>& attn_mask,
249249
const double dropout_p,
250250
const bool is_causal,
251251
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
252-
const optional<double> scale,
252+
const std::optional<double> scale,
253253
Tensor& output) {
254254
(void)ctx;
255255
ET_KERNEL_CHECK(
@@ -281,12 +281,12 @@ Tensor& flash_attention_kernel_out(
281281
is_causal,
282282
attn_mask,
283283
scale,
284-
nullopt,
285-
nullopt,
286-
nullopt,
287-
nullopt,
288-
nullopt,
289-
nullopt);
284+
std::nullopt,
285+
std::nullopt,
286+
std::nullopt,
287+
std::nullopt,
288+
std::nullopt,
289+
std::nullopt);
290290
} else if (seq_len >= 192) {
291291
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
292292
output,
@@ -297,12 +297,12 @@ Tensor& flash_attention_kernel_out(
297297
is_causal,
298298
attn_mask,
299299
scale,
300-
nullopt,
301-
nullopt,
302-
nullopt,
303-
nullopt,
304-
nullopt,
305-
nullopt);
300+
std::nullopt,
301+
std::nullopt,
302+
std::nullopt,
303+
std::nullopt,
304+
std::nullopt,
305+
std::nullopt);
306306
} else {
307307
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
308308
output,
@@ -313,12 +313,12 @@ Tensor& flash_attention_kernel_out(
313313
is_causal,
314314
attn_mask,
315315
scale,
316-
nullopt,
317-
nullopt,
318-
nullopt,
319-
nullopt,
320-
nullopt,
321-
nullopt);
316+
std::nullopt,
317+
std::nullopt,
318+
std::nullopt,
319+
std::nullopt,
320+
std::nullopt,
321+
std::nullopt);
322322
}
323323
});
324324
return output;
@@ -330,18 +330,18 @@ Tensor& custom_sdpa_out_impl(
330330
const Tensor& k,
331331
const Tensor& v,
332332
const int64_t start_pos,
333-
const optional<Tensor>& attn_mask,
333+
const std::optional<Tensor>& attn_mask,
334334
const double dropout_p,
335335
const bool is_causal,
336336
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
337-
const optional<double> scale,
337+
const std::optional<double> scale,
338338
Tensor& output,
339-
const optional<Tensor>& q_zero_points = nullopt,
340-
const optional<Tensor>& q_scales = nullopt,
341-
const optional<Tensor>& k_zero_points = nullopt,
342-
const optional<Tensor>& k_scales = nullopt,
343-
const optional<Tensor>& v_zero_points = nullopt,
344-
const optional<Tensor>& v_scales = nullopt,
339+
const std::optional<Tensor>& q_zero_points = std::nullopt,
340+
const std::optional<Tensor>& q_scales = std::nullopt,
341+
const std::optional<Tensor>& k_zero_points = std::nullopt,
342+
const std::optional<Tensor>& k_scales = std::nullopt,
343+
const std::optional<Tensor>& v_zero_points = std::nullopt,
344+
const std::optional<Tensor>& v_scales = std::nullopt,
345345
bool is_seq_at_dim_2 = false) {
346346
ET_KERNEL_CHECK_MSG(
347347
ctx,
@@ -484,17 +484,17 @@ Tensor& custom_quantized_sdpa_out(
484484
const Tensor& k,
485485
const Tensor& v,
486486
const int64_t start_pos,
487-
const optional<Tensor>& attn_mask,
487+
const std::optional<Tensor>& attn_mask,
488488
const double dropout_p,
489489
const bool is_causal,
490490
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
491-
const optional<double> scale,
492-
const optional<Tensor>& q_zero_points,
493-
const optional<Tensor>& q_scales,
494-
const optional<Tensor>& k_zero_points,
495-
const optional<Tensor>& k_scales,
496-
const optional<Tensor>& v_zero_points,
497-
const optional<Tensor>& v_scales,
491+
const std::optional<double> scale,
492+
const std::optional<Tensor>& q_zero_points,
493+
const std::optional<Tensor>& q_scales,
494+
const std::optional<Tensor>& k_zero_points,
495+
const std::optional<Tensor>& k_scales,
496+
const std::optional<Tensor>& v_zero_points,
497+
const std::optional<Tensor>& v_scales,
498498
const bool is_seq_at_dim_2,
499499
Tensor& output) {
500500
return custom_sdpa_out_impl(
@@ -538,11 +538,11 @@ Tensor& custom_sdpa_out(
538538
const Tensor& k,
539539
const Tensor& v,
540540
const int64_t start_pos,
541-
const optional<Tensor>& attn_mask,
541+
const std::optional<Tensor>& attn_mask,
542542
const double dropout_p,
543543
const bool is_causal,
544544
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
545-
const optional<double> scale,
545+
const std::optional<double> scale,
546546
Tensor& output) {
547547
return custom_sdpa_out_impl(
548548
ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output);
@@ -572,11 +572,11 @@ Tensor& sdpa_with_kv_cache_out(
572572
Tensor& value_cache,
573573
const int64_t start_pos,
574574
const int64_t seq_len,
575-
const optional<Tensor>& attn_mask,
575+
const std::optional<Tensor>& attn_mask,
576576
const double dropout_p,
577577
const bool is_causal,
578578
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
579-
const optional<double> scale,
579+
const std::optional<double> scale,
580580
Tensor& output) {
581581
(void)ctx;
582582
ET_KERNEL_CHECK(

extension/llm/custom_ops/op_sdpa.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ Tensor& sdpa_with_kv_cache_out(
2424
Tensor& value_cache,
2525
const int64_t start_pos,
2626
const int64_t seq_len,
27-
const optional<Tensor>& attn_mask,
27+
const std::optional<Tensor>& attn_mask,
2828
const double dropout_p,
2929
const bool is_causal,
3030
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
31-
const optional<double> scale,
31+
const std::optional<double> scale,
3232
Tensor& output);
3333

3434
Tensor& custom_sdpa_out(
@@ -37,23 +37,23 @@ Tensor& custom_sdpa_out(
3737
const Tensor& k,
3838
const Tensor& v,
3939
const int64_t start_pos,
40-
const optional<Tensor>& attn_mask,
40+
const std::optional<Tensor>& attn_mask,
4141
const double dropout_p,
4242
const bool is_causal,
4343
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
44-
const optional<double> scale,
44+
const std::optional<double> scale,
4545
Tensor& output);
4646

4747
Tensor& flash_attention_kernel_out(
4848
KernelRuntimeContext& ctx,
4949
const Tensor& query,
5050
const Tensor& key,
5151
const Tensor& value,
52-
const optional<Tensor>& attn_mask,
52+
const std::optional<Tensor>& attn_mask,
5353
const double dropout_p,
5454
const bool is_causal,
5555
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
56-
const optional<double> scale,
56+
const std::optional<double> scale,
5757
Tensor& output);
5858

5959
Tensor& custom_quantized_sdpa_out(
@@ -62,17 +62,17 @@ Tensor& custom_quantized_sdpa_out(
6262
const Tensor& k,
6363
const Tensor& v,
6464
const int64_t start_pos,
65-
const optional<Tensor>& attn_mask,
65+
const std::optional<Tensor>& attn_mask,
6666
const double dropout_p,
6767
const bool is_causal,
6868
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
69-
const optional<double> scale,
70-
const optional<Tensor>& q_zero_points,
71-
const optional<Tensor>& q_scales,
72-
const optional<Tensor>& k_zero_points,
73-
const optional<Tensor>& k_scales,
74-
const optional<Tensor>& v_zero_points,
75-
const optional<Tensor>& v_scales,
69+
const std::optional<double> scale,
70+
const std::optional<Tensor>& q_zero_points,
71+
const std::optional<Tensor>& q_scales,
72+
const std::optional<Tensor>& k_zero_points,
73+
const std::optional<Tensor>& k_scales,
74+
const std::optional<Tensor>& v_zero_points,
75+
const std::optional<Tensor>& v_scales,
7676
const bool is_seq_at_dim_1,
7777
Tensor& output);
7878
} // namespace native

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ Tensor& sdpa_with_kv_cache_out_no_context(
2727
const int64_t seq_len,
2828
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
2929
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
30-
const optional<Tensor> attn_mask,
30+
const std::optional<Tensor> attn_mask,
3131
const double dropout_p,
3232
const bool is_causal,
3333
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
34-
const optional<double> scale,
34+
const std::optional<double> scale,
3535
Tensor& output);
3636

3737
at::Tensor sdpa_with_kv_cache_aten(
@@ -57,11 +57,11 @@ Tensor& custom_sdpa_out_no_context(
5757
const int64_t start_pos,
5858
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
5959
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
60-
const optional<Tensor> attn_mask,
60+
const std::optional<Tensor> attn_mask,
6161
const double dropout_p,
6262
const bool is_causal,
6363
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
64-
const optional<double> scale,
64+
const std::optional<double> scale,
6565
Tensor& output);
6666

6767
at::Tensor custom_sdpa_aten(
@@ -84,17 +84,17 @@ Tensor& custom_quantized_sdpa_out_no_context(
8484
const int64_t start_pos,
8585
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
8686
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
87-
const optional<Tensor> attn_mask,
87+
const std::optional<Tensor> attn_mask,
8888
const double dropout_p,
8989
const bool is_causal,
9090
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
91-
const optional<double> scale,
92-
const optional<Tensor> q_zero_points,
93-
const optional<Tensor> q_scales,
94-
const optional<Tensor> k_zero_points,
95-
const optional<Tensor> k_scales,
96-
const optional<Tensor> v_zero_points,
97-
const optional<Tensor> v_scales,
91+
const std::optional<double> scale,
92+
const std::optional<Tensor> q_zero_points,
93+
const std::optional<Tensor> q_scales,
94+
const std::optional<Tensor> k_zero_points,
95+
const std::optional<Tensor> k_scales,
96+
const std::optional<Tensor> v_zero_points,
97+
const std::optional<Tensor> v_scales,
9898
const bool is_seq_at_dim_2,
9999
Tensor& output);
100100

@@ -153,11 +153,11 @@ Tensor& sdpa_with_kv_cache_out_no_context(
153153
const int64_t seq_len,
154154
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
155155
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
156-
const optional<Tensor> attn_mask,
156+
const std::optional<Tensor> attn_mask,
157157
const double dropout_p,
158158
const bool is_causal,
159159
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
160-
const optional<double> scale,
160+
const std::optional<double> scale,
161161
Tensor& output) {
162162
executorch::runtime::KernelRuntimeContext context{};
163163
return torch::executor::native::sdpa_with_kv_cache_out(
@@ -215,11 +215,11 @@ Tensor& custom_sdpa_out_no_context(
215215
const int64_t start_pos,
216216
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
217217
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
218-
const optional<Tensor> attn_mask,
218+
const std::optional<Tensor> attn_mask,
219219
const double dropout_p,
220220
const bool is_causal,
221221
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
222-
const optional<double> scale,
222+
const std::optional<double> scale,
223223
Tensor& output) {
224224
executorch::aten::RuntimeContext context{};
225225
return torch::executor::native::custom_sdpa_out(
@@ -260,17 +260,17 @@ Tensor& custom_quantized_sdpa_out_no_context(
260260
const int64_t start_pos,
261261
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
262262
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
263-
const optional<Tensor> attn_mask,
263+
const std::optional<Tensor> attn_mask,
264264
const double dropout_p,
265265
const bool is_causal,
266266
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
267-
const optional<double> scale,
268-
const optional<Tensor> q_zero_points,
269-
const optional<Tensor> q_scales,
270-
const optional<Tensor> k_zero_points,
271-
const optional<Tensor> k_scales,
272-
const optional<Tensor> v_zero_points,
273-
const optional<Tensor> v_scales,
267+
const std::optional<double> scale,
268+
const std::optional<Tensor> q_zero_points,
269+
const std::optional<Tensor> q_scales,
270+
const std::optional<Tensor> k_zero_points,
271+
const std::optional<Tensor> k_scales,
272+
const std::optional<Tensor> v_zero_points,
273+
const std::optional<Tensor> v_scales,
274274
const bool is_seq_at_dim_2,
275275
Tensor& output) {
276276
executorch::aten::RuntimeContext context{};

extension/llm/custom_ops/op_sdpa_impl.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -547,14 +547,14 @@ void cpu_flash_attention(
547547
const Tensor& value,
548548
double dropout_p,
549549
bool is_causal,
550-
const optional<Tensor>& attn_mask,
551-
const optional<double>& scale,
552-
const optional<Tensor>& q_zero_points,
553-
const optional<Tensor>& q_scales,
554-
const optional<Tensor>& k_zero_points,
555-
const optional<Tensor>& k_scales,
556-
const optional<Tensor>& v_zero_points,
557-
const optional<Tensor>& v_scales,
550+
const std::optional<Tensor>& attn_mask,
551+
const std::optional<double>& scale,
552+
const std::optional<Tensor>& q_zero_points,
553+
const std::optional<Tensor>& q_scales,
554+
const std::optional<Tensor>& k_zero_points,
555+
const std::optional<Tensor>& k_scales,
556+
const std::optional<Tensor>& v_zero_points,
557+
const std::optional<Tensor>& v_scales,
558558
const SeqDim seq_dim = SeqDim::TWO,
559559
const int64_t start_pos = 0,
560560
const int64_t num_keys_for_causal_attention = -1) {

extension/llm/custom_ops/op_update_cache.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ bool validate_cache_params(
2626
const Tensor& quantized_cache,
2727
int64_t start_pos,
2828
int64_t seq_length,
29-
const optional<Tensor>& indices = nullopt) {
29+
const std::optional<Tensor>& indices = std::nullopt) {
3030
ET_CHECK_OR_RETURN_FALSE(
3131
quantized_cache.dim() == 4, "quantized cache must be a 4D tensor");
3232

@@ -94,7 +94,7 @@ Tensor& update_cache_impl(
9494
Tensor& cache,
9595
const int64_t start_pos,
9696
Tensor& output,
97-
const optional<Tensor>& indices = nullopt) {
97+
const std::optional<Tensor>& indices = std::nullopt) {
9898
(void)ctx;
9999

100100
ET_CHECK_MSG(

0 commit comments

Comments
 (0)