Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions onnxruntime/core/providers/cpu/quantization/conv_integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class ConvInteger : public OpKernel {
Status Compute(OpKernelContext* context) const override;

ConvAttributes conv_attrs_;

private:
template <typename XT, typename WT>
Status ComputeInner(OpKernelContext* context) const;
};

ONNX_OPERATOR_KERNEL_EX(
Expand All @@ -28,12 +32,15 @@ ONNX_OPERATOR_KERNEL_EX(
10,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<uint8_t>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>())
.TypeConstraint("T1", {DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()})
.TypeConstraint("T2", {DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()})
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
ConvInteger);

Status ConvInteger::Compute(OpKernelContext* context) const {
template <typename XT, typename WT>
Status ConvInteger::ComputeInner(OpKernelContext* context) const {
const auto input_defs = Node().InputDefs();
size_t num_inputs = input_defs.size();
const auto* X = context->Input<Tensor>(0);
Expand All @@ -43,12 +50,12 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
if (num_inputs >= 3 && input_defs[2]->Exists()) {
const auto* X_Zero_Point = context->Input<Tensor>(2);
ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1.");
input_offset = *(X_Zero_Point->Data<uint8_t>());
input_offset = *static_cast<const uint8_t*>(X_Zero_Point->DataRaw());
}
if (num_inputs >= 4 && input_defs[3]->Exists()) {
const auto* W_Zero_Point = context->Input<Tensor>(3);
ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now.");
filter_offset = *(W_Zero_Point->Data<uint8_t>());
filter_offset = *static_cast<const uint8_t*>(W_Zero_Point->DataRaw());
}

const int64_t N = X->Shape()[0];
Expand Down Expand Up @@ -110,16 +117,16 @@ Status ConvInteger::Compute(OpKernelContext* context) const {

concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();

const auto* Xdata = X->Data<uint8_t>();
const auto* Wdata = W->Data<uint8_t>();
const auto* Xdata = static_cast<const uint8_t*>(X->DataRaw());
const auto* Wdata = static_cast<const uint8_t*>(W->DataRaw());
auto* Ydata = Y->MutableData<int32_t>();

for (int image_id = 0; image_id < N; ++image_id) {
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
if (col_buffer_data != nullptr) {
if (kernel_rank == 2) {
math::Im2col<uint8_t, StorageOrder::NCHW>()(
Xdata,
math::Im2col<XT, StorageOrder::NCHW>()(
reinterpret_cast<const XT*>(Xdata),
C / conv_attrs_.group,
input_shape[0],
input_shape[1],
Expand All @@ -133,11 +140,11 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
pads[3],
strides[0],
strides[1],
col_buffer_data,
input_offset);
reinterpret_cast<XT*>(col_buffer_data),
static_cast<XT>(input_offset));
} else {
math::Im2col<uint8_t, StorageOrder::NCHW>()(
Xdata,
math::Im2col<XT, StorageOrder::NCHW>()(
reinterpret_cast<const XT*>(Xdata),
input_shape.GetDims().data(),
output_shape.GetDims().data(),
kernel_dim,
Expand All @@ -146,35 +153,53 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
dilations.data(),
pads.data(),
static_cast<int>(kernel_rank),
col_buffer_data,
reinterpret_cast<XT*>(col_buffer_data),
false,
input_offset);
static_cast<XT>(input_offset));
}
}

MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
gemm_shape.M = static_cast<size_t>(M / conv_attrs_.group);
gemm_shape.N = static_cast<size_t>(output_image_size);
gemm_shape.K = static_cast<size_t>(kernel_dim);
gemm_shape.AIsSigned = W->IsDataType<int8_t>();
gemm_shape.BIsSigned = X->IsDataType<int8_t>();

MLAS_GEMM_QUANT_DATA_PARAMS gemm_params;
gemm_params.A = Wdata + group_id * W_offset;
gemm_params.lda = static_cast<size_t>(kernel_dim);
gemm_params.ZeroPointA = filter_offset;
gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data,
gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data;
gemm_params.ldb = static_cast<size_t>(output_image_size);
gemm_params.ZeroPointB = &input_offset;
gemm_params.C = Ydata;
gemm_params.ldc = static_cast<size_t>(output_image_size);

MlasGemm(gemm_shape, gemm_params, thread_pool);

Xdata += X_offset;
Xdata = reinterpret_cast<const uint8_t*>(X_offset + reinterpret_cast<const XT*>(Xdata));
Ydata += Y_offset;
}
}

return Status::OK();
}

Status ConvInteger::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
const auto* W = context->Input<Tensor>(1);
if (X->IsDataType<int8_t>()) {
if (W->IsDataType<int8_t>())
return ComputeInner<int8_t, int8_t>(context);
else
return ComputeInner<int8_t, uint8_t>(context);
} else {
if (W->IsDataType<int8_t>())
return ComputeInner<uint8_t, int8_t>(context);
else
return ComputeInner<uint8_t, uint8_t>(context);
}
}

} // namespace onnxruntime
1 change: 1 addition & 0 deletions onnxruntime/core/util/math_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ void Im2col<T, StorageOrder::NCHW>::operator()(

template struct Im2col<float, StorageOrder::NCHW>;
template struct Im2col<uint8_t, StorageOrder::NCHW>;
template struct Im2col<int8_t, StorageOrder::NCHW>;

template <typename T>
void Im2col<T, StorageOrder::NHWC>::operator()(
Expand Down
Loading