@@ -41,10 +41,62 @@ enum class ElementwiseOptimizedPath {
4141 kTreatAs1d ,
4242 kBroadcast2dBy1d ,
4343 kBroadcast2dBy1dReverseArguments ,
44+ kBroadcastNdByNd ,
45+ kBroadcastNdByNdReverseArguments ,
4446};
4547
4648namespace internal {
47- inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path (
49+
50+ // Find the single broadcast dimension if it exists.
51+ // This path aims to handle broadcast of the following form
52+ // A = [a1, a2,., 1, .., an]
53+ // B = [b1, b2,., bm, .., bn]
54+ // OR
55+ // A = [a1, a2,., am, .., an]
56+ // B = [b1, b2,., 1, .., bn]
57+ int32_t inline get_broadcast_dim (const Tensor& lhs, const Tensor& rhs) {
58+ auto lhs_begin = arrayref_begin_ignoring_leading_1s (lhs.sizes ());
59+ auto lhs_end = lhs.sizes ().end ();
60+
61+ auto rhs_begin = arrayref_begin_ignoring_leading_1s (rhs.sizes ());
62+ auto rhs_end = rhs.sizes ().end ();
63+
64+ const auto lhs_size = lhs_end - lhs_begin;
65+ const auto rhs_size = rhs_end - rhs_begin;
66+
67+ // Following example is not handled at the moment
68+ // [1, 3, 4, 5]
69+ // [2, 3, 4, 5]
70+ if (lhs_size != rhs_size) {
71+ return 0 ;
72+ }
73+
74+ int32_t broadcast_dim = 0 ;
75+ // Check
76+ // 1. if any dim value is 1 (it constitutes a broadcast dim)
77+ // 2. If more than one dim value is 1 (we cannot handle)
78+ // 3. If non-1 dim values are equal
79+ lhs_end--;
80+ rhs_end--;
81+ while (lhs_end != lhs_begin) {
82+ if (*lhs_end == 1 || *rhs_end == 1 ) {
83+ // If more than one broadcast dim is found, return 0.
84+ if (broadcast_dim != 0 ) {
85+ return 0 ;
86+ }
87+ // negative index is used
88+ broadcast_dim = lhs_end - lhs.sizes ().end ();
89+ } else if (*lhs_end != *rhs_end) {
90+ // If non-1 dim values are not equal, return 0.
91+ return 0 ;
92+ }
93+ lhs_end--;
94+ rhs_end--;
95+ }
96+ return broadcast_dim;
97+ }
98+
99+ inline ElementwiseOptimizedPath select_broadcast_optimized_path (
48100 const Tensor& lhs,
49101 const Tensor& rhs) {
50102 auto lhs_begin = arrayref_begin_ignoring_leading_1s (lhs.sizes ());
@@ -63,6 +115,17 @@ inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
63115 return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ;
64116 }
65117
118+ int32_t broadcast_dim = get_broadcast_dim (lhs, rhs);
119+ // Right now we dont handle last dim broadcast
120+ if (broadcast_dim < -1 ) {
121+ if (std::count_if (rhs_begin, rhs_end, [](Tensor::SizesType x) {
122+ return x == 1 ;
123+ }) == 1 ) {
124+ return ElementwiseOptimizedPath::kBroadcastNdByNd ;
125+ } else {
126+ return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments ;
127+ }
128+ }
66129 return ElementwiseOptimizedPath::kNone ;
67130}
68131} // namespace internal
@@ -85,7 +148,28 @@ ElementwiseOptimizedPath inline select_optimized_path(
85148 internal::sizes_match_ignoring_leading_1s (a.sizes (), b.sizes ())))) {
86149 return ElementwiseOptimizedPath::kTreatAs1d ;
87150 }
88- return internal::select_broadcast_2d_by_1d_optimized_path (a, b);
151+ return internal::select_broadcast_optimized_path (a, b);
152+ }
153+
154+ std::array<int32_t , 3 > inline get_normalized_tensor_size (
155+ const Tensor& a,
156+ const int32_t broadcast_dim) {
157+ ET_CHECK_MSG (
158+ a.dim () > broadcast_dim,
159+ " Size of tensor: %zd, must be larger than broadcast_dim: %d" ,
160+ a.dim (),
161+ broadcast_dim);
162+ std::array<int32_t , 3 > normalized_tensor_size;
163+ normalized_tensor_size[0 ] = 1 ;
164+ normalized_tensor_size[1 ] = a.size (broadcast_dim);
165+ normalized_tensor_size[2 ] = 1 ;
166+ for (size_t i = 0 ; i < broadcast_dim; i++) {
167+ normalized_tensor_size[0 ] *= a.size (i);
168+ }
169+ for (size_t i = broadcast_dim + 1 ; i < a.dim (); i++) {
170+ normalized_tensor_size[2 ] *= a.size (i);
171+ }
172+ return normalized_tensor_size;
89173}
90174
91175} // namespace executor
0 commit comments