@@ -421,11 +421,32 @@ Error defineTensor(
421421 qparams->channel_dim (),
422422 dtype,
423423 zero_point);
424+
425+ const float * scale = qparams->scale ()->data ();
426+
427+ if (qparams->scale_buffer_idx () != 0 ) {
428+ // if scales are stored in named data, then retrieve it
429+ ConstantDataOffsetPtr scale_buffer_offset =
430+ flatbuffer_graph->constant_data ()->Get (
431+ qparams->scale_buffer_idx ());
432+ const std::string& data_name =
433+ scale_buffer_offset->named_key ()->str ();
434+ Result<FreeableBuffer> scale_buffer =
435+ named_data_map->get_data (data_name.c_str ());
436+ ET_CHECK_OR_RETURN_ERROR (
437+ scale_buffer.ok (),
438+ Internal,
439+ " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
440+ data_name.c_str (),
441+ static_cast <uint32_t >(scale_buffer.error ()));
442+ scale = reinterpret_cast <const float *>(scale_buffer.get ().data ());
443+ freeable_buffers.push_back (std::move (scale_buffer.get ()));
444+ }
424445 status = xnn_define_channelwise_quantized_tensor_value_v2 (
425446 /* subgraph=*/ subgraph_ptr,
426447 /* datatype=*/ dtype,
427448 /* zero_point=*/ zero_point,
428- /* scale=*/ qparams-> scale ()-> data () ,
449+ /* scale=*/ scale,
429450 /* num_dims=*/ tensor_value->num_dims (),
430451 /* channel_dim*/ qparams->channel_dim (),
431452 /* dims=*/ dims_data.data (),
@@ -452,10 +473,24 @@ Error defineTensor(
452473
453474 // Block scales are preferably serialized as bf16 but can also be
454475 // serialized as fp32 for backwards compatability.
455- if (qparams->scale_bf16 () != nullptr ) {
476+ if (qparams->scale_buffer_idx () != 0 ) {
477+ ConstantDataOffsetPtr scale_buffer_offset =
478+ flatbuffer_graph->constant_data ()->Get (
479+ qparams->scale_buffer_idx ());
480+ const std::string& data_name =
481+ scale_buffer_offset->named_key ()->str ();
482+ Result<FreeableBuffer> scale_buffer =
483+ named_data_map->get_data (data_name.c_str ());
484+ ET_CHECK_OR_RETURN_ERROR (
485+ scale_buffer.ok (),
486+ Internal,
487+ " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
488+ data_name.c_str (),
489+ static_cast <uint32_t >(scale_buffer.error ()));
456490 scale_data =
457- static_cast <const uint16_t *>(qparams->scale_bf16 ()->data ());
458- scale_numel = qparams->scale_bf16 ()->size ();
491+ reinterpret_cast <const uint16_t *>(scale_buffer.get ().data ());
492+ freeable_buffers.push_back (std::move (scale_buffer.get ()));
493+ scale_numel = qparams->num_scales ();
459494 } else {
460495 // Read fp32 scales, convert to bf16.
461496 auto conv_buffer = static_cast <uint16_t *>(allocator.allocateTemporary (
0 commit comments