@@ -41,62 +41,10 @@ enum class ElementwiseOptimizedPath {
4141 kTreatAs1d ,
4242 kBroadcast2dBy1d ,
4343 kBroadcast2dBy1dReverseArguments ,
44- kBroadcastNdByNd ,
45- kBroadcastNdByNdReverseArguments ,
4644};
4745
4846namespace internal {
49-
50- // Find the single broadcast dimension if it exists.
51- // This path aims to handle broadcast of the following form
52- // A = [a1, a2,., 1, .., an]
53- // B = [b1, b2,., bm, .., bn]
54- // OR
55- // A = [a1, a2,., am, .., an]
56- // B = [b1, b2,., 1, .., bn]
57- int32_t inline get_broadcast_dim (const Tensor& lhs, const Tensor& rhs) {
58- auto lhs_begin = arrayref_begin_ignoring_leading_1s (lhs.sizes ());
59- auto lhs_end = lhs.sizes ().end ();
60-
61- auto rhs_begin = arrayref_begin_ignoring_leading_1s (rhs.sizes ());
62- auto rhs_end = rhs.sizes ().end ();
63-
64- const auto lhs_size = lhs_end - lhs_begin;
65- const auto rhs_size = rhs_end - rhs_begin;
66-
67- // Following example is not handled at the moment
68- // [1, 3, 4, 5]
69- // [2, 3, 4, 5]
70- if (lhs_size != rhs_size) {
71- return 0 ;
72- }
73-
74- int32_t broadcast_dim = 0 ;
75- // Check
76- // 1. if any dim value is 1 (it constitutes a broadcast dim)
77- // 2. If more than one dim value is 1 (we cannot handle)
78- // 3. If non-1 dim values are equal
79- lhs_end--;
80- rhs_end--;
81- while (lhs_end != lhs_begin) {
82- if (*lhs_end == 1 || *rhs_end == 1 ) {
83- // If more than one broadcast dim is found, return 0.
84- if (broadcast_dim != 0 ) {
85- return 0 ;
86- }
87- // negative index is used
88- broadcast_dim = lhs_end - lhs.sizes ().end ();
89- } else if (*lhs_end != *rhs_end) {
90- // If non-1 dim values are not equal, return 0.
91- return 0 ;
92- }
93- lhs_end--;
94- rhs_end--;
95- }
96- return broadcast_dim;
97- }
98-
99- inline ElementwiseOptimizedPath select_broadcast_optimized_path (
47+ inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path (
10048 const Tensor& lhs,
10149 const Tensor& rhs) {
10250 auto lhs_begin = arrayref_begin_ignoring_leading_1s (lhs.sizes ());
@@ -115,17 +63,6 @@ inline ElementwiseOptimizedPath select_broadcast_optimized_path(
11563 return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ;
11664 }
11765
118- int32_t broadcast_dim = get_broadcast_dim (lhs, rhs);
119- // Right now we dont handle last dim broadcast
120- if (broadcast_dim < -1 ) {
121- if (std::count_if (rhs_begin, rhs_end, [](Tensor::SizesType x) {
122- return x == 1 ;
123- }) == 1 ) {
124- return ElementwiseOptimizedPath::kBroadcastNdByNd ;
125- } else {
126- return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments ;
127- }
128- }
12966 return ElementwiseOptimizedPath::kNone ;
13067}
13168} // namespace internal
@@ -148,28 +85,7 @@ ElementwiseOptimizedPath inline select_optimized_path(
14885 internal::sizes_match_ignoring_leading_1s (a.sizes (), b.sizes ())))) {
14986 return ElementwiseOptimizedPath::kTreatAs1d ;
15087 }
151- return internal::select_broadcast_optimized_path (a, b);
152- }
153-
154- std::array<int32_t , 3 > inline get_normalized_tensor_size (
155- const Tensor& a,
156- const int32_t broadcast_dim) {
157- ET_CHECK_MSG (
158- a.dim () > broadcast_dim,
159- " Size of tensor: %zd, must be larger than broadcast_dim: %d" ,
160- a.dim (),
161- broadcast_dim);
162- std::array<int32_t , 3 > normalized_tensor_size;
163- normalized_tensor_size[0 ] = 1 ;
164- normalized_tensor_size[1 ] = a.size (broadcast_dim);
165- normalized_tensor_size[2 ] = 1 ;
166- for (size_t i = 0 ; i < broadcast_dim; i++) {
167- normalized_tensor_size[0 ] *= a.size (i);
168- }
169- for (size_t i = broadcast_dim + 1 ; i < a.dim (); i++) {
170- normalized_tensor_size[2 ] *= a.size (i);
171- }
172- return normalized_tensor_size;
88+ return internal::select_broadcast_2d_by_1d_optimized_path (a, b);
17389}
17490
17591} // namespace executor
0 commit comments