Skip to content

Commit eadfcdd

Browse files
author
morelos
committed
[ET-VK][Ops] dequantize_per_tensor.default test setup
Creating dequantize_per_tensor testing framework along with a reference implementation for testing Differential Revision: [D76267054](https://our.internmc.facebook.com/intern/diff/D76267054/) [ghstack-poisoned]
1 parent 57b311a commit eadfcdd

File tree

1 file changed

+325
-0
lines changed

1 file changed

+325
-0
lines changed

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,3 +294,328 @@ void check_dequantize_args(
294294
")");
295295
}
296296
}
297+
298+
//
299+
// Reference Implementation
300+
//
301+
302+
/*
303+
* Reference implementation of dequantize_per_tensor
304+
*/
305+
at::Tensor dequantize_per_tensor_reference_impl(
306+
const at::Tensor& input,
307+
double scale,
308+
int64_t zero_point,
309+
int64_t quant_min,
310+
int64_t quant_max,
311+
at::ScalarType dtype,
312+
at::ScalarType out_dtype) {
313+
// Create output tensor with the target dtype
314+
at::Tensor out = at::empty_like(input, out_dtype);
315+
316+
// Dequantize the input tensor
317+
at::Tensor int_input = input.to(at::kInt);
318+
at::Tensor flat_input = int_input.flatten();
319+
at::Tensor flat_out = out.flatten();
320+
321+
for (int i = 0; i < flat_input.numel(); i++) {
322+
int64_t qvalue = flat_input[i].item<int64_t>();
323+
float value = static_cast<float>((qvalue - zero_point) * scale);
324+
325+
if (out_dtype == at::kFloat) {
326+
flat_out[i] = value;
327+
} else if (out_dtype == at::kDouble) {
328+
flat_out[i] = static_cast<double>(value);
329+
}
330+
}
331+
332+
return out.reshape(input.sizes());
333+
}
334+
335+
// Forward declaration of implementation functions
336+
void test_vulkan_dequantize_per_tensor_impl(
337+
const std::vector<int>& input_sizes,
338+
float scale,
339+
int zero_point,
340+
int64_t quant_min,
341+
int64_t quant_max,
342+
at::ScalarType dtype,
343+
at::ScalarType out_dtype,
344+
const vkcompute::utils::StorageType in_storage,
345+
const vkcompute::utils::StorageType out_storage);
346+
347+
// Wrapper function to test both buffer and texture storage types
348+
void test_vulkan_dequantize_per_tensor(
349+
const std::vector<int>& input_sizes,
350+
float scale,
351+
int zero_point,
352+
int64_t quant_min,
353+
int64_t quant_max,
354+
at::ScalarType dtype,
355+
at::ScalarType out_dtype) {
356+
// Test with buffer storage
357+
test_vulkan_dequantize_per_tensor_impl(
358+
input_sizes,
359+
scale,
360+
zero_point,
361+
quant_min,
362+
quant_max,
363+
dtype,
364+
out_dtype,
365+
vkcompute::utils::kBuffer,
366+
vkcompute::utils::kBuffer);
367+
368+
// Test with texture storage
369+
test_vulkan_dequantize_per_tensor_impl(
370+
input_sizes,
371+
scale,
372+
zero_point,
373+
quant_min,
374+
quant_max,
375+
dtype,
376+
out_dtype,
377+
vkcompute::utils::kTexture3D,
378+
vkcompute::utils::kTexture3D);
379+
}
380+
381+
void test_reference_dequantize_per_tensor(
382+
const std::vector<int>& input_sizes,
383+
float scale,
384+
int zero_point,
385+
int64_t quant_min,
386+
int64_t quant_max,
387+
at::ScalarType dtype,
388+
at::ScalarType out_dtype) {
389+
check_dequantize_args(quant_min, quant_max, dtype, out_dtype);
390+
std::vector<int64_t> input_sizes_int64(
391+
input_sizes.begin(), input_sizes.end());
392+
393+
// Create a quantized input tensor with values from quant_min to quant_max
394+
at::Tensor input;
395+
if (dtype == at::kByte) {
396+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte));
397+
} else if (dtype == at::kChar) {
398+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar));
399+
} else if (dtype == at::kShort) {
400+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort));
401+
} else if (dtype == at::kInt) {
402+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt));
403+
} else {
404+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong));
405+
}
406+
407+
// Fill with a simple pattern: values from quant_min to quant_max in steps
408+
float step = 1.0f;
409+
if (input.numel() > 1) {
410+
step = static_cast<float>(quant_max - quant_min) / (input.numel() - 1);
411+
}
412+
413+
auto flat_input = input.flatten();
414+
for (int i = 0; i < flat_input.numel(); i++) {
415+
int64_t qvalue = quant_min + i * step;
416+
if (dtype == at::kByte) {
417+
flat_input[i] = static_cast<uint8_t>(qvalue);
418+
} else if (dtype == at::kChar) {
419+
flat_input[i] = static_cast<int8_t>(qvalue);
420+
} else if (dtype == at::kShort) {
421+
flat_input[i] = static_cast<int16_t>(qvalue);
422+
} else if (dtype == at::kInt) {
423+
flat_input[i] = static_cast<int32_t>(qvalue);
424+
} else if (dtype == at::kLong) {
425+
flat_input[i] = static_cast<int64_t>(qvalue);
426+
}
427+
}
428+
429+
// Reshape back to original dimensions
430+
input = flat_input.reshape(input_sizes_int64);
431+
432+
// Get reference output
433+
at::Tensor reference_out = dequantize_per_tensor_reference_impl(
434+
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype);
435+
436+
// Get implementation output
437+
at::Tensor impl_out = torch::executor::native::dequantize_per_tensor_aten(
438+
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype);
439+
440+
// Compare outputs
441+
const bool output_correct = at::allclose(reference_out, impl_out, 1e-5, 1e-5);
442+
if (!output_correct) {
443+
std::cout << "\n"
444+
<< "Failed with parameters: " << std::endl;
445+
std::cout << " scale: " << scale << std::endl;
446+
std::cout << " zero_point: " << zero_point << std::endl;
447+
std::cout << " quant_min: " << quant_min << std::endl;
448+
std::cout << " quant_max: " << quant_max << std::endl;
449+
450+
std::cout << "input:" << std::endl;
451+
std::cout << input << std::endl;
452+
std::cout << "reference:" << std::endl;
453+
std::cout << reference_out << std::endl;
454+
std::cout << "implementation:" << std::endl;
455+
std::cout << impl_out << std::endl;
456+
}
457+
458+
ASSERT_TRUE(output_correct);
459+
}
460+
461+
void test_vulkan_dequantize_per_tensor_impl(
462+
const std::vector<int>& input_sizes,
463+
float scale,
464+
int zero_point,
465+
int64_t quant_min,
466+
int64_t quant_max,
467+
at::ScalarType dtype,
468+
at::ScalarType out_dtype,
469+
const vkcompute::utils::StorageType in_storage,
470+
const vkcompute::utils::StorageType out_storage) {
471+
check_dequantize_args(quant_min, quant_max, dtype, out_dtype);
472+
std::vector<int64_t> input_sizes_int64(
473+
input_sizes.begin(), input_sizes.end());
474+
475+
// Create a quantized input tensor with values from quant_min to quant_max
476+
at::Tensor input;
477+
if (dtype == at::kByte) {
478+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte));
479+
} else if (dtype == at::kChar) {
480+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar));
481+
} else if (dtype == at::kShort) {
482+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort));
483+
} else if (dtype == at::kInt) {
484+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt));
485+
} else {
486+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong));
487+
}
488+
489+
// Fill with a simple pattern: values from quant_min to quant_max in steps
490+
float step = 1.0f;
491+
if (input.numel() > 1) {
492+
step = static_cast<float>(quant_max - quant_min) / (input.numel() - 1);
493+
}
494+
495+
auto flat_input = input.flatten();
496+
for (int i = 0; i < flat_input.numel(); i++) {
497+
int64_t qvalue = quant_min + i * step;
498+
if (dtype == at::kByte) {
499+
flat_input[i] = static_cast<uint8_t>(qvalue);
500+
} else if (dtype == at::kChar) {
501+
flat_input[i] = static_cast<int8_t>(qvalue);
502+
} else if (dtype == at::kShort) {
503+
flat_input[i] = static_cast<int16_t>(qvalue);
504+
} else if (dtype == at::kInt) {
505+
flat_input[i] = static_cast<int32_t>(qvalue);
506+
} else if (dtype == at::kLong) {
507+
flat_input[i] = static_cast<int64_t>(qvalue);
508+
}
509+
}
510+
511+
// Reshape back to original dimensions
512+
input = flat_input.reshape(input_sizes_int64);
513+
514+
// Get reference output
515+
at::Tensor reference_out = torch::executor::native::dequantize_per_tensor_aten(
516+
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype);
517+
518+
// Build Vulkan dequantize_per_tensor graph
519+
using namespace vkcompute;
520+
521+
GraphConfig config;
522+
config.set_storage_type_override(in_storage);
523+
ComputeGraph graph(config);
524+
525+
IOValueRef r_input = graph.add_input_tensor(
526+
input.sizes().vec(), from_at_scalartype(dtype), in_storage);
527+
528+
const ValueRef r_scale = graph.add_scalar<double>(scale);
529+
const ValueRef r_zero_point = graph.add_scalar<int64_t>(zero_point);
530+
const ValueRef r_quant_min = graph.add_scalar<int64_t>(quant_min);
531+
const ValueRef r_quant_max = graph.add_scalar<int64_t>(quant_max);
532+
533+
const ValueRef r_out = graph.add_tensor(
534+
input.sizes().vec(), from_at_scalartype(out_dtype), out_storage);
535+
536+
VK_GET_OP_FN("dequantize_per_tensor.default")
537+
(graph,
538+
{
539+
r_input.value,
540+
r_scale,
541+
r_zero_point,
542+
r_quant_min,
543+
r_quant_max,
544+
r_out,
545+
});
546+
547+
ValueRef staging_out = graph.set_output_tensor(r_out);
548+
549+
graph.prepare();
550+
graph.encode_prepack();
551+
graph.prepack();
552+
graph.encode_execute();
553+
554+
// Run Vulkan dequantize_per_tensor
555+
graph.copy_into_staging(
556+
r_input.staging, input.const_data_ptr(), input.numel());
557+
558+
graph.execute();
559+
560+
at::Tensor vk_out = at::empty_like(reference_out).contiguous();
561+
graph.copy_from_staging(
562+
staging_out, vk_out.mutable_data_ptr(), vk_out.numel());
563+
564+
// Compare outputs
565+
const bool output_correct = at::allclose(reference_out, vk_out, 1e-5, 1e-5);
566+
if (!output_correct) {
567+
std::cout << "\n"
568+
<< "Failed with parameters: " << std::endl;
569+
std::cout << " scale: " << scale << std::endl;
570+
std::cout << " zero_point: " << zero_point << std::endl;
571+
std::cout << " quant_min: " << quant_min << std::endl;
572+
std::cout << " quant_max: " << quant_max << std::endl;
573+
std::cout << " storage type: "
574+
<< (in_storage == vkcompute::utils::kBuffer ? "buffer"
575+
: "texture")
576+
<< std::endl;
577+
578+
std::cout << "input:" << std::endl;
579+
std::cout << input << std::endl;
580+
std::cout << "reference:" << std::endl;
581+
std::cout << reference_out << std::endl;
582+
std::cout << "vulkan:" << std::endl;
583+
std::cout << vk_out << std::endl;
584+
}
585+
586+
ASSERT_TRUE(output_correct);
587+
}
588+
589+
// Test cases for dequantize_per_tensor
590+
TEST(VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_uint8_to_float) {
591+
test_reference_dequantize_per_tensor(
592+
{2, 3, 4}, // input sizes
593+
0.1, // scale
594+
5, // zero_point
595+
0, // quant_min
596+
255, // quant_max
597+
at::kByte, // input dtype
598+
at::kFloat); // output dtype
599+
}
600+
601+
TEST(VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_int8_to_float) {
602+
test_reference_dequantize_per_tensor(
603+
{3, 4, 5}, // input sizes
604+
0.05, // scale
605+
0, // zero_point
606+
-128, // quant_min
607+
127, // quant_max
608+
at::kChar, // input dtype
609+
at::kFloat); // output dtype
610+
}
611+
612+
TEST(VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_int16_to_float) {
613+
test_reference_dequantize_per_tensor(
614+
{2, 2, 3}, // input sizes
615+
0.001, // scale
616+
-10, // zero_point
617+
-32768, // quant_min
618+
32767, // quant_max
619+
at::kShort, // input dtype
620+
at::kFloat); // output dtype
621+
}

0 commit comments

Comments
 (0)