1414namespace  vkcompute  {
1515namespace  api  {
1616
17+ /* 
18+  * For PackedInt8 memory layouts, ensure that the scalar type used for the 
19+  * tensor is kInt8x4. Otherwise, return the original scalar type. 
20+  */  
21+ vkapi::ScalarType get_effective_scalar_type (
22+     const  vkapi::ScalarType dtype,
23+     const  utils::GPUMemoryLayout memory_layout) {
24+   vkapi::ScalarType effective_dtype = dtype;
25+   if  (utils::is_packed_int8_layout (memory_layout)) {
26+     VK_CHECK_COND (dtype == vkapi::kInt8x4  || dtype == vkapi::kChar );
27+     effective_dtype = vkapi::kInt8x4 ;
28+   }
29+   return  effective_dtype;
30+ }
31+ 
1732/* 
1833 * Used to infer the sizes of a tensor that would correspond to a given 
1934 * VulkanImage. 
@@ -187,6 +202,7 @@ std::vector<int64_t> calculate_padded_sizes(
187202
188203utils::uvec3 calculate_image_extents (
189204    const  std::vector<int64_t >& padded_sizes,
205+     const  utils::GPUMemoryLayout memory_layout,
190206    const  std::vector<int64_t >& axis_map,
191207    const  int32_t  packed_dim) {
192208  utils::uvec3 extents ({1 , 1 , 1 });
@@ -205,6 +221,28 @@ utils::uvec3 calculate_image_extents(
205221    extents[axis] = utils::safe_downcast<uint32_t >(padded_sizes.at (dim));
206222  }
207223
224+   //  For "regular" tensor dtypes, 4 elements along the packed dim are packed
225+   //  into one texel (4-component vectorized type). However, for packed int8
226+   //  memory layouts, an additional level of packing is employed where 4 int8
227+   //  elements are packed into one int32, and then 4 int32 are packed into each
228+   //  ivec4 texel.
229+   if  (utils::is_packed_int8_layout (memory_layout)) {
230+     //  Each int in the ivec4 contains 4 channels. The overall ivec4 contains
231+     //  data for a 1Hx4Wx4C block of the input tensor.
232+     if  (memory_layout == utils::kPackedInt8_4W4C ) {
233+       VK_CHECK_COND (packed_dim == 2 );
234+       extents[axis_map.at (0 )] = utils::div_up (extents[axis_map.at (0 )], 4u );
235+     }
236+     //  Each int in the ivec4 contains 4 elements along the width dim. The
237+     //  overall ivec4 contains data for a 4Hx4W block of the input tensor.
238+     else  if  (memory_layout == utils::kPackedInt8_4H4W ) {
239+       VK_CHECK_COND (packed_dim == 0 );
240+       extents[axis_map.at (1 )] = utils::div_up (extents[axis_map.at (1 )], 4u );
241+     } else  {
242+       VK_THROW (" Unhandled packed int8 memory layout!"  );
243+     }
244+   }
245+ 
208246  //  axis_map[3] indicates the WHCN index of the dimension used for batch
209247  //  concatenation. Thus a double lookup is required to determine the image axis
210248  //  used for batch concatenation.
@@ -215,6 +253,7 @@ utils::uvec3 calculate_image_extents(
215253
216254  VK_CHECK_COND (extents[axis_map.at (packed_dim)] % 4  == 0 );
217255  extents[axis_map.at (packed_dim)] /= 4 ;
256+ 
218257  return  extents;
219258}
220259
@@ -247,35 +286,72 @@ utils::uvec3 calculate_logical_limits(
247286 */  
248287utils::uvec3 calculate_logical_limits (
249288    const  std::vector<int64_t >& sizes,
289+     const  utils::GPUMemoryLayout memory_layout,
250290    const  std::vector<int64_t >& axis_map,
251291    const  int32_t  packed_dim) {
252292  return  calculate_logical_limits (
253293      calculate_image_extents (
254-           calculate_padded_sizes (sizes, packed_dim), axis_map, packed_dim),
294+           calculate_padded_sizes (sizes, packed_dim),
295+           memory_layout,
296+           axis_map,
297+           packed_dim),
255298      axis_map);
256299}
257300
258301int64_t  calculate_gpu_buffer_numel (
302+     const  std::vector<int64_t >& sizes,
303+     const  utils::GPUMemoryLayout memory_layout,
304+     const  vkapi::ScalarType dtype) {
305+   size_t  numel;
306+ 
307+   //  Mirrors the logic in calculate_image_extents for packed int8 memory layouts
308+   if  (dtype == vkapi::kInt8x4 ) {
309+     VK_CHECK_COND (utils::is_packed_int8_layout (memory_layout));
310+     std::vector<int64_t > blocks_in_dim =
311+         flip_and_unsqueeze<int64_t >(sizes, kTensorSizes , 0 );
312+     //  Each ivec4 contains data for a 1Hx4Wx4C block of the input
313+     if  (memory_layout == utils::kPackedInt8_4W4C ) {
314+       blocks_in_dim[0 ] = utils::div_up_4 (blocks_in_dim[0 ]);
315+       blocks_in_dim[2 ] = utils::div_up_4 (blocks_in_dim[2 ]);
316+     }
317+     //  Each ivec4 contains data for a 4Hx4W block of the input
318+     else  if  (memory_layout == utils::kPackedInt8_4H4W ) {
319+       blocks_in_dim[0 ] = utils::div_up_4 (blocks_in_dim[0 ]);
320+       blocks_in_dim[1 ] = utils::div_up_4 (blocks_in_dim[1 ]);
321+     } else  {
322+       VK_THROW (" Unhandled packed int8 memory layout!"  );
323+     }
324+     //  Each block is represented as an ivec4, and the base dtype of the buffer
325+     //  is int. Therefore, need to multiply the number of blocks by 4 to obtain
326+     //  the number of int elements in the data buffer.
327+     numel = utils::multiply_integers (blocks_in_dim) * 4 ;
328+   }
329+   //  Case for "regular" dtypes/memory layouts
330+   else  {
331+     numel = utils::multiply_integers (sizes);
332+ 
333+     //  For 8-bit types, align to the next multiple of 4. For devices that do not
334+     //  support 8-bit storage buffers, the tensor data will be interpreted as an
335+     //  array of int32 instead.
336+     if  (vkapi::element_size (dtype) == 1 ) {
337+       numel = utils::align_up_4 (numel);
338+     }
339+   }
340+   return  numel;
341+ }
342+ 
343+ int64_t  calculate_staging_or_gpu_buffer_numel (
259344    Context* const  context,
260345    const  std::vector<int64_t >& sizes,
261346    const  utils::uvec3 image_extents,
262347    const  utils::StorageType storage_type,
348+     const  utils::GPUMemoryLayout memory_layout,
263349    const  vkapi::ScalarType dtype) {
264350  //  For texture backed tensors, simply multiply the total number of texels by 4
265351  if  (storage_type != utils::kBuffer ) {
266352    return  image_extents[0 ] * image_extents[1 ] * image_extents[2 ] * 4 ;
267353  }
268-   const  bool  is_int8 = dtype == vkapi::kChar ;
269-   const  bool  int8_supported =
270-       context->adapter_ptr ()->has_full_int8_buffers_support ();
271-   const  size_t  numel = utils::multiply_integers (sizes);
272-   //  For int8 tensors, if the device does not support int8 buffers, then int32
273-   //  is used instead to represent the buffer data. Therefore the number of
274-   //  elements in the buffer is aligned to the next multiple of 4.
275-   if  (is_int8 && int8_supported) {
276-     return  utils::align_up_4 (numel);
277-   }
278-   return  numel;
354+   return  calculate_gpu_buffer_numel (sizes, memory_layout, dtype);
279355}
280356
281357template  <typename  T, typename  = std::enable_if_t <std::is_integral<T>::value>>
@@ -332,10 +408,12 @@ vkapi::VulkanImage allocate_image(
332408    Context* const  context_ptr,
333409    utils::uvec3& image_extents,
334410    const  utils::StorageType storage_type,
335-     const  VkFormat image_format ,
411+     const  vkapi::ScalarType dtype ,
336412    const  bool  allocate_memory) {
337413  vkapi::Adapter* adapter_ptr = context_ptr->adapter_ptr ();
338414
415+   const  VkFormat image_format = vkcompute::vkapi::to_vkformat (dtype);
416+ 
339417  vkapi::ImageSampler::Properties sampler_props{
340418      VK_FILTER_NEAREST,
341419      VK_SAMPLER_MIPMAP_MODE_NEAREST,
@@ -420,6 +498,7 @@ vkapi::VulkanBuffer allocate_buffer(
420498vTensorStorage::vTensorStorage (
421499    Context* const  context,
422500    const  utils::StorageType storage_type,
501+     const  utils::GPUMemoryLayout memory_layout,
423502    const  std::vector<int64_t >& axis_map,
424503    const  int32_t  packed_dim,
425504    const  std::vector<int64_t >& sizes,
@@ -429,20 +508,22 @@ vTensorStorage::vTensorStorage(
429508      storage_type_{storage_type},
430509      image_extents_ (calculate_image_extents(
431510          calculate_padded_sizes (sizes, packed_dim),
511+           memory_layout,
432512          axis_map,
433513          packed_dim)),
434-       buffer_length_{calculate_gpu_buffer_numel (
514+       buffer_length_{calculate_staging_or_gpu_buffer_numel (
435515          context_,
436516          sizes,
437517          image_extents_,
438518          storage_type,
519+           memory_layout,
439520          dtype)},
440521      buffer_offset_{0 },
441522      image_ (allocate_image(
442523          context_,
443524          image_extents_,
444525          storage_type_,
445-           to_vkformat ( dtype) ,
526+           dtype,
446527          allocate_memory)),
447528      buffer_(allocate_buffer(
448529          context_,
@@ -553,7 +634,7 @@ vTensor::vTensor(
553634    const  utils::GPUMemoryLayout memory_layout,
554635    const  bool  allocate_memory,
555636    const  utils::AxisMapLayout axis_map_layout)
556-     : dtype_(dtype),
637+     : dtype_(get_effective_scalar_type( dtype, memory_layout) ),
557638      //  Calculate tensor metadata
558639      sizes_(sizes.begin(), sizes.end()),
559640      packed_dim_(utils::to_packed_dim<int32_t >(memory_layout)),
@@ -576,6 +657,7 @@ vTensor::vTensor(
576657      storage_(std::make_shared<vTensorStorage>(
577658          context,
578659          storage_type,
660+           memory_layout,
579661          axis_map_,
580662          packed_dim_,
581663          sizes,
@@ -785,6 +867,16 @@ vkapi::VulkanBuffer& vTensor::buffer(
785867}
786868
787869utils::GPUMemoryLayout vTensor::estimate_memory_layout () const  {
870+   if  (dtype_ == vkapi::kInt8x4 ) {
871+     switch  (packed_dim_) {
872+       case  WHCN::kChannelsDim :
873+         return  utils::kPackedInt8_4W4C ;
874+       case  WHCN::kWidthDim :
875+         return  utils::kPackedInt8_4H4W ;
876+       default :
877+         VK_THROW (" Invalid packed dim for Tensor with kInt8x4 type"  );
878+     }
879+   }
788880  switch  (packed_dim_) {
789881    case  WHCN::kWidthDim :
790882      return  utils::kWidthPacked ;
@@ -914,8 +1006,8 @@ void vTensor::update_metadata() {
9141006        flip_and_unsqueeze_ivec4 (dim_order_, kTensorDimOrder , numel_);
9151007    uniform_data_->strides_v  =
9161008        flip_and_unsqueeze_ivec4 (strides_, kTensorStrides , numel_);
917-     uniform_data_->logical_limits .limits  =
918-         calculate_logical_limits ( sizes_, axis_map_, packed_dim_);
1009+     uniform_data_->logical_limits .limits  =  calculate_logical_limits ( 
1010+         sizes_,  estimate_memory_layout () , axis_map_, packed_dim_);
9191011
9201012    if  (sizes_uniform_offset_ != kUniformOffsetUnset ) {
9211013      uniforms_.update (uniform_data_->sizes_v , sizes_uniform_offset_);
@@ -942,11 +1034,15 @@ void vTensor::update_metadata() {
9421034}
9431035
9441036void  vTensor::check_sizes (const  std::vector<int64_t >& sizes) const  {
1037+   utils::GPUMemoryLayout est_memory_layout = estimate_memory_layout ();
9451038  if  (storage_type () != utils::kBuffer ) {
9461039    //  For texture storage check that the current texture is large enough for
9471040    //  the new sizes of the tensor.
9481041    utils::uvec3 virtual_extents = calculate_image_extents (
949-         calculate_padded_sizes (sizes_, packed_dim_), axis_map_, packed_dim_);
1042+         calculate_padded_sizes (sizes_, packed_dim_),
1043+         est_memory_layout,
1044+         axis_map_,
1045+         packed_dim_);
9501046
9511047    bool  valid_resize = virtual_extents[0 ] <= storage_->image_extents_ [0 ];
9521048    valid_resize =
@@ -958,9 +1054,10 @@ void vTensor::check_sizes(const std::vector<int64_t>& sizes) const {
9581054        valid_resize,
9591055        " tensor sizes requires a larger texture than the current one."  );
9601056  } else  {
961-     //  For buffer storage check that the current buffer is large enough for the
962-     //  new sizes of the tensor.
963-     int64_t  numel = utils::multiply_integers (sizes);
1057+     //  For buffer storage check that the current buffer is large enough for
1058+     //  the new sizes of the tensor.
1059+     int64_t  numel =
1060+         calculate_gpu_buffer_numel (sizes_, est_memory_layout, dtype_);
9641061    bool  valid_resize =
9651062        numel + storage_->buffer_offset_  <= storage_->buffer_length_ ;
9661063    VK_CHECK_COND (
0 commit comments