@@ -421,11 +421,28 @@ Error defineTensor(
421421 qparams->channel_dim (),
422422 dtype,
423423 zero_point);
424+
425+ const float * scale = qparams->scale ()->data ();
426+
427+ if (qparams->scale_buffer_idx () != 0 ) {
428+ // if scales are stored in named data, then retrieve it
429+ ConstantDataOffsetPtr scale_buffer_offset = flatbuffer_graph->constant_data ()->Get (qparams->scale_buffer_idx ());
430+ const std::string& data_name = scale_buffer_offset->named_key ()->str ();
431+ Result<FreeableBuffer> scale_buffer = named_data_map->get_data (data_name.c_str ());
432+ ET_CHECK_OR_RETURN_ERROR (
433+ scale_buffer.ok (),
434+ Internal,
435+ " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
436+ data_name.c_str (),
437+ static_cast <uint32_t >(scale_buffer.error ()));
438+ scale = reinterpret_cast <const float *>(scale_buffer.get ().data ());
439+ freeable_buffers.push_back (std::move (scale_buffer.get ()));
440+ }
424441 status = xnn_define_channelwise_quantized_tensor_value_v2 (
425442 /* subgraph=*/ subgraph_ptr,
426443 /* datatype=*/ dtype,
427444 /* zero_point=*/ zero_point,
428- /* scale=*/ qparams-> scale ()-> data () ,
445+ /* scale=*/ scale,
429446 /* num_dims=*/ tensor_value->num_dims (),
430447 /* channel_dim*/ qparams->channel_dim (),
431448 /* dims=*/ dims_data.data (),
@@ -452,10 +469,19 @@ Error defineTensor(
452469
453470 // Block scales are preferably serialized as bf16 but can also be
454471 // serialized as fp32 for backwards compatability.
455- if (qparams->scale_bf16 () != nullptr ) {
456- scale_data =
457- static_cast <const uint16_t *>(qparams->scale_bf16 ()->data ());
458- scale_numel = qparams->scale_bf16 ()->size ();
472+ if (qparams->scale_buffer_idx () != 0 ) {
473+ ConstantDataOffsetPtr scale_buffer_offset = flatbuffer_graph->constant_data ()->Get (qparams->scale_buffer_idx ());
474+ const std::string& data_name = scale_buffer_offset->named_key ()->str ();
475+ Result<FreeableBuffer> scale_buffer = named_data_map->get_data (data_name.c_str ());
476+ ET_CHECK_OR_RETURN_ERROR (
477+ scale_buffer.ok (),
478+ Internal,
479+ " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
480+ data_name.c_str (),
481+ static_cast <uint32_t >(scale_buffer.error ()));
482+ scale_data = reinterpret_cast <const uint16_t *>(scale_buffer.get ().data ());
483+ freeable_buffers.push_back (std::move (scale_buffer.get ()));
484+ scale_numel = qparams->num_scales ();
459485 } else {
460486 // Read fp32 scales, convert to bf16.
461487 auto conv_buffer = static_cast <uint16_t *>(allocator.allocateTemporary (
0 commit comments