diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 8470184d808..b7d16b18bd1 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -274,19 +274,46 @@ def get_per_channel_dtype( return dtype - def get_quant_params(self, quant_params: QuantParams) -> XNNQuantParams: + def get_quant_params( + self, quant_params: QuantParams, xnn_graph: XNNGraph + ) -> XNNQuantParams: if quant_params.per_channel: scale = cast(torch.Tensor, quant_params.scale) + buffer_idx = len(xnn_graph.constant_data) + num_scales = scale.numel() + + if quant_params.is_per_channel_group: + scale = scale.to(torch.bfloat16) + + num_bytes = scale.untyped_storage().nbytes() + scale_array = ctypes.cast( + scale.untyped_storage().data_ptr(), + ctypes.POINTER(ctypes.c_char * num_bytes), + ).contents + scale_name = hashlib.sha256(bytes(scale_array)).hexdigest() + xnn_graph.constant_data.append( + ConstantDataOffset( + offset=UINT64_MAX, size=num_bytes, named_key=scale_name + ) + ) + self._named_data_store.add_named_data( + scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT + ) + if quant_params.is_per_channel_group: return PerChannelGroupQuant( - scale=scale.flatten().tolist(), + scale=[], channel_dim=quant_params.axis, group_size=quant_params.group_size, + scale_buffer_idx=buffer_idx, + num_scales=num_scales, ) - else: # per_channel quant + else: return PerChannelQuant( - scale=scale.tolist(), + scale=[], channel_dim=quant_params.axis, + scale_buffer_idx=buffer_idx, + num_scales=num_scales, ) elif quant_params.is_dynamic: # NB: @@ -449,7 +476,7 @@ def define_tensor( # noqa: C901 else XValue( xvalue_union=XNNQuantizedTensorValue( tensor_value=tvalue, - quant_params=self.get_quant_params(quant_params), + quant_params=self.get_quant_params(quant_params, xnn_graph), ) ) ) diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 9fd2c55bb83..56d0508bef0 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -421,11 +421,32 @@ Error defineTensor( qparams->channel_dim(), dtype, zero_point); + + const float* scale = qparams->scale()->data(); + + if (qparams->scale_buffer_idx() != 0) { + // if scales are stored in named data, then retrieve it + ConstantDataOffsetPtr scale_buffer_offset = + flatbuffer_graph->constant_data()->Get( + qparams->scale_buffer_idx()); + const std::string& data_name = + scale_buffer_offset->named_key()->str(); + Result scale_buffer = + named_data_map->get_data(data_name.c_str()); + ET_CHECK_OR_RETURN_ERROR( + scale_buffer.ok(), + Internal, + "Failed to get constant data for key %s from named_data_map. Error code: %u", + data_name.c_str(), + static_cast(scale_buffer.error())); + scale = reinterpret_cast(scale_buffer.get().data()); + freeable_buffers.push_back(std::move(scale_buffer.get())); + } status = xnn_define_channelwise_quantized_tensor_value_v2( /*subgraph=*/subgraph_ptr, /*datatype=*/dtype, /*zero_point=*/zero_point, - /*scale=*/qparams->scale()->data(), + /*scale=*/scale, /*num_dims=*/tensor_value->num_dims(), /*channel_dim*/ qparams->channel_dim(), /*dims=*/dims_data.data(), @@ -452,10 +473,24 @@ Error defineTensor( // Block scales are preferably serialized as bf16 but can also be // serialized as fp32 for backwards compatability. - if (qparams->scale_bf16() != nullptr) { + if (qparams->scale_buffer_idx() != 0) { + ConstantDataOffsetPtr scale_buffer_offset = + flatbuffer_graph->constant_data()->Get( + qparams->scale_buffer_idx()); + const std::string& data_name = + scale_buffer_offset->named_key()->str(); + Result scale_buffer = + named_data_map->get_data(data_name.c_str()); + ET_CHECK_OR_RETURN_ERROR( + scale_buffer.ok(), + Internal, + "Failed to get constant data for key %s from named_data_map. Error code: %u", + data_name.c_str(), + static_cast(scale_buffer.error())); scale_data = - static_cast(qparams->scale_bf16()->data()); - scale_numel = qparams->scale_bf16()->size(); + reinterpret_cast(scale_buffer.get().data()); + freeable_buffers.push_back(std::move(scale_buffer.get())); + scale_numel = qparams->num_scales(); } else { // Read fp32 scales, convert to bf16. auto conv_buffer = static_cast(allocator.allocateTemporary( diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index 79502ad4e51..d76c3c0807e 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -48,6 +48,8 @@ table Buffer { table PerChannelQuant { scale:[float]; channel_dim:int; + scale_buffer_idx: uint; + num_scales: uint; } table PerTokenDynamicQuant { @@ -63,7 +65,9 @@ table PerChannelGroupQuant { scale:[float]; channel_dim:int; group_size:int; - scale_bf16:[ushort]; + scale_bf16:[ushort] (deprecated); + scale_buffer_idx: uint; + num_scales: uint; } table XNNTensorValue { diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index a231ed05c5d..356df663dfc 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -48,12 +48,16 @@ table PerChannelGroupQuant { scale:[float]; channel_dim:int; group_size:int; - scale_bf16:[ushort]; + scale_bf16:[ushort] (deprecated); + scale_buffer_idx: uint; + num_scales: uint; } table PerChannelQuant { scale:[float]; channel_dim:int; + scale_buffer_idx: uint; + num_scales: uint; } table PerTokenDynamicQuant { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index 3a39fe98279..b8b4ea7f02f 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -425,6 +425,13 @@ class XNNDatatype(IntEnum): class PerChannelQuant: scale: List[float] channel_dim: int + scale_buffer_idx: int = -1 + num_scales: int = -1 + + +@dataclass +class Buffer: + storage: bytes @dataclass @@ -432,6 +439,9 @@ class PerChannelGroupQuant: scale: List[float] channel_dim: int group_size: int = 1 + scale_bf16: Optional[List[float]] = None + scale_buffer_idx: int = -1 + num_scales: int = -1 @dataclass