@@ -156,3 +156,387 @@ void check_quantize_args(
156156 " actual quant_max: " ,
157157 quant_max);
158158}
159+ /*
160+ * Reference implementation of quantize_per_token
161+ */
162+ at::Tensor quantize_per_token_reference_impl (
163+ const at::Tensor& input,
164+ const at::Tensor& scale,
165+ const at::Tensor& zero_point,
166+ int64_t quant_min,
167+ int64_t quant_max,
168+ at::ScalarType dtype) {
169+ // Create output tensor with the target dtype
170+ at::Tensor out = at::empty_like (input, dtype);
171+
172+ // Calculate number of tokens
173+ int num_tokens = 1 ;
174+ for (int i = 0 ; i < input.dim () - 1 ; i++) {
175+ num_tokens *= input.size (i);
176+ }
177+
178+ // Verify that the number of tokens matches the size of scale and zero_point
179+ // tensors
180+ assert (num_tokens == scale.numel ());
181+ assert (num_tokens == zero_point.numel ());
182+
183+ // Reshape input to [num_tokens, last_dim]
184+ at::Tensor reshaped_input = input.reshape ({num_tokens, input.size (-1 )});
185+ at::Tensor reshaped_out = out.reshape ({num_tokens, input.size (-1 )});
186+
187+ // Quantize each token separately
188+ for (int token_idx = 0 ; token_idx < num_tokens; token_idx++) {
189+ // Use float for scale since Vulkan doesn't support double
190+ float token_scale = scale[token_idx].item <float >();
191+ // Use int for zero_point since Vulkan doesn't support int64_t
192+ int token_zero_point = zero_point[token_idx].item <int >();
193+
194+ float inv_scale = 1.0 / token_scale;
195+
196+ // Quantize the token
197+ for (int i = 0 ; i < input.size (-1 ); i++) {
198+ float value = reshaped_input[token_idx][i].item <float >();
199+ int qvalue = token_zero_point + std::nearbyint (inv_scale * value);
200+
201+ qvalue = std::max<int64_t >(qvalue, quant_min);
202+ qvalue = std::min<int64_t >(qvalue, quant_max);
203+
204+ if (dtype == at::kByte ) {
205+ reshaped_out[token_idx][i] = static_cast <uint8_t >(qvalue);
206+ } else if (dtype == at::kChar ) {
207+ reshaped_out[token_idx][i] = static_cast <int8_t >(qvalue);
208+ } else if (dtype == at::kShort ) {
209+ reshaped_out[token_idx][i] = static_cast <int16_t >(qvalue);
210+ } else if (dtype == at::kInt ) {
211+ reshaped_out[token_idx][i] = static_cast <int32_t >(qvalue);
212+ } else if (dtype == at::kLong ) {
213+ reshaped_out[token_idx][i] = static_cast <int64_t >(qvalue);
214+ }
215+ }
216+ }
217+
218+ return out;
219+ }
220+
221+ void test_vulkan_quantize_per_token_impl (
222+ const std::vector<int >& input_sizes,
223+ const std::vector<float >& scales,
224+ const std::vector<int >& zero_points,
225+ int64_t quant_min,
226+ int64_t quant_max,
227+ at::ScalarType in_dtype,
228+ at::ScalarType dtype,
229+ const vkcompute::utils::StorageType in_storage,
230+ const vkcompute::utils::StorageType out_storage);
231+
232+ // Wrapper function to test both buffer and texture storage types
233+ void test_vulkan_quantize_per_token (
234+ const std::vector<int >& input_sizes,
235+ const std::vector<float >& scales,
236+ const std::vector<int >& zero_points,
237+ int64_t quant_min,
238+ int64_t quant_max,
239+ at::ScalarType in_dtype = at::kFloat ,
240+ at::ScalarType dtype = at::kInt ) {
241+ // Test with buffer storage
242+ test_vulkan_quantize_per_token_impl (
243+ input_sizes,
244+ scales,
245+ zero_points,
246+ quant_min,
247+ quant_max,
248+ in_dtype,
249+ dtype,
250+ vkcompute::utils::kBuffer ,
251+ vkcompute::utils::kBuffer );
252+
253+ // Test with texture storage
254+ test_vulkan_quantize_per_token_impl (
255+ input_sizes,
256+ scales,
257+ zero_points,
258+ quant_min,
259+ quant_max,
260+ in_dtype,
261+ dtype,
262+ vkcompute::utils::kTexture3D ,
263+ vkcompute::utils::kTexture3D );
264+ }
265+
266+ void test_reference_quantize_per_token (
267+ const std::vector<int >& input_sizes,
268+ const std::vector<float >& scales,
269+ const std::vector<int >& zero_points,
270+ int64_t quant_min,
271+ int64_t quant_max,
272+ at::ScalarType in_dtype = at::kFloat ,
273+ at::ScalarType dtype = at::kInt ) {
274+ check_quantize_args (quant_min, quant_max, dtype);
275+ std::vector<int64_t > input_sizes_int64 (
276+ input_sizes.begin (), input_sizes.end ());
277+ at::Tensor input =
278+ at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (in_dtype));
279+
280+ // Fill with a simple pattern: values from 0 to 1 in steps
281+ float step = 1.0 / (input.numel () - 1 );
282+ auto flat_input = input.flatten ();
283+ for (int i = 0 ; i < flat_input.numel (); i++) {
284+ flat_input[i] = i * step;
285+ }
286+
287+ // Reshape back to original dimensions
288+ input = flat_input.reshape (input_sizes_int64);
289+
290+ // Calculate number of tokens
291+ int num_tokens = 1 ;
292+ for (int i = 0 ; i < input.dim () - 1 ; i++) {
293+ num_tokens *= input.size (i);
294+ }
295+
296+ // Verify that the number of tokens matches the size of scales and zero_points
297+ ASSERT_EQ (num_tokens, scales.size ());
298+ ASSERT_EQ (num_tokens, zero_points.size ());
299+
300+ // Create scale and zero_point tensors
301+ at::Tensor scale_tensor =
302+ at::tensor (scales, at::device (at::kCPU ).dtype (at::kDouble ));
303+ at::Tensor zero_point_tensor =
304+ at::tensor (zero_points, at::device (at::kCPU ).dtype (at::kLong ));
305+
306+ // Get reference output
307+ at::Tensor reference_out = quantize_per_token_reference_impl (
308+ input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype);
309+
310+ // Get implementation output
311+ at::Tensor impl_out = torch::executor::native::quantize_per_token_aten (
312+ input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype);
313+
314+ // Convert to int for consistent display regardless of underlying type
315+ at::Tensor reference_int = reference_out.to (at::kInt );
316+ at::Tensor impl_int = impl_out.to (at::kInt );
317+
318+ const bool output_correct = at::equal (reference_int, impl_out);
319+ if (!output_correct) {
320+ std::cout << " \n "
321+ << " Failed with parameters: " << std::endl;
322+ std::cout << " scale(s):" ;
323+ for (size_t i = 0 ; i < scales.size (); i++) {
324+ std::cout << " " << scales[i] << " " ;
325+ }
326+ std::cout << " " << std::endl;
327+ std::cout << " zero_point(s):" ;
328+ for (size_t i = 0 ; i < zero_points.size (); i++) {
329+ std::cout << " " << zero_points[i] << " " ;
330+ }
331+ std::cout << " " << std::endl;
332+ std::cout << " quant_min: " << quant_min << std::endl;
333+ std::cout << " quant_max: " << quant_max << std::endl;
334+
335+ std::cout << " input:" << std::endl;
336+ std::cout << input << std::endl;
337+ std::cout << " reference:" << std::endl;
338+ std::cout << reference_int << std::endl;
339+ std::cout << " my_reference:" << std::endl;
340+ std::cout << impl_out << std::endl;
341+ }
342+
343+ ASSERT_TRUE (output_correct);
344+ }
345+
346+ void test_vulkan_quantize_per_token_impl (
347+ const std::vector<int >& input_sizes,
348+ const std::vector<float >& scales,
349+ const std::vector<int >& zero_points,
350+ int64_t quant_min,
351+ int64_t quant_max,
352+ at::ScalarType in_dtype = at::kFloat ,
353+ at::ScalarType dtype = at::kInt ,
354+ const vkcompute::utils::StorageType in_storage =
355+ vkcompute::utils::kTexture3D ,
356+ const vkcompute::utils::StorageType out_storage =
357+ vkcompute::utils::kTexture3D ) {
358+ check_quantize_args (quant_min, quant_max, dtype);
359+ int num_tokens = 1 ;
360+ for (int i = 0 ; i < input_sizes.size () - 1 ; i++) {
361+ num_tokens *= input_sizes[i];
362+ }
363+
364+ ASSERT_EQ (num_tokens, scales.size ());
365+ ASSERT_EQ (num_tokens, zero_points.size ());
366+
367+ // Create input tensor with random values
368+ std::vector<int64_t > input_sizes_int64 (
369+ input_sizes.begin (), input_sizes.end ());
370+ at::Tensor input =
371+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (in_dtype));
372+ at::Tensor scale_tensor =
373+ at::tensor (scales, at::device (at::kCPU ).dtype (at::kDouble ));
374+ at::Tensor zero_point_tensor =
375+ at::tensor (zero_points, at::device (at::kCPU ).dtype (at::kLong ));
376+
377+ // Get reference output to show what we would compare against
378+ at::Tensor reference_out = torch::executor::native::quantize_per_token_aten (
379+ input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype);
380+
381+ using namespace vkcompute ;
382+
383+ GraphConfig config;
384+ config.set_storage_type_override (in_storage);
385+ ComputeGraph graph (config);
386+
387+ IOValueRef r_input = graph.add_input_tensor (
388+ input.sizes ().vec (), from_at_scalartype (input.scalar_type ()), in_storage);
389+ IOValueRef r_scale = graph.add_input_tensor (
390+ scale_tensor.sizes ().vec (), vkapi::kFloat , in_storage);
391+ IOValueRef r_zero_point = graph.add_input_tensor (
392+ zero_point_tensor.sizes ().vec (), vkapi::kInt , in_storage);
393+
394+ const ValueRef r_quant_min = graph.add_scalar <int64_t >(quant_min);
395+ const ValueRef r_quant_max = graph.add_scalar <int64_t >(quant_max);
396+
397+ const ValueRef r_out = graph.add_tensor (
398+ input.sizes ().vec (), from_at_scalartype (dtype), out_storage);
399+
400+ VK_GET_OP_FN (" quantize_per_token.default" )
401+ (graph,
402+ {
403+ r_input.value ,
404+ r_scale.value ,
405+ r_zero_point.value ,
406+ r_quant_min,
407+ r_quant_max,
408+ r_out,
409+ });
410+
411+ ValueRef staging_out = graph.set_output_tensor (r_out);
412+
413+ graph.prepare ();
414+ graph.encode_prepack ();
415+ graph.prepack ();
416+ graph.encode_execute ();
417+
418+ // Copy input data to GPU
419+ graph.copy_into_staging (
420+ r_input.staging , input.const_data_ptr (), input.numel ());
421+
422+ // Convert scale tensor to float and copy to GPU
423+ at::Tensor scale_float = scale_tensor.to (at::kFloat );
424+ graph.copy_into_staging (
425+ r_scale.staging , scale_float.const_data_ptr (), scale_float.numel ());
426+
427+ // Convert zero_point tensor to int and copy to GPU
428+ at::Tensor zero_point_int = zero_point_tensor.to (at::kInt );
429+ graph.copy_into_staging (
430+ r_zero_point.staging ,
431+ zero_point_int.const_data_ptr (),
432+ zero_point_int.numel ());
433+
434+ // Execute the graph
435+ graph.execute ();
436+
437+ // Copy output data back to CPU
438+ at::Tensor vk_out = at::empty_like (reference_out).contiguous ();
439+ graph.copy_from_staging (
440+ staging_out, vk_out.mutable_data_ptr (), vk_out.numel ());
441+
442+ // Compare outputs
443+ at::Tensor reference_int = reference_out.to (at::kInt );
444+ at::Tensor vk_int = vk_out.to (at::kInt );
445+
446+ const bool output_correct = at::equal (reference_int, vk_int);
447+ if (!output_correct) {
448+ at::Tensor diffs = at::abs (reference_int - vk_int);
449+
450+ std::cout << " \n "
451+ << " Failed with parameters: " << std::endl;
452+ std::cout << " scale(s):" ;
453+ for (size_t i = 0 ; i < scales.size (); i++) {
454+ std::cout << " " << scales[i] << " " ;
455+ }
456+ std::cout << " " << std::endl;
457+ std::cout << " zero_point(s):" ;
458+ for (size_t i = 0 ; i < zero_points.size (); i++) {
459+ std::cout << " " << zero_points[i] << " " ;
460+ }
461+ std::cout << " " << std::endl;
462+ std::cout << " quant_min: " << quant_min << std::endl;
463+ std::cout << " quant_max: " << quant_max << std::endl;
464+ std::cout << " storage type: "
465+ << (in_storage == vkcompute::utils::kBuffer ? " buffer"
466+ : " texture" )
467+ << std::endl;
468+
469+ std::cout << " input:" << std::endl;
470+ std::cout << input << std::endl;
471+ std::cout << " reference:" << std::endl;
472+ std::cout << reference_int << std::endl;
473+ std::cout << " vulkan:" << std::endl;
474+ std::cout << vk_int << std::endl;
475+ }
476+
477+ ASSERT_TRUE (output_correct);
478+ }
479+
480+ TEST (
481+ VulkanQuantizePerTensorTest,
482+ test_reference_quantize_per_token_float_to_int8) {
483+ std::vector<float > scales = {0.1 , 0 , 0.3 , 0.1 , 0.2 , 0.3 };
484+ std::vector<int > zero_points = {1 , 2 , 3 , 0 , -1 , -2 };
485+
486+ test_reference_quantize_per_token (
487+ {2 , 3 , 4 }, // input sizes (2*3=6 tokens)
488+ scales,
489+ zero_points,
490+ -128 , // quant_min
491+ 127 , // quant_max
492+ at::kFloat ,
493+ at::kChar );
494+ }
495+
496+ TEST (
497+ VulkanQuantizePerTensorTest,
498+ test_reference_quantize_per_token_float_to_int32) {
499+ std::vector<float > scales = {0.1 , 0 , 0.3 , 0.1 , 0.2 , 0.3 };
500+ std::vector<int > zero_points = {1 , 2 , 3 , 0 , -1 , -2 };
501+
502+ test_reference_quantize_per_token (
503+ {2 , 3 , 4 }, // input sizes (2*3=6 tokens)
504+ scales,
505+ zero_points,
506+ std::numeric_limits<int32_t >::min (), // quant_min
507+ std::numeric_limits<int32_t >::max (), // quant_max
508+ at::kFloat ,
509+ at::kInt );
510+ }
511+
512+ TEST (
513+ VulkanQuantizePerTensorTest,
514+ test_reference_quantize_per_token_half_to_int32) {
515+ std::vector<float > scales = {0.1 , 0 , 0.3 , 0.1 , 0.2 , 0.3 };
516+ std::vector<int > zero_points = {1 , 2 , 3 , 0 , -1 , -2 };
517+
518+ test_reference_quantize_per_token (
519+ {2 , 3 , 4 }, // input sizes (2*3=6 tokens)
520+ scales,
521+ zero_points,
522+ std::numeric_limits<int32_t >::min (), // quant_min
523+ std::numeric_limits<int32_t >::max (), // quant_max
524+ at::kHalf ,
525+ at::kInt );
526+ }
527+
528+ TEST (
529+ VulkanQuantizePerTensorTest,
530+ test_reference_quantize_per_token_half_to_uint8) {
531+ std::vector<float > scales = {0.1 , 0 , 0.3 , 0.1 , 0.2 , 0.3 };
532+ std::vector<int > zero_points = {1 , 2 , 3 , 0 , -1 , -2 };
533+
534+ test_reference_quantize_per_token (
535+ {2 , 3 , 4 }, // input sizes (2*3=6 tokens)
536+ scales,
537+ zero_points,
538+ 0 , // quant_min
539+ 255 , // quant_max
540+ at::kHalf ,
541+ at::kByte );
542+ }
0 commit comments