Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Status GemmOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
if (1 == input_trans_flag.at(input_i)) {
ORT_RETURN_IF_ERROR(quantize_param.HandleTranspose<size_t>(std::vector<size_t>({1, 0})));
ORT_RETURN_IF_ERROR(
utils::TwoDimensionTranspose(qnn_model_wrapper, input_shape, *input_tensor, unpacked_tensor));
utils::TwoDimensionTranspose(qnn_model_wrapper, input_shape, *input_tensor, unpacked_tensor, logger));
} else {
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ Status MatMulOpBuilder::ProcessInputsForQnnFullyConnected(QnnModelWrapper& qnn_m
ORT_RETURN_IF_ERROR(utils::TwoDimensionTranspose(qnn_model_wrapper,
original_shape_copy, // Will be modified to new shape (unnecessary)
*input_info_1.initializer_tensor,
unpacked_tensor));
unpacked_tensor,
logger));
} else {
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info_1.initializer_tensor, unpacked_tensor));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& act_dql_node_unit,
const NodeUnit& gemm_node_unit,
const NodeUnit& output_ql_node_unit,
const logging::Logger& logger,
bool validate);

std::unique_ptr<IQnnNodeGroup> LowPowerBlockQuantizedGemmFusion::TryFusion(
Expand Down Expand Up @@ -115,6 +116,7 @@ std::unique_ptr<IQnnNodeGroup> LowPowerBlockQuantizedGemmFusion::TryFusion(
*p_act_dql_node_unit,
gemm_node_unit,
*p_output_ql_node_unit,
logger,
true);
!status.IsOK()) {
return nullptr;
Expand Down Expand Up @@ -143,13 +145,11 @@ LowPowerBlockQuantizedGemmFusion::LowPowerBlockQuantizedGemmFusion(const NodeUni
}

Status LowPowerBlockQuantizedGemmFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const {
ORT_UNUSED_PARAMETER(logger);
return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], *node_units_[2], *node_units_[3], *node_units_[4], *node_units_[5], true);
return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], *node_units_[2], *node_units_[3], *node_units_[4], *node_units_[5], logger, true);
}

Status LowPowerBlockQuantizedGemmFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const {
ORT_UNUSED_PARAMETER(logger);
return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], *node_units_[2], *node_units_[3], *node_units_[4], *node_units_[5], false);
return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], *node_units_[2], *node_units_[3], *node_units_[4], *node_units_[5], logger, false);
}

gsl::span<const NodeUnit* const> LowPowerBlockQuantizedGemmFusion::GetNodeUnits() const {
Expand All @@ -164,13 +164,15 @@ Status UnpackWeightTensorData(const QnnModelWrapper& qnn_model_wrapper,
const onnx::TensorProto* weight_tensor_proto,
std::vector<uint32_t>& weight_shape,
int64_t input_channel_axis,
std::vector<uint8_t>& unpacked_tensor) {
std::vector<uint8_t>& unpacked_tensor,
const logging::Logger& logger,
bool validate) {
ORT_RETURN_IF_NOT(weight_tensor_proto != nullptr, "Weight tensor proto is null");

if (input_channel_axis == 0) {
// Transpose to keep output_channel at index 0;
// This is needed for proper LPBQ encoding where output channels must be at dimension 0
return utils::TwoDimensionTranspose(qnn_model_wrapper, weight_shape, *weight_tensor_proto, unpacked_tensor);
return utils::TwoDimensionTranspose(qnn_model_wrapper, weight_shape, *weight_tensor_proto, unpacked_tensor, logger, validate);
} else {
// No transpose needed, just unpack the initializer data
return qnn_model_wrapper.UnpackInitializerData(*weight_tensor_proto, unpacked_tensor);
Expand All @@ -184,6 +186,7 @@ Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& act_dql_node_unit,
const NodeUnit& gemm_node_unit,
const NodeUnit& output_ql_node_unit,
const logging::Logger& logger,
bool validate) {
assert(scale_dql_node_unit.OpType() == "DequantizeLinear" &&
w_ql_node_unit.OpType() == "QuantizeLinear" &&
Expand Down Expand Up @@ -255,7 +258,7 @@ Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper,
std::vector<uint8_t> unpacked_tensor;
Qnn_DataType_t weight_data_type = is_int4_type ? QNN_DATATYPE_SFIXED_POINT_4 : QNN_DATATYPE_SFIXED_POINT_8;
const auto& weight_tensor_proto = qnn_model_wrapper.GetConstantTensor(weight_tensor_name);
ORT_RETURN_IF_ERROR(UnpackWeightTensorData(qnn_model_wrapper, weight_tensor_proto, weight_shape, input_channel_axis, unpacked_tensor));
ORT_RETURN_IF_ERROR(UnpackWeightTensorData(qnn_model_wrapper, weight_tensor_proto, weight_shape, input_channel_axis, unpacked_tensor, logger, validate));

// Quantize weight tensor
size_t weight_elements = unpacked_tensor.size() / sizeof(float);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,15 @@ Status UnpackWeightTensorData(const QnnModelWrapper& qnn_model_wrapper,
const onnx::TensorProto* weight_tensor_proto,
std::vector<uint32_t>& weight_shape,
int64_t& input_channel_axis,
std::vector<uint8_t>& unpacked_tensor) {
std::vector<uint8_t>& unpacked_tensor,
const logging::Logger& logger) {
ORT_RETURN_IF_NOT(weight_tensor_proto != nullptr, "Weight tensor proto is null");

if (input_channel_axis == 0) {
// Transpose to keep output_channel at index 0;
// The current logic that quantizes with LPBQ encodings requires out_channels at index 0
input_channel_axis = weight_shape.size() - 1;
return utils::TwoDimensionTranspose(qnn_model_wrapper, weight_shape, *weight_tensor_proto, unpacked_tensor);
return utils::TwoDimensionTranspose(qnn_model_wrapper, weight_shape, *weight_tensor_proto, unpacked_tensor, logger);
} else {
// No transpose needed, just unpack the initializer data
return qnn_model_wrapper.UnpackInitializerData(*weight_tensor_proto, unpacked_tensor);
Expand Down Expand Up @@ -273,7 +274,7 @@ Status ProcessLPBQWeight(QnnModelWrapper& qnn_model_wrapper,
std::vector<uint8_t> unpacked_tensor;
const auto& weight_tensor_proto = qnn_model_wrapper.GetConstantTensor(weight_tensor_name);
// if input_channel_axis = 0, UnpackWeightTensorData will transpose and keep output_channel at 0
ORT_RETURN_IF_ERROR(UnpackWeightTensorData(qnn_model_wrapper, weight_tensor_proto, weight_shape, input_channel_axis, unpacked_tensor));
ORT_RETURN_IF_ERROR(UnpackWeightTensorData(qnn_model_wrapper, weight_tensor_proto, weight_shape, input_channel_axis, unpacked_tensor, logger));

// Quantize weight tensor
size_t weight_elements = unpacked_tensor.size() / sizeof(float);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ bool CheckShape(const Node& reshape_node) {
}

Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& reshape_node_unit,
const NodeUnit& gemm_node_unit, bool validate) {
const NodeUnit& gemm_node_unit, const logging::Logger& logger, bool validate) {
assert(reshape_node_unit.OpType() == "Reshape" && gemm_node_unit.OpType() == "Gemm");
const auto& node_name = utils::GetUniqueName(gemm_node_unit);
const NodeUnitIODef& input_def = reshape_node_unit.Inputs()[0];
Expand All @@ -91,7 +91,7 @@ Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit&
ORT_RETURN_IF_ERROR(utils::GetQnnDataType(false, weight_def.node_arg.TypeAsProto(), data_type));
const auto& weight_tensor_proto = qnn_model_wrapper.GetConstantTensor(weight_tensor_name);
ORT_RETURN_IF_ERROR(
utils::TwoDimensionTranspose(qnn_model_wrapper, weight_shape, *weight_tensor_proto, unpacked_tensor));
utils::TwoDimensionTranspose(qnn_model_wrapper, weight_shape, *weight_tensor_proto, unpacked_tensor, logger, validate));
QnnTensorWrapper weight_tensor(weight_tensor_name, tensor_type, data_type, QnnQuantParamsWrapper(),
std::move(weight_shape), std::move(unpacked_tensor));
if (has_bias) {
Expand Down Expand Up @@ -169,12 +169,12 @@ ReshapeGemmFusion::ReshapeGemmFusion(const NodeUnit& reshape_node_unit, const No
node_units_[1] = &gemm_node_unit;
}

Status ReshapeGemmFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const {
return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], true);
Status ReshapeGemmFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const {
return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], logger, true);
}

Status ReshapeGemmFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const {
return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], false);
Status ReshapeGemmFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const {
return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], logger, false);
}

gsl::span<const NodeUnit* const> ReshapeGemmFusion::GetNodeUnits() const {
Expand Down
26 changes: 20 additions & 6 deletions onnxruntime/core/providers/qnn/builder/qnn_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1255,10 +1255,13 @@ static Status TransposeDataRank5(const TensorShape& input_shape,
return Status::OK();
}

// Use use_dummy_tensor flag when performing only QNN op validation and no real tensor data is required.
Status TwoDimensionTranspose(const QnnModelWrapper& qnn_model_wrapper,
std::vector<uint32_t>& data_shape,
const onnx::TensorProto& initializer,
std::vector<uint8_t>& transposed_data) {
std::vector<uint8_t>& transposed_data,
const logging::Logger& logger,
bool use_dummy_tensor) {
ORT_RETURN_IF_NOT(data_shape.size() == 2, "Expected shape of rank 2");

std::array<size_t, 2> perm = {1, 0};
Expand All @@ -1271,12 +1274,23 @@ Status TwoDimensionTranspose(const QnnModelWrapper& qnn_model_wrapper,

std::vector<uint8_t> input_buffer;
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(initializer, input_buffer));
transposed_data.resize(input_buffer.size());
transposed_data.resize(input_buffer.size(), 0);

if (use_dummy_tensor) { // Only shape & dtype validation are needed, no need for real tensor
LOGS(logger, VERBOSE) << "Only shape and dtype validation are required, so we can use dummy tensor to avoid heavy memcpy.";
data_shape = std::move(output_shape); // Update parameter with final transposed shape
return Status::OK();
}

// Actual tensor content is required.
const size_t rows = data_shape[0];
const size_t cols = data_shape[1];
const size_t output_cols = output_shape[1];

for (size_t row = 0; row < data_shape[0]; row++) {
for (size_t col = 0; col < data_shape[1]; col++) {
const size_t src_elem_index = (row * data_shape[1] + col);
const size_t dst_elem_index = (col * output_shape[1] + row);
for (size_t row = 0; row < rows; row++) {
for (size_t col = 0; col < cols; col++) {
const size_t src_elem_index = (row * cols + col);
const size_t dst_elem_index = (col * output_cols + row);
const size_t src_byte_index = src_elem_index * elem_byte_size;
const size_t dst_byte_index = dst_elem_index * elem_byte_size;
assert(src_byte_index < input_buffer.size());
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/qnn/builder/qnn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,9 @@ Status TransposeFromCnhwToHwcn(std::vector<int64_t>&& input_shape_dims,
Status TwoDimensionTranspose(const QnnModelWrapper& qnn_model_wrapper,
std::vector<uint32_t>& data_shape,
const onnx::TensorProto& initializer,
std::vector<uint8_t>& transposed_data);
std::vector<uint8_t>& transposed_data,
const logging::Logger& logger,
bool use_dummy_tensor = false);

Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper,
const std::string& convert_input_name,
Expand Down