@@ -794,3 +794,112 @@ TEST_F(OpMulScalarOutTest, BFloat16SanityCheck) {
794794 // Check that it matches the expected output.
795795 EXPECT_TENSOR_CLOSE (out, tf.make (sizes, {2.6 , 4.2 , 9.2 , 16.4 }));
796796}
797+
798+ // Tests for broadcast handling fix: when tensor dimensions don't match,
799+ // the output should be resized to match the tensor with higher dimensionality
800+ TEST_F (OpMulOutTest, BroadcastDimensionMismatchFix) {
801+ TensorFactory<ScalarType::Float> tf;
802+
803+ // Test case: tensor a of size [6] and b of size [1, 1, 6]
804+ // Expected output should be [1, 1, 6], not [6]
805+ Tensor a = tf.make ({6 }, {1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 });
806+ Tensor b = tf.make ({1 , 1 , 6 }, {2.0 , 2.0 , 2.0 , 2.0 , 2.0 , 2.0 });
807+
808+ // Create output tensor with expected broadcast shape [1, 1, 6]
809+ Tensor out = tf.zeros ({1 , 1 , 6 });
810+
811+ // Call the mul function
812+ Tensor& result = op_mul_out (a, b, out);
813+
814+ // Verify the output shape is [1, 1, 6]
815+ EXPECT_EQ (result.dim (), 3 );
816+ EXPECT_EQ (result.size (0 ), 1 );
817+ EXPECT_EQ (result.size (1 ), 1 );
818+ EXPECT_EQ (result.size (2 ), 6 );
819+
820+ // Verify the values are correct (element-wise multiplication with
821+ // broadcasting)
822+ Tensor expected = tf.make ({1 , 1 , 6 }, {2.0 , 4.0 , 6.0 , 8.0 , 10.0 , 12.0 });
823+ EXPECT_TENSOR_CLOSE (result, expected);
824+ }
825+
826+ TEST_F (OpMulOutTest, BroadcastDimensionMismatchReversed) {
827+ TensorFactory<ScalarType::Float> tf;
828+
829+ // Test case: tensor a of size [1, 1, 6] and b of size [6]
830+ // Expected output should be [1, 1, 6]
831+ Tensor a = tf.make ({1 , 1 , 6 }, {1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 });
832+ Tensor b = tf.make ({6 }, {2.0 , 2.0 , 2.0 , 2.0 , 2.0 , 2.0 });
833+
834+ // Create output tensor with expected broadcast shape [1, 1, 6]
835+ Tensor out = tf.zeros ({1 , 1 , 6 });
836+
837+ // Call the mul function
838+ Tensor& result = op_mul_out (a, b, out);
839+
840+ // Verify the output shape is [1, 1, 6]
841+ EXPECT_EQ (result.dim (), 3 );
842+ EXPECT_EQ (result.size (0 ), 1 );
843+ EXPECT_EQ (result.size (1 ), 1 );
844+ EXPECT_EQ (result.size (2 ), 6 );
845+
846+ // Verify the values are correct (element-wise multiplication with
847+ // broadcasting)
848+ Tensor expected = tf.make ({1 , 1 , 6 }, {2.0 , 4.0 , 6.0 , 8.0 , 10.0 , 12.0 });
849+ EXPECT_TENSOR_CLOSE (result, expected);
850+ }
851+
852+ TEST_F (OpMulOutTest, BroadcastDimensionMismatchWithDifferentTypes) {
853+ // Test the same broadcast fix with different data types
854+ TensorFactory<ScalarType::Half> tf_half;
855+ TensorFactory<ScalarType::BFloat16> tf_bf16;
856+ TensorFactory<ScalarType::Int> tf_int;
857+
858+ // Test with Half precision
859+ {
860+ Tensor a = tf_half.make ({4 }, {1.0 , 2.0 , 3.0 , 4.0 });
861+ Tensor b = tf_half.make ({1 , 1 , 4 }, {2.0 , 2.0 , 2.0 , 2.0 });
862+ Tensor out = tf_half.zeros ({1 , 1 , 4 });
863+
864+ Tensor& result = op_mul_out (a, b, out);
865+ EXPECT_EQ (result.dim (), 3 );
866+ EXPECT_EQ (result.size (0 ), 1 );
867+ EXPECT_EQ (result.size (1 ), 1 );
868+ EXPECT_EQ (result.size (2 ), 4 );
869+
870+ Tensor expected = tf_half.make ({1 , 1 , 4 }, {2.0 , 4.0 , 6.0 , 8.0 });
871+ EXPECT_TENSOR_CLOSE (result, expected);
872+ }
873+
874+ // Test with BFloat16
875+ {
876+ Tensor a = tf_bf16.make ({4 }, {1.0 , 2.0 , 3.0 , 4.0 });
877+ Tensor b = tf_bf16.make ({1 , 1 , 4 }, {2.0 , 2.0 , 2.0 , 2.0 });
878+ Tensor out = tf_bf16.zeros ({1 , 1 , 4 });
879+
880+ Tensor& result = op_mul_out (a, b, out);
881+ EXPECT_EQ (result.dim (), 3 );
882+ EXPECT_EQ (result.size (0 ), 1 );
883+ EXPECT_EQ (result.size (1 ), 1 );
884+ EXPECT_EQ (result.size (2 ), 4 );
885+
886+ Tensor expected = tf_bf16.make ({1 , 1 , 4 }, {2.0 , 4.0 , 6.0 , 8.0 });
887+ EXPECT_TENSOR_CLOSE (result, expected);
888+ }
889+
890+ // Test with Int
891+ {
892+ Tensor a = tf_int.make ({4 }, {1 , 2 , 3 , 4 });
893+ Tensor b = tf_int.make ({1 , 1 , 4 }, {2 , 2 , 2 , 2 });
894+ Tensor out = tf_int.zeros ({1 , 1 , 4 });
895+
896+ Tensor& result = op_mul_out (a, b, out);
897+ EXPECT_EQ (result.dim (), 3 );
898+ EXPECT_EQ (result.size (0 ), 1 );
899+ EXPECT_EQ (result.size (1 ), 1 );
900+ EXPECT_EQ (result.size (2 ), 4 );
901+
902+ Tensor expected = tf_int.make ({1 , 1 , 4 }, {2 , 4 , 6 , 8 });
903+ EXPECT_TENSOR_EQ (result, expected);
904+ }
905+ }
0 commit comments