@@ -315,6 +315,12 @@ void test_vulkan_quantize_per_tensor(
315315      vkcompute::utils::kBuffer ,
316316      vkcompute::utils::kBuffer );
317317
318+   //  If the in_dtype is a double, convert to float for texture implementation
319+   //  since they don't support 64bit as inputs
320+   if  (in_dtype == at::kDouble ) {
321+     in_dtype = at::kFloat ;
322+   }
323+ 
318324  //  Test with texture storage
319325  test_vulkan_quantize_per_tensor_impl (
320326      input_sizes,
@@ -349,6 +355,12 @@ void test_vulkan_quantize_per_token(
349355      vkcompute::utils::kBuffer ,
350356      vkcompute::utils::kBuffer );
351357
358+   //  If the in_dtype is a double, convert to float for texture implementation
359+   //  since they don't support 64bit as inputs
360+   if  (in_dtype == at::kDouble ) {
361+     in_dtype = at::kFloat ;
362+   }
363+ 
352364  //  Test with texture storage
353365  test_vulkan_quantize_per_token_impl (
354366      input_sizes,
@@ -655,6 +667,24 @@ TEST(
655667      at::kChar ); //  output dtype
656668}
657669
670+ TEST (
671+     VulkanQuantizePerTensorTest,
672+     test_vulkan_quantize_per_tensor_double_to_int8) {
673+   if  (!vkcompute::api::context ()
674+            ->adapter_ptr ()
675+            ->has_full_int8_buffers_support ()) {
676+     GTEST_SKIP ();
677+   }
678+   test_vulkan_quantize_per_tensor (
679+       {2 , 3 }, //  input sizes
680+       0.01 , //  scale
681+       1 , //  zero_point
682+       -128 , //  quant_min
683+       127 , //  quant_max
684+       at::kDouble , //  input dtype
685+       at::kChar ); //  output dtype
686+ }
687+ 
658688void  test_reference_quantize_per_token (
659689    const  std::vector<int >& input_sizes,
660690    const  std::vector<float >& pre_scales,
@@ -1075,3 +1105,24 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) {
10751105      at::kHalf , //  input dtype
10761106      at::kChar ); //  output dtype
10771107}
1108+ 
1109+ TEST (
1110+     VulkanQuantizePerTensorTest,
1111+     test_vulkan_quantize_per_token_double_to_int8) {
1112+   if  (!vkcompute::api::context ()
1113+            ->adapter_ptr ()
1114+            ->has_full_int8_buffers_support ()) {
1115+     GTEST_SKIP ();
1116+   }
1117+   std::vector<float > scales = {0.1 , 0.2 };
1118+   std::vector<int > zero_points = {0 , 5 };
1119+ 
1120+   test_vulkan_quantize_per_token (
1121+       {2 , 2 }, //  input sizes (2*2=4 tokens)
1122+       scales,
1123+       zero_points,
1124+       -128 , //  quant_min
1125+       127 , //  quant_max
1126+       at::kDouble , //  input dtype
1127+       at::kChar ); //  output dtype
1128+ }
0 commit comments