@@ -49,187 +49,6 @@ using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExec
4949 Concat, Input, 0 );
5050} // namespace
5151
52- // this method will be shared between 'Concat' (CPU and GPU) and
53- // 'ConcatFromSequence' ('concat' and 'stack' modes) to validate inputs
54- Status ConcatBase::PrepareForCompute (OpKernelContext* ctx,
55- const InlinedTensorsVector& input_tensors,
56- Prepare& p) const {
57- size_t input_count = input_tensors.size ();
58-
59- // Must have atleast one input to concat
60- ORT_RETURN_IF_NOT (input_count >= 1 , " Must have 1 or more inputs" );
61-
62- TensorShapeVector reference_dims;
63- size_t reference_rank = 0 ;
64-
65- int reference_tensor_index = 0 ;
66-
67- InlinedVector<int64_t , Prepare::kExpectedNumberOfInputs > input_tensor_sizes;
68- input_tensor_sizes.reserve (input_count);
69-
70- bool all_inputs_are_empty = true ;
71-
72- for (size_t index = 0 ; index < input_count; ++index) {
73- const auto * input = input_tensors[index];
74- ORT_ENFORCE (input != nullptr , " input count mismatch" );
75-
76- // find the first tensor that isn't empty
77- // to be used as a reference for all
78- // downstream shape/rank validations of other inputs
79- const auto & shape = input->Shape ();
80- const auto num_elements = shape.Size ();
81- if (num_elements > 0 ) {
82- reference_dims = shape.AsShapeVector ();
83- reference_rank = reference_dims.size ();
84- reference_tensor_index = onnxruntime::narrow<int >(index);
85- input_tensor_sizes.push_back (num_elements);
86- all_inputs_are_empty = false ;
87- break ;
88- } else {
89- input_tensor_sizes.push_back (0 );
90- }
91- }
92-
93- if (all_inputs_are_empty) {
94- // Reference dim and reference rank can just come from the first input
95- // No shape/rank validations will be done (as all inputs are empty).
96- // But the rest of the execution flow (filling in the Prepare instance - p)
97- // can use this info.
98- reference_dims = input_tensors[0 ]->Shape ().AsShapeVector ();
99- reference_rank = reference_dims.size ();
100- }
101-
102- // Cannot concatenate scalars (but they can be stacked)
103- if (!is_stack_)
104- ORT_RETURN_IF_NOT (reference_rank > 0 , " Cannot concatenate scalars" );
105-
106- // Handle and fix negative axis
107- // In 'stack' mode, the accepted range depends on the output rank (which is one more than the input rank)
108- p.axis = static_cast <uint64_t >(HandleNegativeAxis (axis_, onnxruntime::narrow<int64_t >(!is_stack_
109- ? reference_rank
110- : reference_rank + 1 )));
111-
112- // Ensure all of the non concatenated axes match each other
113- for (size_t index = static_cast <size_t >(reference_tensor_index) + 1 ; index < input_count; index++) {
114- const auto * input = input_tensors[index];
115- ORT_ENFORCE (input != nullptr , " input count mismatch" );
116- const auto & input_shape = input->Shape ();
117- const auto input_dims = input_shape.GetDims ();
118-
119- // Skip shape/rank validation for inputs that are empty.
120- // The ONNX spec states that all dim values along axes not concatentated on
121- // need to be the same for all inputs (empty inputs are not explicitly exempted).
122- // The model in GH issue 8020 has a bunch of Loop nodes all feeding into
123- // the 'Concat' node and one of these Loops tend to have an iteration
124- // count of 0 for some inputs. If the iteration count for a Loop is zero,
125- // we don't execute its subgraph (since the outputs are going to be empty anyway)
126- // and we send an "empty" tensor(s) downstream and use ONNX shape inferred shape
127- // to "compose" the shape for these empty tensor(s).
128- // If we encounter symbolic dims in the ONNX shape inferred shape, we place a '0'
129- // in that position and due to the "lossy" nature of this process, the inputs' shape
130- // validation for such empty inputs fail and hence we skip these validations for all
131- // empty inputs.
132- // This isn't too bad as we will never use empty inputs while concatenating anyway.
133- // We just loosen this check to unblock model in GH issue 8020 to complete processing.
134- if (input_shape.Size () == 0 ) {
135- input_tensor_sizes.push_back (0 );
136- } else {
137- const size_t input_rank = input_dims.size ();
138-
139- ORT_ENFORCE (input_rank == reference_rank,
140- " Ranks of input data are different, cannot concatenate them. expected rank: " ,
141- reference_rank, " got: " , input_rank);
142-
143- // Ensure all the other (non-concat) axes match
144- int64_t tensor_size = 1 ;
145- for (size_t axis_index = 0 ; axis_index < reference_rank; ++axis_index) {
146- auto dim_value = input_dims[axis_index];
147- tensor_size *= dim_value;
148-
149- // In 'concat' mode, the axis to be concatenated may be different
150- // But in 'stack' mode, all input shapes must be the same and must be validated
151- if (!is_stack_ && axis_index == p.axis )
152- continue ;
153-
154- ORT_RETURN_IF_NOT (dim_value == reference_dims[axis_index],
155- " Non concat axis dimensions must match: Axis " ,
156- axis_index, " has mismatched dimensions of " , dim_value,
157- " and " , reference_dims[axis_index]);
158- }
159-
160- input_tensor_sizes.push_back (tensor_size); // assign the computed size of the input tensor
161- }
162- }
163-
164- // Calculate the shape of the output tensor
165- auto output_dims = reference_dims;
166-
167- if (!is_stack_) { // 'Concat' mode
168- // While concatenating, the rank of the output is the same as the input rank(s)
169-
170- // Calculate the size of the concatenated axis
171- size_t concat_axis_size = 0 ;
172- for (size_t index = 0 ; index < input_count; index++) {
173- concat_axis_size += onnxruntime::narrow<size_t >(input_tensors[index]->Shape ()[onnxruntime::narrow<size_t >(p.axis )]);
174- }
175-
176- output_dims[onnxruntime::narrow<size_t >(p.axis )] = onnxruntime::narrow<int64_t >(concat_axis_size);
177- } else { // 'Stack' mode
178- // While stacking, the rank of the output is one more than the input rank(s).
179- // Stacking may be thought of as adding an unit dimension (of value 1) in the input tensors,
180- // and concatenating them on thie new axis.
181- // The value in the corresponding axis of the output will be the number of inputs that are being stacked.
182- output_dims.insert (output_dims.begin () + p.axis , static_cast <int64_t >(input_count));
183- }
184-
185- TensorShape output_shape (output_dims);
186-
187- // Create output tensor
188- p.output_tensor = &(*ctx->Output (0 , output_shape));
189-
190- // Make note if output tensor is going to be empty
191- p.output_num_elements = output_shape.Size ();
192-
193- // No need to proceed further if output is going to be empty
194- if (p.output_num_elements == 0 )
195- return Status::OK ();
196-
197- // The output_axis_pitch is the number of elements to add to move to the next split axis in the output.
198- // Can handle stacking as well.
199- p.output_axis_pitch = 1 ;
200- auto output_rank = !is_stack_ ? reference_rank : reference_rank + 1 ;
201- for (size_t i = output_rank; i-- > p.axis ;) {
202- p.output_axis_pitch *= output_dims[i];
203- }
204-
205- // Fill the 'Prepare' struct with available information
206- p.inputs .reserve (input_count);
207- for (size_t input_index = 0 ; input_index < input_count; input_index++) {
208- const Tensor* data_n_ptr = input_tensors[input_index];
209- auto & data_n = *data_n_ptr;
210-
211- // Type sanity check (Make sure we are working on homogeneous types)
212- ORT_RETURN_IF_NOT (data_n.DataType () == p.output_tensor ->DataType (), " Data type mismatch" );
213-
214- // The input_axis_pitch is the number of elements to add to move to the next split axis in the input
215- // Can handle stacking as well (as the "new dummy dimension" in the input is of unit value).
216- // TODO: Minor Optimization possibility: This input_axis_patch will be common across all inputs
217- // in 'ConcatFromSequence' (stack mode). They have to be computed for each input only while concatenating.
218- int64_t input_axis_pitch = 1 ;
219- const auto & data_dims = data_n.Shape ().GetDims ();
220- for (size_t i = reference_rank; i-- > p.axis ;) {
221- input_axis_pitch *= data_dims[i];
222- }
223-
224- p.inputs .push_back ({&data_n, input_axis_pitch, input_tensor_sizes[input_index]});
225- }
226-
227- // Make note if the input Tensors of type 'string'
228- p.is_string_type = p.inputs [0 ].tensor ->IsDataTypeString ();
229-
230- return Status::OK ();
231- }
232-
23352namespace {
23453TensorShapeVector StridesForStack (const TensorShapeVector& full_strides, uint64_t axis) {
23554 // if we are stacking, skip the dimension that will be stacked along in the output strides
0 commit comments