Skip to content

Commit 6f22ae7

Browse files
committed
[WIP] extract common code for EP API adapter
1 parent 7b9de5b commit 6f22ae7

File tree

17 files changed

+287
-270
lines changed

17 files changed

+287
-270
lines changed

onnxruntime/contrib_ops/cpu/bert/attention_base.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,22 @@ class AttentionBase {
3232
int& past_sequence_length) const;
3333

3434
protected:
35-
AttentionBase(const OpKernelInfo& info, bool require_same_hidden_size) {
35+
template <typename KernelInfoType>
36+
AttentionBase(const KernelInfoType& info, bool require_same_hidden_size) {
3637
int64_t num_heads = 0;
3738
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
3839
num_heads_ = static_cast<int>(num_heads);
3940

40-
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
41-
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
42-
rotary_embedding_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
43-
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
44-
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
45-
46-
if (!info.GetAttrs<int64_t>("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK()) {
41+
is_unidirectional_ = info.template GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
42+
do_rotary_ = info.template GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
43+
rotary_embedding_ = static_cast<int>(info.template GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
44+
mask_filter_value_ = info.template GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
45+
scale_ = info.template GetAttrOrDefault<float>("scale", 0.0f);
46+
if (!info.template GetAttrs<int64_t>("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK()) {
4747
qkv_hidden_sizes_.clear();
4848
}
4949

50-
past_present_share_buffer_ = info.GetAttrOrDefault<int64_t>("past_present_share_buffer", 0LL);
50+
past_present_share_buffer_ = info.template GetAttrOrDefault<int64_t>("past_present_share_buffer", 0LL);
5151

5252
require_same_hidden_size_ = require_same_hidden_size;
5353
}

onnxruntime/core/providers/cpu/nn/conv_attributes.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ namespace onnxruntime {
1818
struct ConvAttributes {
1919
using ConvPadVector = InlinedVector<int64_t, kTensorShapeSmallBufferElementsSize * 2>;
2020

21-
explicit ConvAttributes(const OpKernelInfo& info) {
21+
template <typename KernelInfoType>
22+
explicit ConvAttributes(const KernelInfoType& info) {
2223
std::string auto_pad_str;
23-
auto status = info.GetAttr<std::string>("auto_pad", &auto_pad_str);
24+
auto status = info.template GetAttr<std::string>("auto_pad", &auto_pad_str);
2425
if (status.IsOK()) {
2526
auto_pad = StringToAutoPadType(auto_pad_str);
2627
}
@@ -32,8 +33,8 @@ struct ConvAttributes {
3233
strides.resize(kernel_shape_.size(), 1);
3334
}
3435

35-
gsl::span<const int64_t> pads_span;
36-
status = info.GetAttrsAsSpan("pads", pads_span);
36+
std::vector<int64_t> pads_attr;
37+
status = info.GetAttrs("pads", pads_attr);
3738
if (!status.IsOK()) {
3839
if (kernel_shape_specified) {
3940
// If pads are not explicitly provided, fill the container with all zeros
@@ -44,15 +45,15 @@ struct ConvAttributes {
4445
// Pads are explicitly provided, make sure that auto_pad is NOTSET
4546
ORT_ENFORCE(auto_pad == AutoPadType::NOTSET,
4647
"A Conv/ConvTranspose node has both 'auto_pad' and 'pads' attributes");
47-
pads.assign(pads_span.begin(), pads_span.end());
48+
pads.assign(pads_attr.begin(), pads_attr.end());
4849
}
4950

5051
status = info.GetAttrs("dilations", dilations);
5152
if (kernel_shape_specified && (!status.IsOK() || dilations.empty())) {
5253
dilations.resize(kernel_shape_.size(), 1);
5354
}
5455

55-
status = info.GetAttr<int64_t>("group", &group);
56+
status = info.template GetAttr<int64_t>("group", &group);
5657
if (!status.IsOK()) {
5758
group = 1;
5859
}
@@ -61,9 +62,9 @@ struct ConvAttributes {
6162
// TODO: Re-enable when attributes values are guaranteed to be filled.
6263
// Make sure empty strides or dilations are defaulted to 1 if necessary
6364
std::string auto_pad_str;
64-
ORT_ENFORCE(info.GetAttr<std::string>("auto_pad", &auto_pad_str).IsOK());
65+
ORT_ENFORCE(info.template GetAttr<std::string>("auto_pad", &auto_pad_str).IsOK());
6566
auto_pad = StringToAutoPadType(auto_pad_str);
66-
ORT_ENFORCE(info.GetAttr<int64_t>("group", &group).IsOK());
67+
ORT_ENFORCE(info.template GetAttr<int64_t>("group", &group).IsOK());
6768
ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape_).IsOK());
6869
ORT_ENFORCE(info.GetAttrs("strides", strides).IsOK());
6970
ORT_ENFORCE(info.GetAttrs("pads", pads).IsOK());

onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
namespace onnxruntime {
2424

2525
struct ConvTransposeAttributes : public ConvAttributes {
26-
explicit ConvTransposeAttributes(const OpKernelInfo& info)
26+
template <typename KernelInfoType>
27+
explicit ConvTransposeAttributes(const KernelInfoType& info)
2728
: ConvAttributes(info),
2829
output_padding(info.GetAttrsOrDefault("output_padding")),
2930
output_shape(info.GetAttrsOrDefault("output_shape")) {

onnxruntime/core/providers/cpu/nn/pool_attributes.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ struct PoolAttributes {
2424
// Shared providers don't know about OpNodeProtoHelper
2525
PoolAttributes(const OpKernelInfo& info,
2626
#else
27-
PoolAttributes(const OpNodeProtoHelper<ProtoHelperNodeContext>& info,
27+
template <typename KernelInfoType>
28+
PoolAttributes(const KernelInfoType& info,
2829
#endif
2930
const std::string& op_name, int start_version)
3031
: global_pooling(IsGlobalPooling(op_name)) {
@@ -37,7 +38,7 @@ struct PoolAttributes {
3738

3839
std::string auto_padding;
3940
if (op_name != "MaxUnpool") {
40-
ORT_ENFORCE(info.GetAttr<std::string>("auto_pad", &auto_padding).IsOK());
41+
ORT_ENFORCE(info.template GetAttr<std::string>("auto_pad", &auto_padding).IsOK());
4142
}
4243
auto_pad = StringToAutoPadType(auto_padding);
4344

@@ -49,7 +50,7 @@ struct PoolAttributes {
4950
strides.resize(kernel_shape.size(), 1);
5051
}
5152

52-
if (!info.GetAttr<int64_t>("ceil_mode", &ceil_mode).IsOK()) {
53+
if (!info.template GetAttr<int64_t>("ceil_mode", &ceil_mode).IsOK()) {
5354
ceil_mode = 0;
5455
}
5556

@@ -63,7 +64,7 @@ struct PoolAttributes {
6364

6465
if (op_name == "AveragePool") {
6566
int64_t temp;
66-
ORT_ENFORCE(info.GetAttr<int64_t>("count_include_pad", &temp).IsOK());
67+
ORT_ENFORCE(info.template GetAttr<int64_t>("count_include_pad", &temp).IsOK());
6768
count_include_pad = (temp != 0);
6869
}
6970

onnxruntime/core/providers/cpu/nn/pool_base.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,15 @@ class LpPool {
102102

103103
class PoolBase {
104104
private:
105-
static int GetStartVersion(const OpKernelInfo& info) {
105+
template <typename KernelInfoType>
106+
static int GetStartVersion(const KernelInfoType& info) {
106107
return info.node().SinceVersion();
107108
}
108109

109110
protected:
110-
PoolBase(const OpKernelInfo& info)
111-
: op_name_(info.GetKernelDef().OpName().rfind("QLinear", 0) != 0 ? info.GetKernelDef().OpName() : info.GetKernelDef().OpName().substr(7)),
111+
template <typename KernelInfoType>
112+
PoolBase(const KernelInfoType& info)
113+
: op_name_(info.node().OpType().rfind("QLinear", 0) != 0 ? info.node().OpType() : info.node().OpType().substr(7)),
112114
pool_attrs_(info, op_name_, GetStartVersion(info)) {
113115
}
114116

onnxruntime/core/providers/cpu/reduction/reduction_kernel_base.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ namespace onnxruntime {
1111
template <bool allow_multi_axes>
1212
class ReduceKernelBase {
1313
protected:
14-
ReduceKernelBase(const OpKernelInfo& info, optional<int64_t> keepdims_override = {}) {
14+
template <typename KernelInfoType>
15+
ReduceKernelBase(const KernelInfoType& info, optional<int64_t> keepdims_override = {}) {
1516
if (allow_multi_axes) {
16-
axes_ = ToShapeVector(info.GetAttrsOrDefault<int64_t>("axes"));
17+
axes_ = ToShapeVector(info.template GetAttrsOrDefault<int64_t>("axes"));
1718
} else {
18-
auto v = info.GetAttrOrDefault<int64_t>("axis", 0);
19+
auto v = info.template GetAttrOrDefault<int64_t>("axis", 0);
1920
axes_.push_back(v);
2021
}
2122
int64_t keepdims = 1;
@@ -25,9 +26,9 @@ class ReduceKernelBase {
2526
ORT_ENFORCE(info.GetAttr("keepdims", &keepdims).IsOK());
2627
}
2728
keepdims_ = (keepdims == 1);
28-
int64_t noop_with_empty_axes = info.GetAttrOrDefault<int64_t>("noop_with_empty_axes", 0);
29+
int64_t noop_with_empty_axes = info.template GetAttrOrDefault<int64_t>("noop_with_empty_axes", 0);
2930
noop_with_empty_axes_ = (noop_with_empty_axes == 1);
30-
int64_t select_last_index = info.GetAttrOrDefault<int64_t>("select_last_index", 0);
31+
int64_t select_last_index = info.template GetAttrOrDefault<int64_t>("select_last_index", 0);
3132
select_last_index_ = (select_last_index != 0);
3233
}
3334

onnxruntime/core/providers/cpu/tensor/concat.cc

Lines changed: 0 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -49,187 +49,6 @@ using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExec
4949
Concat, Input, 0);
5050
} // namespace
5151

52-
// this method will be shared between 'Concat' (CPU and GPU) and
53-
// 'ConcatFromSequence' ('concat' and 'stack' modes) to validate inputs
54-
Status ConcatBase::PrepareForCompute(OpKernelContext* ctx,
55-
const InlinedTensorsVector& input_tensors,
56-
Prepare& p) const {
57-
size_t input_count = input_tensors.size();
58-
59-
// Must have atleast one input to concat
60-
ORT_RETURN_IF_NOT(input_count >= 1, "Must have 1 or more inputs");
61-
62-
TensorShapeVector reference_dims;
63-
size_t reference_rank = 0;
64-
65-
int reference_tensor_index = 0;
66-
67-
InlinedVector<int64_t, Prepare::kExpectedNumberOfInputs> input_tensor_sizes;
68-
input_tensor_sizes.reserve(input_count);
69-
70-
bool all_inputs_are_empty = true;
71-
72-
for (size_t index = 0; index < input_count; ++index) {
73-
const auto* input = input_tensors[index];
74-
ORT_ENFORCE(input != nullptr, "input count mismatch");
75-
76-
// find the first tensor that isn't empty
77-
// to be used as a reference for all
78-
// downstream shape/rank validations of other inputs
79-
const auto& shape = input->Shape();
80-
const auto num_elements = shape.Size();
81-
if (num_elements > 0) {
82-
reference_dims = shape.AsShapeVector();
83-
reference_rank = reference_dims.size();
84-
reference_tensor_index = onnxruntime::narrow<int>(index);
85-
input_tensor_sizes.push_back(num_elements);
86-
all_inputs_are_empty = false;
87-
break;
88-
} else {
89-
input_tensor_sizes.push_back(0);
90-
}
91-
}
92-
93-
if (all_inputs_are_empty) {
94-
// Reference dim and reference rank can just come from the first input
95-
// No shape/rank validations will be done (as all inputs are empty).
96-
// But the rest of the execution flow (filling in the Prepare instance - p)
97-
// can use this info.
98-
reference_dims = input_tensors[0]->Shape().AsShapeVector();
99-
reference_rank = reference_dims.size();
100-
}
101-
102-
// Cannot concatenate scalars (but they can be stacked)
103-
if (!is_stack_)
104-
ORT_RETURN_IF_NOT(reference_rank > 0, "Cannot concatenate scalars");
105-
106-
// Handle and fix negative axis
107-
// In 'stack' mode, the accepted range depends on the output rank (which is one more than the input rank)
108-
p.axis = static_cast<uint64_t>(HandleNegativeAxis(axis_, onnxruntime::narrow<int64_t>(!is_stack_
109-
? reference_rank
110-
: reference_rank + 1)));
111-
112-
// Ensure all of the non concatenated axes match each other
113-
for (size_t index = static_cast<size_t>(reference_tensor_index) + 1; index < input_count; index++) {
114-
const auto* input = input_tensors[index];
115-
ORT_ENFORCE(input != nullptr, "input count mismatch");
116-
const auto& input_shape = input->Shape();
117-
const auto input_dims = input_shape.GetDims();
118-
119-
// Skip shape/rank validation for inputs that are empty.
120-
// The ONNX spec states that all dim values along axes not concatentated on
121-
// need to be the same for all inputs (empty inputs are not explicitly exempted).
122-
// The model in GH issue 8020 has a bunch of Loop nodes all feeding into
123-
// the 'Concat' node and one of these Loops tend to have an iteration
124-
// count of 0 for some inputs. If the iteration count for a Loop is zero,
125-
// we don't execute its subgraph (since the outputs are going to be empty anyway)
126-
// and we send an "empty" tensor(s) downstream and use ONNX shape inferred shape
127-
// to "compose" the shape for these empty tensor(s).
128-
// If we encounter symbolic dims in the ONNX shape inferred shape, we place a '0'
129-
// in that position and due to the "lossy" nature of this process, the inputs' shape
130-
// validation for such empty inputs fail and hence we skip these validations for all
131-
// empty inputs.
132-
// This isn't too bad as we will never use empty inputs while concatenating anyway.
133-
// We just loosen this check to unblock model in GH issue 8020 to complete processing.
134-
if (input_shape.Size() == 0) {
135-
input_tensor_sizes.push_back(0);
136-
} else {
137-
const size_t input_rank = input_dims.size();
138-
139-
ORT_ENFORCE(input_rank == reference_rank,
140-
"Ranks of input data are different, cannot concatenate them. expected rank: ",
141-
reference_rank, " got: ", input_rank);
142-
143-
// Ensure all the other (non-concat) axes match
144-
int64_t tensor_size = 1;
145-
for (size_t axis_index = 0; axis_index < reference_rank; ++axis_index) {
146-
auto dim_value = input_dims[axis_index];
147-
tensor_size *= dim_value;
148-
149-
// In 'concat' mode, the axis to be concatenated may be different
150-
// But in 'stack' mode, all input shapes must be the same and must be validated
151-
if (!is_stack_ && axis_index == p.axis)
152-
continue;
153-
154-
ORT_RETURN_IF_NOT(dim_value == reference_dims[axis_index],
155-
"Non concat axis dimensions must match: Axis ",
156-
axis_index, " has mismatched dimensions of ", dim_value,
157-
" and ", reference_dims[axis_index]);
158-
}
159-
160-
input_tensor_sizes.push_back(tensor_size); // assign the computed size of the input tensor
161-
}
162-
}
163-
164-
// Calculate the shape of the output tensor
165-
auto output_dims = reference_dims;
166-
167-
if (!is_stack_) { // 'Concat' mode
168-
// While concatenating, the rank of the output is the same as the input rank(s)
169-
170-
// Calculate the size of the concatenated axis
171-
size_t concat_axis_size = 0;
172-
for (size_t index = 0; index < input_count; index++) {
173-
concat_axis_size += onnxruntime::narrow<size_t>(input_tensors[index]->Shape()[onnxruntime::narrow<size_t>(p.axis)]);
174-
}
175-
176-
output_dims[onnxruntime::narrow<size_t>(p.axis)] = onnxruntime::narrow<int64_t>(concat_axis_size);
177-
} else { // 'Stack' mode
178-
// While stacking, the rank of the output is one more than the input rank(s).
179-
// Stacking may be thought of as adding an unit dimension (of value 1) in the input tensors,
180-
// and concatenating them on thie new axis.
181-
// The value in the corresponding axis of the output will be the number of inputs that are being stacked.
182-
output_dims.insert(output_dims.begin() + p.axis, static_cast<int64_t>(input_count));
183-
}
184-
185-
TensorShape output_shape(output_dims);
186-
187-
// Create output tensor
188-
p.output_tensor = &(*ctx->Output(0, output_shape));
189-
190-
// Make note if output tensor is going to be empty
191-
p.output_num_elements = output_shape.Size();
192-
193-
// No need to proceed further if output is going to be empty
194-
if (p.output_num_elements == 0)
195-
return Status::OK();
196-
197-
// The output_axis_pitch is the number of elements to add to move to the next split axis in the output.
198-
// Can handle stacking as well.
199-
p.output_axis_pitch = 1;
200-
auto output_rank = !is_stack_ ? reference_rank : reference_rank + 1;
201-
for (size_t i = output_rank; i-- > p.axis;) {
202-
p.output_axis_pitch *= output_dims[i];
203-
}
204-
205-
// Fill the 'Prepare' struct with available information
206-
p.inputs.reserve(input_count);
207-
for (size_t input_index = 0; input_index < input_count; input_index++) {
208-
const Tensor* data_n_ptr = input_tensors[input_index];
209-
auto& data_n = *data_n_ptr;
210-
211-
// Type sanity check (Make sure we are working on homogeneous types)
212-
ORT_RETURN_IF_NOT(data_n.DataType() == p.output_tensor->DataType(), "Data type mismatch");
213-
214-
// The input_axis_pitch is the number of elements to add to move to the next split axis in the input
215-
// Can handle stacking as well (as the "new dummy dimension" in the input is of unit value).
216-
// TODO: Minor Optimization possibility: This input_axis_patch will be common across all inputs
217-
// in 'ConcatFromSequence' (stack mode). They have to be computed for each input only while concatenating.
218-
int64_t input_axis_pitch = 1;
219-
const auto& data_dims = data_n.Shape().GetDims();
220-
for (size_t i = reference_rank; i-- > p.axis;) {
221-
input_axis_pitch *= data_dims[i];
222-
}
223-
224-
p.inputs.push_back({&data_n, input_axis_pitch, input_tensor_sizes[input_index]});
225-
}
226-
227-
// Make note if the input Tensors of type 'string'
228-
p.is_string_type = p.inputs[0].tensor->IsDataTypeString();
229-
230-
return Status::OK();
231-
}
232-
23352
namespace {
23453
TensorShapeVector StridesForStack(const TensorShapeVector& full_strides, uint64_t axis) {
23554
// if we are stacking, skip the dimension that will be stacked along in the output strides

0 commit comments

Comments
 (0)