Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,19 +274,46 @@ def get_per_channel_dtype(

return dtype

def get_quant_params(self, quant_params: QuantParams) -> XNNQuantParams:
def get_quant_params(
self, quant_params: QuantParams, xnn_graph: XNNGraph
) -> XNNQuantParams:
if quant_params.per_channel:
scale = cast(torch.Tensor, quant_params.scale)
buffer_idx = len(xnn_graph.constant_data)
num_scales = scale.numel()

if quant_params.is_per_channel_group:
scale = scale.to(torch.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be a flag (default=bf16), which lines up with QB4W xnnpack flag to use bf16. We can error out at AoT for fp32 given we can't run that yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what you mean, won't almost all scales that come to use be fp32?


num_bytes = scale.untyped_storage().nbytes()
scale_array = ctypes.cast(
scale.untyped_storage().data_ptr(),
ctypes.POINTER(ctypes.c_char * num_bytes),
).contents
scale_name = hashlib.sha256(bytes(scale_array)).hexdigest()
xnn_graph.constant_data.append(
ConstantDataOffset(
offset=UINT64_MAX, size=num_bytes, named_key=scale_name
)
)
self._named_data_store.add_named_data(
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT
)

if quant_params.is_per_channel_group:
return PerChannelGroupQuant(
scale=scale.flatten().tolist(),
scale=[],
channel_dim=quant_params.axis,
group_size=quant_params.group_size,
scale_buffer_idx=buffer_idx,
num_scales=num_scales,
)
else: # per_channel quant
else:
return PerChannelQuant(
scale=scale.tolist(),
scale=[],
channel_dim=quant_params.axis,
scale_buffer_idx=buffer_idx,
num_scales=num_scales,
)
elif quant_params.is_dynamic:
# NB:
Expand Down Expand Up @@ -449,7 +476,7 @@ def define_tensor( # noqa: C901
else XValue(
xvalue_union=XNNQuantizedTensorValue(
tensor_value=tvalue,
quant_params=self.get_quant_params(quant_params),
quant_params=self.get_quant_params(quant_params, xnn_graph),
)
)
)
Expand Down
43 changes: 39 additions & 4 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,11 +421,32 @@ Error defineTensor(
qparams->channel_dim(),
dtype,
zero_point);

const float* scale = qparams->scale()->data();

if (qparams->scale_buffer_idx() != 0) {
// if scales are stored in named data, then retrieve it
ConstantDataOffsetPtr scale_buffer_offset =
flatbuffer_graph->constant_data()->Get(
qparams->scale_buffer_idx());
const std::string& data_name =
scale_buffer_offset->named_key()->str();
Result<FreeableBuffer> scale_buffer =
named_data_map->get_data(data_name.c_str());
ET_CHECK_OR_RETURN_ERROR(
scale_buffer.ok(),
Internal,
"Failed to get constant data for key %s from named_data_map. Error code: %u",
data_name.c_str(),
static_cast<uint32_t>(scale_buffer.error()));
scale = reinterpret_cast<const float*>(scale_buffer.get().data());
freeable_buffers.push_back(std::move(scale_buffer.get()));
}
status = xnn_define_channelwise_quantized_tensor_value_v2(
/*subgraph=*/subgraph_ptr,
/*datatype=*/dtype,
/*zero_point=*/zero_point,
/*scale=*/qparams->scale()->data(),
/*scale=*/scale,
/*num_dims=*/tensor_value->num_dims(),
/*channel_dim*/ qparams->channel_dim(),
/*dims=*/dims_data.data(),
Expand All @@ -452,10 +473,24 @@ Error defineTensor(

// Block scales are preferably serialized as bf16 but can also be
// serialized as fp32 for backwards compatability.
if (qparams->scale_bf16() != nullptr) {
if (qparams->scale_buffer_idx() != 0) {
ConstantDataOffsetPtr scale_buffer_offset =
flatbuffer_graph->constant_data()->Get(
qparams->scale_buffer_idx());
const std::string& data_name =
scale_buffer_offset->named_key()->str();
Result<FreeableBuffer> scale_buffer =
named_data_map->get_data(data_name.c_str());
ET_CHECK_OR_RETURN_ERROR(
scale_buffer.ok(),
Internal,
"Failed to get constant data for key %s from named_data_map. Error code: %u",
data_name.c_str(),
static_cast<uint32_t>(scale_buffer.error()));
scale_data =
static_cast<const uint16_t*>(qparams->scale_bf16()->data());
scale_numel = qparams->scale_bf16()->size();
reinterpret_cast<const uint16_t*>(scale_buffer.get().data());
freeable_buffers.push_back(std::move(scale_buffer.get()));
scale_numel = qparams->num_scales();
} else {
// Read fp32 scales, convert to bf16.
auto conv_buffer = static_cast<uint16_t*>(allocator.allocateTemporary(
Expand Down
6 changes: 5 additions & 1 deletion backends/xnnpack/serialization/runtime_schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ table Buffer {
table PerChannelQuant {
scale:[float];
channel_dim:int;
scale_buffer_idx: uint;
num_scales: uint;
}

table PerTokenDynamicQuant {
Expand All @@ -63,7 +65,9 @@ table PerChannelGroupQuant {
scale:[float];
channel_dim:int;
group_size:int;
scale_bf16:[ushort];
scale_bf16:[ushort] (deprecated);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, why mark this as deprecated but not float if we are moving to ndm for evreythig?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because this actually was never used, since we never added the export path to serialize into this field. scale[float] is still used by older versions

scale_buffer_idx: uint;
num_scales: uint;
}

table XNNTensorValue {
Expand Down
6 changes: 5 additions & 1 deletion backends/xnnpack/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,16 @@ table PerChannelGroupQuant {
scale:[float];
channel_dim:int;
group_size:int;
scale_bf16:[ushort];
scale_bf16:[ushort] (deprecated);
scale_buffer_idx: uint;
num_scales: uint;
}

table PerChannelQuant {
scale:[float];
channel_dim:int;
scale_buffer_idx: uint;
num_scales: uint;
}

table PerTokenDynamicQuant {
Expand Down
10 changes: 10 additions & 0 deletions backends/xnnpack/serialization/xnnpack_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,23 @@ class XNNDatatype(IntEnum):
class PerChannelQuant:
scale: List[float]
channel_dim: int
scale_buffer_idx: int = -1
num_scales: int = -1


@dataclass
class Buffer:
storage: bytes


@dataclass
class PerChannelGroupQuant:
scale: List[float]
channel_dim: int
group_size: int = 1
scale_bf16: Optional[List[float]] = None
scale_buffer_idx: int = -1
num_scales: int = -1


@dataclass
Expand Down
Loading