@@ -156,6 +156,56 @@ void check_quantize_args(
156156 " actual quant_max: " ,
157157 quant_max);
158158}
159+
160+ //
161+ // Reference Implementation
162+ //
163+
164+ /*
165+ * Reference implementation of quantize_per_tensor
166+ */
167+ at::Tensor quantize_per_tensor_reference_impl (
168+ const at::Tensor& input,
169+ double scale,
170+ int64_t zero_point,
171+ int64_t quant_min,
172+ int64_t quant_max,
173+ at::ScalarType dtype) {
174+ // Create output tensor with the target dtype
175+ at::Tensor out = at::empty_like (input, dtype);
176+
177+ // Quantize the input tensor
178+ float inv_scale = 1.0 / scale;
179+
180+ // Iterate through the tensor and quantize each element
181+ at::Tensor float_input = input.to (at::kFloat );
182+ at::Tensor float_values = float_input.flatten ();
183+
184+ auto out_flat = out.flatten ();
185+
186+ for (int i = 0 ; i < float_values.numel (); i++) {
187+ float value = float_values[i].item <float >();
188+ int64_t qvalue = zero_point + std::nearbyint (inv_scale * value);
189+
190+ qvalue = std::max<int64_t >(qvalue, quant_min);
191+ qvalue = std::min<int64_t >(qvalue, quant_max);
192+
193+ if (dtype == at::kByte ) {
194+ out_flat[i] = static_cast <uint8_t >(qvalue);
195+ } else if (dtype == at::kChar ) {
196+ out_flat[i] = static_cast <int8_t >(qvalue);
197+ } else if (dtype == at::kShort ) {
198+ out_flat[i] = static_cast <int16_t >(qvalue);
199+ } else if (dtype == at::kInt ) {
200+ out_flat[i] = static_cast <int32_t >(qvalue);
201+ } else if (dtype == at::kLong ) {
202+ out_flat[i] = static_cast <int64_t >(qvalue);
203+ }
204+ }
205+
206+ return out.reshape (input.sizes ());
207+ }
208+
159209/*
160210 * Reference implementation of quantize_per_token
161211 */
@@ -218,6 +268,18 @@ at::Tensor quantize_per_token_reference_impl(
218268 return out;
219269}
220270
271+ // Forward declaration of implementation functions
272+ void test_vulkan_quantize_per_tensor_impl (
273+ const std::vector<int >& input_sizes,
274+ float scale,
275+ int zero_point,
276+ int64_t quant_min,
277+ int64_t quant_max,
278+ at::ScalarType in_dtype,
279+ at::ScalarType dtype,
280+ const vkcompute::utils::StorageType in_storage,
281+ const vkcompute::utils::StorageType out_storage);
282+
221283void test_vulkan_quantize_per_token_impl (
222284 const std::vector<int >& input_sizes,
223285 const std::vector<float >& scales,
@@ -229,6 +291,40 @@ void test_vulkan_quantize_per_token_impl(
229291 const vkcompute::utils::StorageType in_storage,
230292 const vkcompute::utils::StorageType out_storage);
231293
294+ // Wrapper function to test both buffer and texture storage types
295+ void test_vulkan_quantize_per_tensor (
296+ const std::vector<int >& input_sizes,
297+ float scale,
298+ int zero_point,
299+ int64_t quant_min,
300+ int64_t quant_max,
301+ at::ScalarType in_dtype = at::kFloat ,
302+ at::ScalarType dtype = at::kInt ) {
303+ // Test with buffer storage
304+ test_vulkan_quantize_per_tensor_impl (
305+ input_sizes,
306+ scale,
307+ zero_point,
308+ quant_min,
309+ quant_max,
310+ in_dtype,
311+ dtype,
312+ vkcompute::utils::kBuffer ,
313+ vkcompute::utils::kBuffer );
314+
315+ // Test with texture storage
316+ test_vulkan_quantize_per_tensor_impl (
317+ input_sizes,
318+ scale,
319+ zero_point,
320+ quant_min,
321+ quant_max,
322+ in_dtype,
323+ dtype,
324+ vkcompute::utils::kTexture3D ,
325+ vkcompute::utils::kTexture3D );
326+ }
327+
232328// Wrapper function to test both buffer and texture storage types
233329void test_vulkan_quantize_per_token (
234330 const std::vector<int >& input_sizes,
@@ -263,6 +359,211 @@ void test_vulkan_quantize_per_token(
263359 vkcompute::utils::kTexture3D );
264360}
265361
362+ void test_reference_quantize_per_tensor (
363+ const std::vector<int >& input_sizes,
364+ float scale,
365+ int zero_point,
366+ int64_t quant_min,
367+ int64_t quant_max,
368+ at::ScalarType in_dtype = at::kFloat ,
369+ at::ScalarType dtype = at::kInt ) {
370+ check_quantize_args (quant_min, quant_max, dtype);
371+ std::vector<int64_t > input_sizes_int64 (
372+ input_sizes.begin (), input_sizes.end ());
373+ at::Tensor input =
374+ at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (in_dtype));
375+
376+ // Fill with a simple pattern: values from 0 to 1 in steps
377+ float step = 1 .0f / (input.numel () - 1 );
378+ auto flat_input = input.flatten ();
379+ for (int i = 0 ; i < flat_input.numel (); i++) {
380+ flat_input[i] = i * step;
381+ }
382+
383+ // Reshape back to original dimensions
384+ input = flat_input.reshape (input_sizes_int64);
385+
386+ // Get reference output
387+ at::Tensor reference_out = quantize_per_tensor_reference_impl (
388+ input, scale, zero_point, quant_min, quant_max, dtype);
389+
390+ // Get implementation output
391+ at::Tensor impl_out = torch::executor::native::quantize_per_tensor_aten (
392+ input, scale, zero_point, quant_min, quant_max, dtype);
393+
394+ // Convert to int for consistent display regardless of underlying type
395+ at::Tensor reference_int = reference_out.to (at::kInt );
396+ at::Tensor impl_int = impl_out.to (at::kInt );
397+
398+ const bool output_correct = at::equal (reference_int, impl_int);
399+ if (!output_correct) {
400+ at::Tensor diffs = at::abs (reference_int - impl_int);
401+
402+ std::cout << " \n "
403+ << " Failed with parameters: " << std::endl;
404+ std::cout << " scale: " << scale << std::endl;
405+ std::cout << " zero_point: " << zero_point << std::endl;
406+ std::cout << " quant_min: " << quant_min << std::endl;
407+ std::cout << " quant_max: " << quant_max << std::endl;
408+
409+ std::cout << " input:" << std::endl;
410+ std::cout << input << std::endl;
411+ std::cout << " reference:" << std::endl;
412+ std::cout << reference_int << std::endl;
413+ std::cout << " my_reference:" << std::endl;
414+ std::cout << impl_int << std::endl;
415+ }
416+
417+ ASSERT_TRUE (output_correct);
418+ }
419+
420+ void test_vulkan_quantize_per_tensor_impl (
421+ const std::vector<int >& input_sizes,
422+ float scale,
423+ int zero_point,
424+ int64_t quant_min,
425+ int64_t quant_max,
426+ at::ScalarType in_dtype = at::kFloat ,
427+ at::ScalarType dtype = at::kInt ,
428+ const vkcompute::utils::StorageType in_storage =
429+ vkcompute::utils::kTexture3D ,
430+ const vkcompute::utils::StorageType out_storage =
431+ vkcompute::utils::kTexture3D ) {
432+ check_quantize_args (quant_min, quant_max, dtype);
433+ std::vector<int64_t > input_sizes_int64 (
434+ input_sizes.begin (), input_sizes.end ());
435+ at::Tensor input =
436+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (in_dtype));
437+
438+ // Get reference output
439+ at::Tensor reference_out = torch::executor::native::quantize_per_tensor_aten (
440+ input, scale, zero_point, quant_min, quant_max, dtype);
441+
442+ // Build Vulkan quantize_per_tensor graph
443+ using namespace vkcompute ;
444+
445+ GraphConfig config;
446+ config.set_storage_type_override (in_storage);
447+ ComputeGraph graph (config);
448+
449+ IOValueRef r_input = graph.add_input_tensor (
450+ input.sizes ().vec (), from_at_scalartype (input.scalar_type ()), in_storage);
451+
452+ const ValueRef r_scale = graph.add_scalar <double >(scale);
453+ const ValueRef r_zero_point = graph.add_scalar <int64_t >(zero_point);
454+ const ValueRef r_quant_min = graph.add_scalar <int64_t >(quant_min);
455+ const ValueRef r_quant_max = graph.add_scalar <int64_t >(quant_max);
456+
457+ const ValueRef r_out = graph.add_tensor (
458+ input.sizes ().vec (), from_at_scalartype (dtype), out_storage);
459+
460+ VK_GET_OP_FN (" quantize_per_tensor.default" )
461+ (graph,
462+ {
463+ r_input.value ,
464+ r_scale,
465+ r_zero_point,
466+ r_quant_min,
467+ r_quant_max,
468+ r_out,
469+ });
470+
471+ ValueRef staging_out = graph.set_output_tensor (r_out);
472+
473+ graph.prepare ();
474+ graph.encode_prepack ();
475+ graph.prepack ();
476+ graph.encode_execute ();
477+
478+ // Run Vulkan quantize_per_tensor
479+ graph.copy_into_staging (
480+ r_input.staging , input.const_data_ptr (), input.numel ());
481+
482+ graph.execute ();
483+
484+ at::Tensor vk_out = at::empty_like (reference_out).contiguous ();
485+ graph.copy_from_staging (
486+ staging_out, vk_out.mutable_data_ptr (), vk_out.numel ());
487+
488+ // Compare outputs
489+ // For quantized types, we need to compare the actual integer values
490+ at::Tensor reference_int = reference_out.to (at::kInt );
491+ at::Tensor vk_int = vk_out.to (at::kInt );
492+
493+ const bool output_correct = at::equal (reference_int, vk_int);
494+ if (!output_correct) {
495+ at::Tensor diffs = at::abs (reference_int - vk_int);
496+
497+ std::cout << " \n "
498+ << " Failed with parameters: " << std::endl;
499+ std::cout << " scale: " << scale << std::endl;
500+ std::cout << " zero_point: " << zero_point << std::endl;
501+ std::cout << " quant_min: " << quant_min << std::endl;
502+ std::cout << " quant_max: " << quant_max << std::endl;
503+
504+ std::cout << " input:" << std::endl;
505+ std::cout << input << std::endl;
506+ std::cout << " reference:" << std::endl;
507+ std::cout << reference_int << std::endl;
508+ std::cout << " vulkan:" << std::endl;
509+ std::cout << vk_int << std::endl;
510+ }
511+
512+ ASSERT_TRUE (output_correct);
513+ }
514+
515+ TEST (
516+ VulkanQuantizePerTensorTest,
517+ test_reference_quantize_per_tensor_float_to_int8) {
518+ test_reference_quantize_per_tensor (
519+ {2 , 3 , 4 }, // input sizes
520+ 0.1 , // scale
521+ 0 , // zero_point
522+ -128 , // quant_min
523+ 127 , // quant_max
524+ at::kFloat ,
525+ at::kChar );
526+ }
527+
528+ TEST (
529+ VulkanQuantizePerTensorTest,
530+ test_reference_quantize_per_tensor_float_to_int32) {
531+ test_reference_quantize_per_tensor (
532+ {2 , 3 , 4 }, // input sizes
533+ 0.04 , // scale
534+ 5 , // zero_point
535+ std::numeric_limits<int32_t >::min (), // quant_min
536+ std::numeric_limits<int32_t >::max (), // quant_max
537+ at::kFloat ,
538+ at::kInt );
539+ }
540+
541+ TEST (
542+ VulkanQuantizePerTensorTest,
543+ test_reference_quantize_per_tensor_half_to_uint8) {
544+ test_reference_quantize_per_tensor (
545+ {2 , 3 , 4 }, // input sizes
546+ 0.2 , // scale
547+ 2 , // zero_point
548+ 0 , // quant_min
549+ 255 , // quant_max
550+ at::kHalf ,
551+ at::kByte );
552+ }
553+
554+ TEST (
555+ VulkanQuantizePerTensorTest,
556+ test_reference_quantize_per_tensor_half_to_int32) {
557+ test_reference_quantize_per_tensor (
558+ {2 , 3 , 4 }, // input sizes
559+ 0.01 , // scale
560+ 1 , // zero_point
561+ std::numeric_limits<int32_t >::min (), // quant_min
562+ std::numeric_limits<int32_t >::max (), // quant_max
563+ at::kHalf ,
564+ at::kInt );
565+ }
566+
266567void test_reference_quantize_per_token (
267568 const std::vector<int >& input_sizes,
268569 const std::vector<float >& scales,
0 commit comments