@@ -114,3 +114,277 @@ std::tuple<at::Tensor, at::Tensor> choose_qparams_per_token_asymmetric_aten(
114114} // namespace native
115115} // namespace executor
116116} // namespace torch
117+
118+ //
119+ // Reference Implementation
120+ //
121+
122+ /*
123+ * Reference implementation of choose_qparams_tensor
124+ */
125+ std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_reference_impl (
126+ const at::Tensor& input,
127+ int64_t quant_min,
128+ int64_t quant_max) {
129+ // Create output tensors
130+ at::Tensor scale_out = at::empty ({}, at::device (at::kCPU ).dtype (at::kDouble ));
131+ at::Tensor zero_point_out =
132+ at::empty ({}, at::device (at::kCPU ).dtype (at::kLong ));
133+
134+ // Find min and max values in the input tensor
135+ float min_val = input.min ().item <float >();
136+ float max_val = input.max ().item <float >();
137+
138+ // Extend the [min, max] interval to ensure it contains 0
139+ min_val = std::min (min_val, 0 .f );
140+ max_val = std::max (max_val, 0 .f );
141+
142+ // Calculate scale
143+ double scale =
144+ (static_cast <double >(max_val) - min_val) / (quant_max - quant_min);
145+
146+ // Handle small scale
147+ constexpr float SMALL_SCALE_THRESHOLD = 6 .1e-5f ;
148+ if (float (scale) == 0 .0f || std::isinf (1 .0f / float (scale))) {
149+ scale = 0.1 ;
150+ }
151+
152+ if (scale < SMALL_SCALE_THRESHOLD) {
153+ float org_scale = scale;
154+ scale = SMALL_SCALE_THRESHOLD;
155+ // Adjust min and max based on new scale
156+ if (min_val == 0 .0f ) {
157+ max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
158+ } else if (max_val == 0 .0f ) {
159+ min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
160+ } else {
161+ float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
162+ min_val *= amplifier;
163+ max_val *= amplifier;
164+ }
165+ }
166+
167+ // Calculate zero point
168+ double zero_point_from_min = quant_min - min_val / static_cast <double >(scale);
169+ double zero_point_from_max = quant_max - max_val / static_cast <double >(scale);
170+ double zero_point_from_min_error =
171+ std::abs (quant_min) - std::abs (min_val / static_cast <double >(scale));
172+ double zero_point_from_max_error =
173+ std::abs (quant_max) - std::abs (max_val / static_cast <double >(scale));
174+ double initial_zero_point =
175+ zero_point_from_min_error < zero_point_from_max_error
176+ ? zero_point_from_min
177+ : zero_point_from_max;
178+
179+ // Nudge zero point to be an integer
180+ int64_t nudged_zero_point = 0 ;
181+ if (initial_zero_point < quant_min) {
182+ nudged_zero_point = quant_min;
183+ } else if (initial_zero_point > quant_max) {
184+ nudged_zero_point = quant_max;
185+ } else {
186+ nudged_zero_point = std::nearbyint (static_cast <float >(initial_zero_point));
187+ }
188+
189+ // Set output values - use item_mutable() for scalar tensors
190+ scale_out.fill_ (scale);
191+ zero_point_out.fill_ (nudged_zero_point);
192+
193+ return std::make_tuple (scale_out, zero_point_out);
194+ }
195+
196+ // Forward declaration of implementation functions
197+ void test_vulkan_choose_qparams_tensor_impl (
198+ const std::vector<int >& input_sizes,
199+ int64_t quant_min,
200+ int64_t quant_max,
201+ at::ScalarType dtype,
202+ const vkcompute::utils::StorageType in_storage,
203+ const vkcompute::utils::StorageType out_storage);
204+
205+ // Wrapper function to test both buffer and texture storage types
206+ void test_vulkan_choose_qparams_tensor (
207+ const std::vector<int >& input_sizes,
208+ int64_t quant_min,
209+ int64_t quant_max,
210+ at::ScalarType dtype) {
211+ // Test with buffer storage
212+ test_vulkan_choose_qparams_tensor_impl (
213+ input_sizes,
214+ quant_min,
215+ quant_max,
216+ dtype,
217+ vkcompute::utils::kBuffer ,
218+ vkcompute::utils::kBuffer );
219+
220+ // Test with texture storage
221+ test_vulkan_choose_qparams_tensor_impl (
222+ input_sizes,
223+ quant_min,
224+ quant_max,
225+ dtype,
226+ vkcompute::utils::kTexture3D ,
227+ vkcompute::utils::kTexture3D );
228+ }
229+
230+ void test_reference_choose_qparams_tensor (
231+ const std::vector<int >& input_sizes,
232+ int64_t quant_min,
233+ int64_t quant_max,
234+ at::ScalarType dtype) {
235+ std::vector<int64_t > input_sizes_int64 (
236+ input_sizes.begin (), input_sizes.end ());
237+ at::Tensor input =
238+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
239+
240+ // Get reference output
241+ auto [reference_scale, reference_zero_point] =
242+ choose_qparams_tensor_reference_impl (input, quant_min, quant_max);
243+
244+ // Get implementation output
245+ auto [impl_scale, impl_zero_point] =
246+ torch::executor::native::choose_qparams_tensor_aten (
247+ input, quant_min, quant_max, dtype);
248+
249+ // Compare outputs
250+ const bool scale_correct = at::allclose (reference_scale, impl_scale);
251+ const bool zero_point_correct =
252+ at::equal (reference_zero_point, impl_zero_point);
253+
254+ if (!scale_correct || !zero_point_correct) {
255+ std::cout << " \n "
256+ << " Failed with parameters: " << std::endl;
257+ std::cout << " quant_min: " << quant_min << std::endl;
258+ std::cout << " quant_max: " << quant_max << std::endl;
259+
260+ std::cout << " input:" << std::endl;
261+ std::cout << input << std::endl;
262+ std::cout << " reference scale:" << std::endl;
263+ std::cout << reference_scale << std::endl;
264+ std::cout << " implementation scale:" << std::endl;
265+ std::cout << impl_scale << std::endl;
266+ std::cout << " reference zero_point:" << std::endl;
267+ std::cout << reference_zero_point << std::endl;
268+ std::cout << " implementation zero_point:" << std::endl;
269+ std::cout << impl_zero_point << std::endl;
270+ }
271+
272+ ASSERT_TRUE (scale_correct && zero_point_correct);
273+ }
274+
275+ void test_vulkan_choose_qparams_tensor_impl (
276+ const std::vector<int >& input_sizes,
277+ int64_t quant_min,
278+ int64_t quant_max,
279+ at::ScalarType dtype,
280+ const vkcompute::utils::StorageType in_storage,
281+ const vkcompute::utils::StorageType out_storage) {
282+ std::vector<int64_t > input_sizes_int64 (
283+ input_sizes.begin (), input_sizes.end ());
284+ at::Tensor input =
285+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
286+
287+ // Get reference output
288+ auto [reference_scale, reference_zero_point] =
289+ torch::executor::native::choose_qparams_tensor_aten (
290+ input, quant_min, quant_max, dtype);
291+
292+ // Build Vulkan choose_qparams_tensor graph
293+ using namespace vkcompute ;
294+
295+ GraphConfig config;
296+ config.set_storage_type_override (in_storage);
297+ ComputeGraph graph (config);
298+
299+ IOValueRef r_input = graph.add_input_tensor (
300+ input.sizes ().vec (), from_at_scalartype (input.scalar_type ()), in_storage);
301+
302+ const ValueRef r_quant_min = graph.add_scalar <int64_t >(quant_min);
303+ const ValueRef r_quant_max = graph.add_scalar <int64_t >(quant_max);
304+
305+ // Output tensors
306+ const ValueRef r_scale = graph.add_tensor ({}, vkapi::kFloat , out_storage);
307+ const ValueRef r_zero_point = graph.add_tensor ({}, vkapi::kInt , out_storage);
308+
309+ VK_GET_OP_FN (" choose_qparams.tensor" )
310+ (graph,
311+ {
312+ r_input.value ,
313+ r_quant_min,
314+ r_quant_max,
315+ r_scale,
316+ r_zero_point,
317+ });
318+
319+ ValueRef staging_scale = graph.set_output_tensor (r_scale);
320+ ValueRef staging_zero_point = graph.set_output_tensor (r_zero_point);
321+
322+ graph.prepare ();
323+ graph.encode_prepack ();
324+ graph.prepack ();
325+ graph.encode_execute ();
326+
327+ // Run Vulkan choose_qparams_tensor
328+ graph.copy_into_staging (
329+ r_input.staging , input.const_data_ptr (), input.numel ());
330+
331+ graph.execute ();
332+
333+ // Create output tensors to hold the results - use types that match GPU output
334+ at::Tensor vk_scale =
335+ at::empty ({}, at::device (at::kCPU ).dtype (at::kFloat )).contiguous ();
336+ at::Tensor vk_zero_point =
337+ at::empty ({}, at::device (at::kCPU ).dtype (at::kInt )).contiguous ();
338+
339+ // Copy results from GPU to CPU
340+ graph.copy_from_staging (
341+ staging_scale, vk_scale.mutable_data_ptr (), vk_scale.numel ());
342+ graph.copy_from_staging (
343+ staging_zero_point,
344+ vk_zero_point.mutable_data_ptr (),
345+ vk_zero_point.numel ());
346+
347+ // Convert reference values to match Vulkan output types for comparison
348+ at::Tensor reference_scale_float = reference_scale.to (at::kFloat );
349+ at::Tensor reference_zero_point_int = reference_zero_point.to (at::kInt );
350+
351+ // Compare outputs
352+ const bool scale_correct = at::allclose (reference_scale_float, vk_scale);
353+ const bool zero_point_correct =
354+ at::equal (reference_zero_point_int, vk_zero_point);
355+
356+ if (!scale_correct || !zero_point_correct) {
357+ std::cout << " \n "
358+ << " Failed with parameters: " << std::endl;
359+ std::cout << " quant_min: " << quant_min << std::endl;
360+ std::cout << " quant_max: " << quant_max << std::endl;
361+ std::cout << " storage type: "
362+ << (in_storage == vkcompute::utils::kBuffer ? " buffer"
363+ : " texture" )
364+ << std::endl;
365+
366+ // make sure that there arent a ton of elements in the input tensor
367+ if (input.numel () < 100 ) {
368+ std::cout << " input:" << std::endl;
369+ std::cout << input << " \n " << std::endl;
370+ std::cout << " reference scale:" << std::endl;
371+ std::cout << reference_scale << std::endl;
372+ std::cout << " vulkan scale:" << std::endl;
373+ std::cout << vk_scale << " \n " << std::endl;
374+ std::cout << " reference zero_point:" << std::endl;
375+ std::cout << reference_zero_point << std::endl;
376+ std::cout << " vulkan zero_point:" << std::endl;
377+ std::cout << vk_zero_point << std::endl;
378+ }
379+ }
380+
381+ ASSERT_TRUE (scale_correct && zero_point_correct);
382+ }
383+
384+ TEST (VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) {
385+ test_reference_choose_qparams_tensor (
386+ {2 , 3 , 4 }, // input sizes
387+ -128 , // quant_min
388+ 127 , // quant_max
389+ at::kChar );
390+ }
0 commit comments