From df1a5fb113e97cf8d11cc6112b7ca5ce58fd2b54 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 2 Oct 2024 13:21:47 -0700 Subject: [PATCH] op_clamp: add downcasting tests & fix (#5798) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5798 Reviewed By: swolchok Differential Revision: D63716405 fbshipit-source-id: 4987e9ad93f0b3f490432cf07ba19c2f26fc82e0 (cherry picked from commit 3aa6b14de291cd79264f4473464d521d7f8f4c72) --- kernels/portable/cpu/op_clamp.cpp | 16 +++++++---- kernels/test/op_clamp_test.cpp | 48 +++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index ec34fa9bd35..924780b29ab 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -218,6 +218,10 @@ Tensor& clamp_tensor_out( ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() { ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() { ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { + using CTYPE_MINMAX = typename torch::executor:: + promote_types::type; + using CTYPE = typename torch::executor:: + promote_types::type; apply_ternary_elementwise_fn< CTYPE_IN, CTYPE_MIN, @@ -227,16 +231,16 @@ Tensor& clamp_tensor_out( const CTYPE_IN val_in, const CTYPE_MIN val_min, const CTYPE_MAX val_max) { - CTYPE_OUT val_out = static_cast(val_in); + CTYPE val_out = static_cast(val_in); if (has_min) { - val_out = utils::max_override( - val_out, static_cast(val_min)); + val_out = + utils::max_override(val_out, static_cast(val_min)); } if (has_max) { - val_out = utils::min_override( - val_out, static_cast(val_max)); + val_out = + utils::min_override(val_out, static_cast(val_max)); } - return val_out; + return static_cast(val_out); }, in, min, diff --git a/kernels/test/op_clamp_test.cpp b/kernels/test/op_clamp_test.cpp index d9d45509084..533dfee7ae1 100644 --- a/kernels/test/op_clamp_test.cpp +++ b/kernels/test/op_clamp_test.cpp @@ -484,3 +484,51 @@ TEST_F(OpClampTensorOutTest, SmokeTest) { op_clamp_tensor_out(in, min, max, out); EXPECT_TENSOR_EQ(out, expected); } + +TEST_F(OpClampTensorOutTest, DowncastingSmokeTest) { + TensorFactory tf_in; + TensorFactory tf_min; + TensorFactory tf_max; + TensorFactory tf_out; + + Tensor in = tf_in.make({}, {5}); + Tensor min = tf_min.make({}, {-129}); + Tensor max = tf_max.make({}, {300}); + Tensor out = tf_out.zeros({}); + Tensor expected = tf_out.make({}, {5}); + + op_clamp_tensor_out(in, min, max, out); + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpClampTensorOutTest, DowncastingSmokeTest2) { + TensorFactory tf_in; + TensorFactory tf_min; + TensorFactory tf_max; + TensorFactory tf_out; + + Tensor in = tf_in.make({}, {301}); + Tensor min = tf_min.make({}, {-129}); + Tensor max = tf_max.make({}, {300}); + Tensor out = tf_out.zeros({}); + Tensor expected = tf_out.make({}, {44}); + + op_clamp_tensor_out(in, min, max, out); + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpClampTensorOutTest, DowncastingSmokeTest3) { + TensorFactory tf_in; + TensorFactory tf_min; + TensorFactory tf_max; + TensorFactory tf_out; + + Tensor in = tf_in.make({}, {45}); + Tensor min = tf_min.make({}, {-129}); + Tensor max = tf_max.make({}, {300}); + Tensor out = tf_out.zeros({}); + Tensor expected = tf_out.make({}, {45}); + + op_clamp_tensor_out(in, min, max, out); + EXPECT_TENSOR_EQ(out, expected); +}