Skip to content

Commit 881fadf

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] dequantize_per_channel shaders and impl"
# Context We need to enable the core logic for dequantize_per_channel in the vulkan shader. This implements the shader itself and its cpp header. TODO: add more of a description regarding the operator # Changes This creates an extension of the existing files for dequantize_per_channel. Differential Revision: [D77746141](https://our.internmc.facebook.com/intern/diff/D77746141/) [ghstack-poisoned]
2 parents cca4176 + 59895c9 commit 881fadf

File tree

1 file changed

+36
-25
lines changed

1 file changed

+36
-25
lines changed

extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -424,10 +424,12 @@ TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_ArrayRefOptional) {
424424

425425
TEST_F(MakeATenFunctorFromETFunctorTest, TestConvert_ConstRefOptionals) {
426426
// Test const optional scalar conversion
427-
const std::optional<int64_t> const_optional_at_in = std::optional<int64_t>(42);
427+
const std::optional<int64_t> const_optional_at_in =
428+
std::optional<int64_t>(42);
428429
auto const_optional_et =
429-
type_convert<const std::optional<int64_t>, torch::executor::optional<int64_t>>(
430-
const_optional_at_in)
430+
type_convert<
431+
const std::optional<int64_t>,
432+
torch::executor::optional<int64_t>>(const_optional_at_in)
431433
.call();
432434
EXPECT_TRUE(const_optional_et.has_value());
433435
EXPECT_EQ(const_optional_et.value(), 42);
@@ -442,60 +444,69 @@ TEST_F(MakeATenFunctorFromETFunctorTest, TestConvert_ConstRefOptionals) {
442444
EXPECT_EQ(optional_et_from_ref.value(), 24);
443445

444446
// Test const optional scalar reference conversion
445-
const std::optional<int64_t> const_optional_at_ref_in = std::optional<int64_t>(84);
447+
const std::optional<int64_t> const_optional_at_ref_in =
448+
std::optional<int64_t>(84);
446449
auto const_optional_et_from_ref =
447-
type_convert<const std::optional<int64_t>&, torch::executor::optional<int64_t>>(
448-
const_optional_at_ref_in)
450+
type_convert<
451+
const std::optional<int64_t>&,
452+
torch::executor::optional<int64_t>>(const_optional_at_ref_in)
449453
.call();
450454
EXPECT_TRUE(const_optional_et_from_ref.has_value());
451455
EXPECT_EQ(const_optional_et_from_ref.value(), 84);
452456

453457
// Test const optional tensor conversion
454458
const std::optional<at::Tensor> const_optional_tensor_at_in =
455459
std::optional<at::Tensor>(torch::tensor({5}));
456-
auto const_optional_tensor_converter =
457-
type_convert<
458-
const std::optional<at::Tensor>,
459-
torch::executor::optional<torch::executor::Tensor>>(const_optional_tensor_at_in);
460+
auto const_optional_tensor_converter = type_convert<
461+
const std::optional<at::Tensor>,
462+
torch::executor::optional<torch::executor::Tensor>>(
463+
const_optional_tensor_at_in);
460464
auto const_optional_tensor_et = const_optional_tensor_converter.call();
461465
EXPECT_TRUE(const_optional_tensor_et.has_value());
462466
EXPECT_EQ(const_optional_tensor_et.value().const_data_ptr<int64_t>()[0], 5);
463467

464468
// Test optional tensor reference conversion
465469
std::optional<at::Tensor> optional_tensor_at_ref_in =
466470
std::optional<at::Tensor>(torch::tensor({7}));
467-
auto optional_tensor_converter_from_ref =
468-
type_convert<
469-
std::optional<at::Tensor>&,
470-
torch::executor::optional<torch::executor::Tensor>>(optional_tensor_at_ref_in);
471+
auto optional_tensor_converter_from_ref = type_convert<
472+
std::optional<at::Tensor>&,
473+
torch::executor::optional<torch::executor::Tensor>>(
474+
optional_tensor_at_ref_in);
471475
auto optional_tensor_et_from_ref = optional_tensor_converter_from_ref.call();
472476
EXPECT_TRUE(optional_tensor_et_from_ref.has_value());
473-
EXPECT_EQ(optional_tensor_et_from_ref.value().const_data_ptr<int64_t>()[0], 7);
477+
EXPECT_EQ(
478+
optional_tensor_et_from_ref.value().const_data_ptr<int64_t>()[0], 7);
474479

475480
// Test const optional tensor reference conversion
476481
const std::optional<at::Tensor> const_optional_tensor_at_ref_in =
477482
std::optional<at::Tensor>(torch::tensor({9}));
478-
auto const_optional_tensor_converter_from_ref =
479-
type_convert<
480-
const std::optional<at::Tensor>&,
481-
torch::executor::optional<torch::executor::Tensor>>(const_optional_tensor_at_ref_in);
482-
auto const_optional_tensor_et_from_ref = const_optional_tensor_converter_from_ref.call();
483+
auto const_optional_tensor_converter_from_ref = type_convert<
484+
const std::optional<at::Tensor>&,
485+
torch::executor::optional<torch::executor::Tensor>>(
486+
const_optional_tensor_at_ref_in);
487+
auto const_optional_tensor_et_from_ref =
488+
const_optional_tensor_converter_from_ref.call();
483489
EXPECT_TRUE(const_optional_tensor_et_from_ref.has_value());
484-
EXPECT_EQ(const_optional_tensor_et_from_ref.value().const_data_ptr<int64_t>()[0], 9);
490+
EXPECT_EQ(
491+
const_optional_tensor_et_from_ref.value().const_data_ptr<int64_t>()[0],
492+
9);
485493

486494
// Test empty const optional conversions
487495
const std::optional<int64_t> empty_const_optional_at_in = std::nullopt;
488496
auto empty_const_optional_et =
489-
type_convert<const std::optional<int64_t>, torch::executor::optional<int64_t>>(
490-
empty_const_optional_at_in)
497+
type_convert<
498+
const std::optional<int64_t>,
499+
torch::executor::optional<int64_t>>(empty_const_optional_at_in)
491500
.call();
492501
EXPECT_FALSE(empty_const_optional_et.has_value());
493502

494-
const std::optional<at::Tensor> empty_const_optional_tensor_at_in = std::nullopt;
503+
const std::optional<at::Tensor> empty_const_optional_tensor_at_in =
504+
std::nullopt;
495505
auto empty_const_optional_tensor_et =
496506
type_convert<
497507
const std::optional<at::Tensor>,
498-
torch::executor::optional<torch::executor::Tensor>>(empty_const_optional_tensor_at_in)
508+
torch::executor::optional<torch::executor::Tensor>>(
509+
empty_const_optional_tensor_at_in)
499510
.call();
500511
EXPECT_FALSE(empty_const_optional_tensor_et.has_value());
501512
}

0 commit comments

Comments
 (0)