@@ -41,10 +41,56 @@ enum class ElementwiseOptimizedPath {
4141 kTreatAs1d ,
4242 kBroadcast2dBy1d ,
4343 kBroadcast2dBy1dReverseArguments ,
44+ kBroadcastNdByNd ,
45+ kBroadcastNdByNdReverseArguments ,
4446};
4547
4648namespace internal {
47- inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path (
49+
50+ // Find the single broadcast dimension if it exists.
51+ int32_t inline get_broadcast_dim (const Tensor& lhs, const Tensor& rhs) {
52+ auto lhs_begin = arrayref_begin_ignoring_leading_1s (lhs.sizes ());
53+ auto lhs_end = lhs.sizes ().end ();
54+
55+ auto rhs_begin = arrayref_begin_ignoring_leading_1s (rhs.sizes ());
56+ auto rhs_end = rhs.sizes ().end ();
57+
58+ const auto lhs_size = lhs_end - lhs_begin;
59+ const auto rhs_size = rhs_end - rhs_begin;
60+
61+ // Would like to handle this
62+ // [1, 3, 4, 5]
63+ // [2, 3, 4, 5]
64+ if (lhs_size != rhs_size) {
65+ return 0 ;
66+ }
67+
68+ int32_t broadcast_dim = 0 ;
69+ // Check
70+ // 1. if any dim value is 1 (it constitutes a broadcast dim)
71+ // 2. If more than one dim value is 1 (we cannot handle)
72+ // 3. If non-1 dim values are equal
73+ lhs_end--;
74+ rhs_end--;
75+ while (lhs_end != lhs_begin) {
76+ if (*lhs_end == 1 || *rhs_end == 1 ) {
77+ // If more than one broadcast dim is found, return 0.
78+ if (broadcast_dim != 0 ) {
79+ return 0 ;
80+ }
81+ // negative index is used
82+ broadcast_dim = lhs_end - lhs.sizes ().end ();
83+ } else if (*lhs_end != *rhs_end) {
84+ // If non-1 dim values are not equal, return 0.
85+ return 0 ;
86+ }
87+ lhs_end--;
88+ rhs_end--;
89+ }
90+ return broadcast_dim;
91+ }
92+
93+ inline ElementwiseOptimizedPath select_broadcast_optimized_path (
4894 const Tensor& lhs,
4995 const Tensor& rhs) {
5096 auto lhs_begin = arrayref_begin_ignoring_leading_1s (lhs.sizes ());
@@ -63,6 +109,15 @@ inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
63109 return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ;
64110 }
65111
112+ int32_t broadcast_dim = get_broadcast_dim (lhs, rhs);
113+ // Right now we dont handle last dim broadcast
114+ if (broadcast_dim < -1 ) {
115+ if (std::count_if (rhs_begin, rhs_end, [](Tensor::SizesType x) { return x == 1 ; }) == 1 ) {
116+ return ElementwiseOptimizedPath::kBroadcastNdByNd ;
117+ } else {
118+ return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments ;
119+ }
120+ }
66121 return ElementwiseOptimizedPath::kNone ;
67122}
68123} // namespace internal
@@ -85,7 +140,22 @@ ElementwiseOptimizedPath inline select_optimized_path(
85140 internal::sizes_match_ignoring_leading_1s (a.sizes (), b.sizes ())))) {
86141 return ElementwiseOptimizedPath::kTreatAs1d ;
87142 }
88- return internal::select_broadcast_2d_by_1d_optimized_path (a, b);
143+ return internal::select_broadcast_optimized_path (a, b);
144+ }
145+
146+ std::array<int32_t , 3 > inline get_normalized_tensor_size (const Tensor& a, const int32_t broadcast_dim) {
147+ ET_CHECK_MSG (a.dim () > broadcast_dim, " Size of tensor: %zd, must be larger than broadcast_dim: %d" , a.dim (), broadcast_dim);
148+ std::array<int32_t , 3 > normalized_tensor_size;
149+ normalized_tensor_size[0 ] = 1 ;
150+ normalized_tensor_size[1 ] = a.size (broadcast_dim);
151+ normalized_tensor_size[2 ] = 1 ;
152+ for (size_t i = 0 ; i < broadcast_dim; i++) {
153+ normalized_tensor_size[0 ] = normalized_tensor_size[0 ] * a.size (i);
154+ }
155+ for (size_t i = broadcast_dim + 1 ; i < a.dim (); i++) {
156+ normalized_tensor_size[2 ] = normalized_tensor_size[2 ] * a.size (i);
157+ }
158+ return normalized_tensor_size;
89159}
90160
91161} // namespace executor
0 commit comments