Skip to content

Commit 543cdb3

Browse files
authored
serialize scales as bf16 and serialize in Named Data Map (#11031)
XNNPACK Currently uses BF16 scales for running GEMMS with groupwise quantized weights. Currently we serialize scales as FP32, and then convert them to BF16 before passing to XNNPACK. We can save both memory and file size by serializing the scales as BF16 first. As an additional step here, we move the serialization of scales both for channelwise and groupwise quantized weights into the named data map. In the future, if we want to swap data that could be a potential feature because scales are no longer tied to the XNNPACK payload but can be swappable through the ptd file. cc @lucylq for the scale serialization ### Llama Experiments ``` -rw-r--r-- 1 maxren staff 1746392320 May 20 16:49 llama3_fp32_scales.pte -rw-r--r-- 1 maxren staff 1707798912 May 20 18:47 llama3_bf16_scales.pte ``` we see ~40 mb reduction in model size.
1 parent 4b15029 commit 543cdb3

File tree

5 files changed

+91
-11
lines changed

5 files changed

+91
-11
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,19 +274,46 @@ def get_per_channel_dtype(
274274

275275
return dtype
276276

277-
def get_quant_params(self, quant_params: QuantParams) -> XNNQuantParams:
277+
def get_quant_params(
278+
self, quant_params: QuantParams, xnn_graph: XNNGraph
279+
) -> XNNQuantParams:
278280
if quant_params.per_channel:
279281
scale = cast(torch.Tensor, quant_params.scale)
282+
buffer_idx = len(xnn_graph.constant_data)
283+
num_scales = scale.numel()
284+
285+
if quant_params.is_per_channel_group:
286+
scale = scale.to(torch.bfloat16)
287+
288+
num_bytes = scale.untyped_storage().nbytes()
289+
scale_array = ctypes.cast(
290+
scale.untyped_storage().data_ptr(),
291+
ctypes.POINTER(ctypes.c_char * num_bytes),
292+
).contents
293+
scale_name = hashlib.sha256(bytes(scale_array)).hexdigest()
294+
xnn_graph.constant_data.append(
295+
ConstantDataOffset(
296+
offset=UINT64_MAX, size=num_bytes, named_key=scale_name
297+
)
298+
)
299+
self._named_data_store.add_named_data(
300+
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT
301+
)
302+
280303
if quant_params.is_per_channel_group:
281304
return PerChannelGroupQuant(
282-
scale=scale.flatten().tolist(),
305+
scale=[],
283306
channel_dim=quant_params.axis,
284307
group_size=quant_params.group_size,
308+
scale_buffer_idx=buffer_idx,
309+
num_scales=num_scales,
285310
)
286-
else: # per_channel quant
311+
else:
287312
return PerChannelQuant(
288-
scale=scale.tolist(),
313+
scale=[],
289314
channel_dim=quant_params.axis,
315+
scale_buffer_idx=buffer_idx,
316+
num_scales=num_scales,
290317
)
291318
elif quant_params.is_dynamic:
292319
# NB:
@@ -449,7 +476,7 @@ def define_tensor( # noqa: C901
449476
else XValue(
450477
xvalue_union=XNNQuantizedTensorValue(
451478
tensor_value=tvalue,
452-
quant_params=self.get_quant_params(quant_params),
479+
quant_params=self.get_quant_params(quant_params, xnn_graph),
453480
)
454481
)
455482
)

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -421,11 +421,32 @@ Error defineTensor(
421421
qparams->channel_dim(),
422422
dtype,
423423
zero_point);
424+
425+
const float* scale = qparams->scale()->data();
426+
427+
if (qparams->scale_buffer_idx() != 0) {
428+
// if scales are stored in named data, then retrieve it
429+
ConstantDataOffsetPtr scale_buffer_offset =
430+
flatbuffer_graph->constant_data()->Get(
431+
qparams->scale_buffer_idx());
432+
const std::string& data_name =
433+
scale_buffer_offset->named_key()->str();
434+
Result<FreeableBuffer> scale_buffer =
435+
named_data_map->get_data(data_name.c_str());
436+
ET_CHECK_OR_RETURN_ERROR(
437+
scale_buffer.ok(),
438+
Internal,
439+
"Failed to get constant data for key %s from named_data_map. Error code: %u",
440+
data_name.c_str(),
441+
static_cast<uint32_t>(scale_buffer.error()));
442+
scale = reinterpret_cast<const float*>(scale_buffer.get().data());
443+
freeable_buffers.push_back(std::move(scale_buffer.get()));
444+
}
424445
status = xnn_define_channelwise_quantized_tensor_value_v2(
425446
/*subgraph=*/subgraph_ptr,
426447
/*datatype=*/dtype,
427448
/*zero_point=*/zero_point,
428-
/*scale=*/qparams->scale()->data(),
449+
/*scale=*/scale,
429450
/*num_dims=*/tensor_value->num_dims(),
430451
/*channel_dim*/ qparams->channel_dim(),
431452
/*dims=*/dims_data.data(),
@@ -452,10 +473,24 @@ Error defineTensor(
452473

453474
// Block scales are preferably serialized as bf16 but can also be
454475
// serialized as fp32 for backwards compatability.
455-
if (qparams->scale_bf16() != nullptr) {
476+
if (qparams->scale_buffer_idx() != 0) {
477+
ConstantDataOffsetPtr scale_buffer_offset =
478+
flatbuffer_graph->constant_data()->Get(
479+
qparams->scale_buffer_idx());
480+
const std::string& data_name =
481+
scale_buffer_offset->named_key()->str();
482+
Result<FreeableBuffer> scale_buffer =
483+
named_data_map->get_data(data_name.c_str());
484+
ET_CHECK_OR_RETURN_ERROR(
485+
scale_buffer.ok(),
486+
Internal,
487+
"Failed to get constant data for key %s from named_data_map. Error code: %u",
488+
data_name.c_str(),
489+
static_cast<uint32_t>(scale_buffer.error()));
456490
scale_data =
457-
static_cast<const uint16_t*>(qparams->scale_bf16()->data());
458-
scale_numel = qparams->scale_bf16()->size();
491+
reinterpret_cast<const uint16_t*>(scale_buffer.get().data());
492+
freeable_buffers.push_back(std::move(scale_buffer.get()));
493+
scale_numel = qparams->num_scales();
459494
} else {
460495
// Read fp32 scales, convert to bf16.
461496
auto conv_buffer = static_cast<uint16_t*>(allocator.allocateTemporary(

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ table Buffer {
4848
table PerChannelQuant {
4949
scale:[float];
5050
channel_dim:int;
51+
scale_buffer_idx: uint;
52+
num_scales: uint;
5153
}
5254

5355
table PerTokenDynamicQuant {
@@ -63,7 +65,9 @@ table PerChannelGroupQuant {
6365
scale:[float];
6466
channel_dim:int;
6567
group_size:int;
66-
scale_bf16:[ushort];
68+
scale_bf16:[ushort] (deprecated);
69+
scale_buffer_idx: uint;
70+
num_scales: uint;
6771
}
6872

6973
table XNNTensorValue {

backends/xnnpack/serialization/schema.fbs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,16 @@ table PerChannelGroupQuant {
4848
scale:[float];
4949
channel_dim:int;
5050
group_size:int;
51-
scale_bf16:[ushort];
51+
scale_bf16:[ushort] (deprecated);
52+
scale_buffer_idx: uint;
53+
num_scales: uint;
5254
}
5355

5456
table PerChannelQuant {
5557
scale:[float];
5658
channel_dim:int;
59+
scale_buffer_idx: uint;
60+
num_scales: uint;
5761
}
5862

5963
table PerTokenDynamicQuant {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,13 +425,23 @@ class XNNDatatype(IntEnum):
425425
class PerChannelQuant:
426426
scale: List[float]
427427
channel_dim: int
428+
scale_buffer_idx: int = -1
429+
num_scales: int = -1
430+
431+
432+
@dataclass
433+
class Buffer:
434+
storage: bytes
428435

429436

430437
@dataclass
431438
class PerChannelGroupQuant:
432439
scale: List[float]
433440
channel_dim: int
434441
group_size: int = 1
442+
scale_bf16: Optional[List[float]] = None
443+
scale_buffer_idx: int = -1
444+
num_scales: int = -1
435445

436446

437447
@dataclass

0 commit comments

Comments
 (0)