Skip to content

Commit 1402991

Browse files
committed
serialize scales as bf16
1 parent fff7b3c commit 1402991

File tree

5 files changed

+91
-11
lines changed

5 files changed

+91
-11
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,19 +274,46 @@ def get_per_channel_dtype(
274274

275275
return dtype
276276

277-
def get_quant_params(self, quant_params: QuantParams) -> XNNQuantParams:
277+
def get_quant_params(
278+
self, quant_params: QuantParams, xnn_graph: XNNGraph
279+
) -> XNNQuantParams:
278280
if quant_params.per_channel:
279281
scale = cast(torch.Tensor, quant_params.scale)
282+
buffer_idx = len(xnn_graph.constant_data)
283+
num_scales = scale.numel()
284+
285+
if quant_params.is_per_channel_group:
286+
scale = scale.to(torch.bfloat16)
287+
288+
num_bytes = scale.untyped_storage().nbytes()
289+
scale_array = ctypes.cast(
290+
scale.untyped_storage().data_ptr(),
291+
ctypes.POINTER(ctypes.c_char * num_bytes),
292+
).contents
293+
scale_name = hashlib.sha256(bytes(scale_array)).hexdigest()
294+
xnn_graph.constant_data.append(
295+
ConstantDataOffset(
296+
offset=UINT64_MAX, size=num_bytes, named_key=scale_name
297+
)
298+
)
299+
self._named_data_store.add_named_data(
300+
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT
301+
)
302+
280303
if quant_params.is_per_channel_group:
281304
return PerChannelGroupQuant(
282-
scale=scale.flatten().tolist(),
305+
scale=[],
283306
channel_dim=quant_params.axis,
284307
group_size=quant_params.group_size,
308+
scale_buffer_idx=buffer_idx,
309+
num_scales=num_scales,
285310
)
286-
else: # per_channel quant
311+
else:
287312
return PerChannelQuant(
288-
scale=scale.tolist(),
313+
scale=[],
289314
channel_dim=quant_params.axis,
315+
scale_buffer_idx=buffer_idx,
316+
num_scales=num_scales,
290317
)
291318
elif quant_params.is_dynamic:
292319
# NB:
@@ -449,7 +476,7 @@ def define_tensor( # noqa: C901
449476
else XValue(
450477
xvalue_union=XNNQuantizedTensorValue(
451478
tensor_value=tvalue,
452-
quant_params=self.get_quant_params(quant_params),
479+
quant_params=self.get_quant_params(quant_params, xnn_graph),
453480
)
454481
)
455482
)

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -421,11 +421,32 @@ Error defineTensor(
421421
qparams->channel_dim(),
422422
dtype,
423423
zero_point);
424+
425+
const float* scale = qparams->scale()->data();
426+
427+
if (qparams->scale_buffer_idx() != 0) {
428+
// if scales are stored in named data, then retrieve it
429+
ConstantDataOffsetPtr scale_buffer_offset =
430+
flatbuffer_graph->constant_data()->Get(
431+
qparams->scale_buffer_idx());
432+
const std::string& data_name =
433+
scale_buffer_offset->named_key()->str();
434+
Result<FreeableBuffer> scale_buffer =
435+
named_data_map->get_data(data_name.c_str());
436+
ET_CHECK_OR_RETURN_ERROR(
437+
scale_buffer.ok(),
438+
Internal,
439+
"Failed to get constant data for key %s from named_data_map. Error code: %u",
440+
data_name.c_str(),
441+
static_cast<uint32_t>(scale_buffer.error()));
442+
scale = reinterpret_cast<const float*>(scale_buffer.get().data());
443+
freeable_buffers.push_back(std::move(scale_buffer.get()));
444+
}
424445
status = xnn_define_channelwise_quantized_tensor_value_v2(
425446
/*subgraph=*/subgraph_ptr,
426447
/*datatype=*/dtype,
427448
/*zero_point=*/zero_point,
428-
/*scale=*/qparams->scale()->data(),
449+
/*scale=*/scale,
429450
/*num_dims=*/tensor_value->num_dims(),
430451
/*channel_dim*/ qparams->channel_dim(),
431452
/*dims=*/dims_data.data(),
@@ -452,10 +473,24 @@ Error defineTensor(
452473

453474
// Block scales are preferably serialized as bf16 but can also be
454475
// serialized as fp32 for backwards compatability.
455-
if (qparams->scale_bf16() != nullptr) {
476+
if (qparams->scale_buffer_idx() != 0) {
477+
ConstantDataOffsetPtr scale_buffer_offset =
478+
flatbuffer_graph->constant_data()->Get(
479+
qparams->scale_buffer_idx());
480+
const std::string& data_name =
481+
scale_buffer_offset->named_key()->str();
482+
Result<FreeableBuffer> scale_buffer =
483+
named_data_map->get_data(data_name.c_str());
484+
ET_CHECK_OR_RETURN_ERROR(
485+
scale_buffer.ok(),
486+
Internal,
487+
"Failed to get constant data for key %s from named_data_map. Error code: %u",
488+
data_name.c_str(),
489+
static_cast<uint32_t>(scale_buffer.error()));
456490
scale_data =
457-
static_cast<const uint16_t*>(qparams->scale_bf16()->data());
458-
scale_numel = qparams->scale_bf16()->size();
491+
reinterpret_cast<const uint16_t*>(scale_buffer.get().data());
492+
freeable_buffers.push_back(std::move(scale_buffer.get()));
493+
scale_numel = qparams->num_scales();
459494
} else {
460495
// Read fp32 scales, convert to bf16.
461496
auto conv_buffer = static_cast<uint16_t*>(allocator.allocateTemporary(

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ table Buffer {
4848
table PerChannelQuant {
4949
scale:[float];
5050
channel_dim:int;
51+
scale_buffer_idx: uint;
52+
num_scales: uint;
5153
}
5254

5355
table PerTokenDynamicQuant {
@@ -63,7 +65,9 @@ table PerChannelGroupQuant {
6365
scale:[float];
6466
channel_dim:int;
6567
group_size:int;
66-
scale_bf16:[ushort];
68+
scale_bf16:[ushort] (deprecated);
69+
scale_buffer_idx: uint;
70+
num_scales: uint;
6771
}
6872

6973
table XNNTensorValue {

backends/xnnpack/serialization/schema.fbs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,16 @@ table PerChannelGroupQuant {
4848
scale:[float];
4949
channel_dim:int;
5050
group_size:int;
51-
scale_bf16:[ushort];
51+
scale_bf16:[ushort] (deprecated);
52+
scale_buffer_idx: uint;
53+
num_scales: uint;
5254
}
5355

5456
table PerChannelQuant {
5557
scale:[float];
5658
channel_dim:int;
59+
scale_buffer_idx: uint;
60+
num_scales: uint;
5761
}
5862

5963
table PerTokenDynamicQuant {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,13 +419,23 @@ class XNNDatatype(IntEnum):
419419
class PerChannelQuant:
420420
scale: List[float]
421421
channel_dim: int
422+
scale_buffer_idx: int = -1
423+
num_scales: int = -1
424+
425+
426+
@dataclass
427+
class Buffer:
428+
storage: bytes
422429

423430

424431
@dataclass
425432
class PerChannelGroupQuant:
426433
scale: List[float]
427434
channel_dim: int
428435
group_size: int = 1
436+
scale_bf16: Optional[List[float]] = None
437+
scale_buffer_idx: int = -1
438+
num_scales: int = -1
429439

430440

431441
@dataclass

0 commit comments

Comments
 (0)