@@ -76,27 +76,45 @@ Tensor& add_out(
7676 int inp2_shape[kTensorDimensionLimit ];
7777 int out_shape[kTensorDimensionLimit ];
7878
79- /* input shapes and output shapes */
80- for (auto i = 0 ; i < a_size.size (); i++) {
81- inp1_shape[i] = a_size[i];
82- }
83-
84- for (auto i = 0 ; i < b_size.size (); i++) {
85- inp2_shape[i] = b_size[i];
86- }
87-
88- for (auto i = 0 ; i < out_size.size (); i++) {
89- out_shape[i] = out_size[i];
90- }
91-
9279 /* find broadcast*/
9380 const bool a_is_broadcasted = !out.sizes ().equals (a.sizes ());
9481 const bool b_is_broadcasted = !out.sizes ().equals (b.sizes ());
9582 const bool broadcast = (a_is_broadcasted || b_is_broadcasted);
9683
9784 int max_dim = a.dim () > b.dim () ? a.dim () : b.dim ();
85+ max_dim = out.dim () > max_dim ? out.dim () : max_dim;
9886
99- if (compute_type == ScalarType::Int) {
87+ bool optimized = 1 ;
88+
89+ if ((a.dim () == 0 ) || (b.dim () == 0 )) {
90+ optimized = 0 ;
91+ }
92+
93+ if ((broadcast == 1 ) && (max_dim > kTensorDimensionLimit )) {
94+ optimized = 0 ;
95+ }
96+
97+ for (int i = 0 ; i < max_dim; i++) {
98+ out_shape[i] = 1 ;
99+ inp1_shape[i] = 1 ;
100+ inp2_shape[i] = 1 ;
101+ }
102+
103+ int offset_out = max_dim - out.dim ();
104+ int offset_inp1 = max_dim - a.dim ();
105+ int offset_inp2 = max_dim - b.dim ();
106+
107+ for (int i = 0 ; i < out.dim (); i++) {
108+ out_shape[i + offset_out] = out.size (i);
109+ }
110+ for (int i = 0 ; i < a.dim (); i++) {
111+ inp1_shape[i + offset_inp1] = a.size (i);
112+ }
113+ for (int i = 0 ; i < b.dim (); i++) {
114+ inp2_shape[i + offset_inp2] = b.size (i);
115+ }
116+
117+ if ((compute_type == ScalarType::Int) && (optimized)){
100118 const int * const inp1_data = a.const_data_ptr <int >();
101119 const int * const inp2_data = b.const_data_ptr <int >();
102120 int * const out_data = out.mutable_data_ptr <int >();
@@ -117,7 +135,7 @@ Tensor& add_out(
117135 xa_nn_elm_add_32x32_32 (
118136 out_data, inp1_data, inp2_data, alpha_val, out.numel ());
119137 }
120- } else if (compute_type == ScalarType::Float) {
138+ } else if (( compute_type == ScalarType::Float) && (optimized) ) {
121139 const float * const inp1_data = a.const_data_ptr <float >();
122140 const float * const inp2_data = b.const_data_ptr <float >();
123141 float * const out_data = out.mutable_data_ptr <float >();
0 commit comments