@@ -728,3 +728,59 @@ TEST_F(OpConvCorrectnessTest, InvalidOutputPadding) {
728728 groups,
729729 out));
730730}
731+
732+ TEST_F (OpConvCorrectnessTest, HalfTypeSmokeTest) {
733+ TensorFactory<ScalarType::Half> tf;
734+
735+ auto input = tf.make ({1 , 2 , 3 }, {1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 });
736+ auto weight = tf.make ({2 , 2 , 2 }, {0.5 , 0.5 , 0.5 , 0.5 , 1.0 , 1.0 , 1.0 , 1.0 });
737+ optional<Tensor> bias;
738+ auto expected = tf.make ({1 , 2 , 2 }, {6.0 , 8.0 , 12.0 , 16.0 });
739+ auto out = tf.zeros ({1 , 2 , 2 });
740+
741+ int64_t stride[1 ] = {1 };
742+ int64_t padding[1 ] = {0 };
743+ int64_t dilation[1 ] = {1 };
744+ int64_t output_padding[1 ] = {0 };
745+
746+ op_convolution_out (
747+ input,
748+ weight,
749+ bias,
750+ executorch::aten::ArrayRef<int64_t >{stride, 1 },
751+ executorch::aten::ArrayRef<int64_t >{padding, 1 },
752+ executorch::aten::ArrayRef<int64_t >{dilation, 1 },
753+ false ,
754+ executorch::aten::ArrayRef<int64_t >{output_padding, 1 },
755+ int64_t (1 ),
756+ out);
757+ EXPECT_TENSOR_CLOSE (out, expected);
758+ }
759+
760+ TEST_F (OpConvCorrectnessTest, BFloat16TypeSmokeTest) {
761+ TensorFactory<ScalarType::BFloat16> tf;
762+
763+ auto input = tf.make ({1 , 2 , 3 }, {1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 });
764+ auto weight = tf.make ({2 , 2 , 2 }, {0.5 , 0.5 , 0.5 , 0.5 , 1.0 , 1.0 , 1.0 , 1.0 });
765+ optional<Tensor> bias;
766+ auto expected = tf.make ({1 , 2 , 2 }, {6.0 , 8.0 , 12.0 , 16.0 });
767+ auto out = tf.zeros ({1 , 2 , 2 });
768+
769+ int64_t stride[1 ] = {1 };
770+ int64_t padding[1 ] = {0 };
771+ int64_t dilation[1 ] = {1 };
772+ int64_t output_padding[1 ] = {0 };
773+
774+ op_convolution_out (
775+ input,
776+ weight,
777+ bias,
778+ executorch::aten::ArrayRef<int64_t >{stride, 1 },
779+ executorch::aten::ArrayRef<int64_t >{padding, 1 },
780+ executorch::aten::ArrayRef<int64_t >{dilation, 1 },
781+ false ,
782+ executorch::aten::ArrayRef<int64_t >{output_padding, 1 },
783+ int64_t (1 ),
784+ out);
785+ EXPECT_TENSOR_CLOSE (out, expected);
786+ }
0 commit comments