@@ -232,3 +232,277 @@ vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
232232 " )" );
233233 }
234234}
235+
236+ //
237+ // Reference Implementation
238+ //
239+
240+ /*
241+ * Reference implementation of choose_qparams_tensor
242+ */
243+ std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_reference_impl (
244+ const at::Tensor& input,
245+ int64_t quant_min,
246+ int64_t quant_max) {
247+ // Create output tensors
248+ at::Tensor scale_out = at::empty ({}, at::device (at::kCPU ).dtype (at::kDouble ));
249+ at::Tensor zero_point_out =
250+ at::empty ({}, at::device (at::kCPU ).dtype (at::kLong ));
251+
252+ // Find min and max values in the input tensor
253+ float min_val = input.min ().item <float >();
254+ float max_val = input.max ().item <float >();
255+
256+ // Extend the [min, max] interval to ensure it contains 0
257+ min_val = std::min (min_val, 0 .f );
258+ max_val = std::max (max_val, 0 .f );
259+
260+ // Calculate scale
261+ double scale =
262+ (static_cast <double >(max_val) - min_val) / (quant_max - quant_min);
263+
264+ // Handle small scale
265+ constexpr float SMALL_SCALE_THRESHOLD = 6 .1e-5f ;
266+ if (float (scale) == 0 .0f || std::isinf (1 .0f / float (scale))) {
267+ scale = 0.1 ;
268+ }
269+
270+ if (scale < SMALL_SCALE_THRESHOLD) {
271+ float org_scale = scale;
272+ scale = SMALL_SCALE_THRESHOLD;
273+ // Adjust min and max based on new scale
274+ if (min_val == 0 .0f ) {
275+ max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
276+ } else if (max_val == 0 .0f ) {
277+ min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
278+ } else {
279+ float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
280+ min_val *= amplifier;
281+ max_val *= amplifier;
282+ }
283+ }
284+
285+ // Calculate zero point
286+ double zero_point_from_min = quant_min - min_val / static_cast <double >(scale);
287+ double zero_point_from_max = quant_max - max_val / static_cast <double >(scale);
288+ double zero_point_from_min_error =
289+ std::abs (quant_min) - std::abs (min_val / static_cast <double >(scale));
290+ double zero_point_from_max_error =
291+ std::abs (quant_max) - std::abs (max_val / static_cast <double >(scale));
292+ double initial_zero_point =
293+ zero_point_from_min_error < zero_point_from_max_error
294+ ? zero_point_from_min
295+ : zero_point_from_max;
296+
297+ // Nudge zero point to be an integer
298+ int64_t nudged_zero_point = 0 ;
299+ if (initial_zero_point < quant_min) {
300+ nudged_zero_point = quant_min;
301+ } else if (initial_zero_point > quant_max) {
302+ nudged_zero_point = quant_max;
303+ } else {
304+ nudged_zero_point = std::nearbyint (static_cast <float >(initial_zero_point));
305+ }
306+
307+ // Set output values - use item_mutable() for scalar tensors
308+ scale_out.fill_ (scale);
309+ zero_point_out.fill_ (nudged_zero_point);
310+
311+ return std::make_tuple (scale_out, zero_point_out);
312+ }
313+
314+ // Forward declaration of implementation functions
315+ void test_vulkan_choose_qparams_tensor_impl (
316+ const std::vector<int >& input_sizes,
317+ int64_t quant_min,
318+ int64_t quant_max,
319+ at::ScalarType dtype,
320+ const vkcompute::utils::StorageType in_storage,
321+ const vkcompute::utils::StorageType out_storage);
322+
323+ // Wrapper function to test both buffer and texture storage types
324+ void test_vulkan_choose_qparams_tensor (
325+ const std::vector<int >& input_sizes,
326+ int64_t quant_min,
327+ int64_t quant_max,
328+ at::ScalarType dtype) {
329+ // Test with buffer storage
330+ test_vulkan_choose_qparams_tensor_impl (
331+ input_sizes,
332+ quant_min,
333+ quant_max,
334+ dtype,
335+ vkcompute::utils::kBuffer ,
336+ vkcompute::utils::kBuffer );
337+
338+ // Test with texture storage
339+ test_vulkan_choose_qparams_tensor_impl (
340+ input_sizes,
341+ quant_min,
342+ quant_max,
343+ dtype,
344+ vkcompute::utils::kTexture3D ,
345+ vkcompute::utils::kTexture3D );
346+ }
347+
348+ void test_reference_choose_qparams_tensor (
349+ const std::vector<int >& input_sizes,
350+ int64_t quant_min,
351+ int64_t quant_max,
352+ at::ScalarType dtype) {
353+ std::vector<int64_t > input_sizes_int64 (
354+ input_sizes.begin (), input_sizes.end ());
355+ at::Tensor input =
356+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
357+
358+ // Get reference output
359+ auto [reference_scale, reference_zero_point] =
360+ choose_qparams_tensor_reference_impl (input, quant_min, quant_max);
361+
362+ // Get implementation output
363+ auto [impl_scale, impl_zero_point] =
364+ torch::executor::native::choose_qparams_tensor_aten (
365+ input, quant_min, quant_max, dtype);
366+
367+ // Compare outputs
368+ const bool scale_correct = at::allclose (reference_scale, impl_scale);
369+ const bool zero_point_correct =
370+ at::equal (reference_zero_point, impl_zero_point);
371+
372+ if (!scale_correct || !zero_point_correct) {
373+ std::cout << " \n "
374+ << " Failed with parameters: " << std::endl;
375+ std::cout << " quant_min: " << quant_min << std::endl;
376+ std::cout << " quant_max: " << quant_max << std::endl;
377+
378+ std::cout << " input:" << std::endl;
379+ std::cout << input << std::endl;
380+ std::cout << " reference scale:" << std::endl;
381+ std::cout << reference_scale << std::endl;
382+ std::cout << " implementation scale:" << std::endl;
383+ std::cout << impl_scale << std::endl;
384+ std::cout << " reference zero_point:" << std::endl;
385+ std::cout << reference_zero_point << std::endl;
386+ std::cout << " implementation zero_point:" << std::endl;
387+ std::cout << impl_zero_point << std::endl;
388+ }
389+
390+ ASSERT_TRUE (scale_correct && zero_point_correct);
391+ }
392+
393+ void test_vulkan_choose_qparams_tensor_impl (
394+ const std::vector<int >& input_sizes,
395+ int64_t quant_min,
396+ int64_t quant_max,
397+ at::ScalarType dtype,
398+ const vkcompute::utils::StorageType in_storage,
399+ const vkcompute::utils::StorageType out_storage) {
400+ std::vector<int64_t > input_sizes_int64 (
401+ input_sizes.begin (), input_sizes.end ());
402+ at::Tensor input =
403+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
404+
405+ // Get reference output
406+ auto [reference_scale, reference_zero_point] =
407+ torch::executor::native::choose_qparams_tensor_aten (
408+ input, quant_min, quant_max, dtype);
409+
410+ // Build Vulkan choose_qparams_tensor graph
411+ using namespace vkcompute ;
412+
413+ GraphConfig config;
414+ config.set_storage_type_override (in_storage);
415+ ComputeGraph graph (config);
416+
417+ IOValueRef r_input = graph.add_input_tensor (
418+ input.sizes ().vec (), from_at_scalartype (input.scalar_type ()), in_storage);
419+
420+ const ValueRef r_quant_min = graph.add_scalar <int64_t >(quant_min);
421+ const ValueRef r_quant_max = graph.add_scalar <int64_t >(quant_max);
422+
423+ // Output tensors
424+ const ValueRef r_scale = graph.add_tensor ({}, vkapi::kFloat , out_storage);
425+ const ValueRef r_zero_point = graph.add_tensor ({}, vkapi::kInt , out_storage);
426+
427+ VK_GET_OP_FN (" choose_qparams.tensor" )
428+ (graph,
429+ {
430+ r_input.value ,
431+ r_quant_min,
432+ r_quant_max,
433+ r_scale,
434+ r_zero_point,
435+ });
436+
437+ ValueRef staging_scale = graph.set_output_tensor (r_scale);
438+ ValueRef staging_zero_point = graph.set_output_tensor (r_zero_point);
439+
440+ graph.prepare ();
441+ graph.encode_prepack ();
442+ graph.prepack ();
443+ graph.encode_execute ();
444+
445+ // Run Vulkan choose_qparams_tensor
446+ graph.copy_into_staging (
447+ r_input.staging , input.const_data_ptr (), input.numel ());
448+
449+ graph.execute ();
450+
451+ // Create output tensors to hold the results - use types that match GPU output
452+ at::Tensor vk_scale =
453+ at::empty ({}, at::device (at::kCPU ).dtype (at::kFloat )).contiguous ();
454+ at::Tensor vk_zero_point =
455+ at::empty ({}, at::device (at::kCPU ).dtype (at::kInt )).contiguous ();
456+
457+ // Copy results from GPU to CPU
458+ graph.copy_from_staging (
459+ staging_scale, vk_scale.mutable_data_ptr (), vk_scale.numel ());
460+ graph.copy_from_staging (
461+ staging_zero_point,
462+ vk_zero_point.mutable_data_ptr (),
463+ vk_zero_point.numel ());
464+
465+ // Convert reference values to match Vulkan output types for comparison
466+ at::Tensor reference_scale_float = reference_scale.to (at::kFloat );
467+ at::Tensor reference_zero_point_int = reference_zero_point.to (at::kInt );
468+
469+ // Compare outputs
470+ const bool scale_correct = at::allclose (reference_scale_float, vk_scale);
471+ const bool zero_point_correct =
472+ at::equal (reference_zero_point_int, vk_zero_point);
473+
474+ if (!scale_correct || !zero_point_correct) {
475+ std::cout << " \n "
476+ << " Failed with parameters: " << std::endl;
477+ std::cout << " quant_min: " << quant_min << std::endl;
478+ std::cout << " quant_max: " << quant_max << std::endl;
479+ std::cout << " storage type: "
480+ << (in_storage == vkcompute::utils::kBuffer ? " buffer"
481+ : " texture" )
482+ << std::endl;
483+
484+ // make sure that there arent a ton of elements in the input tensor
485+ if (input.numel () < 100 ) {
486+ std::cout << " input:" << std::endl;
487+ std::cout << input << " \n " << std::endl;
488+ std::cout << " reference scale:" << std::endl;
489+ std::cout << reference_scale << std::endl;
490+ std::cout << " vulkan scale:" << std::endl;
491+ std::cout << vk_scale << " \n " << std::endl;
492+ std::cout << " reference zero_point:" << std::endl;
493+ std::cout << reference_zero_point << std::endl;
494+ std::cout << " vulkan zero_point:" << std::endl;
495+ std::cout << vk_zero_point << std::endl;
496+ }
497+ }
498+
499+ ASSERT_TRUE (scale_correct && zero_point_correct);
500+ }
501+
502+ TEST (VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) {
503+ test_reference_choose_qparams_tensor (
504+ {2 , 3 , 4 }, // input sizes
505+ -128 , // quant_min
506+ 127 , // quant_max
507+ at::kChar );
508+ }
0 commit comments