@@ -234,62 +234,27 @@ Tensor qnnpack_add(Tensor qa, Tensor qb, double scale, int64_t zero_point) {
234234
235235#ifdef USE_XNNPACK
236236C10_ALWAYS_INLINE
237- enum xnn_status xnnp_create_add_nd (
238- int8_t azp,
239- float ascale,
240- int8_t bzp,
241- float bscale,
242- int8_t czp,
243- float cscale,
244- int8_t output_min,
245- int8_t output_max,
246- uint32_t flags,
247- xnn_operator_t * op) {
248- return xnn_create_add_nd_qs8 (
249- azp, /* int8_t input1_zero_point */
250- ascale, /* float input1_scale */
251- bzp, /* int8_t input2_zero_point */
252- bscale, /* float input2_scale */
253- czp, /* int8_t output_zero_point */
254- cscale, /* float output_scale */
255- output_min, /* int8_t output_min */
256- output_max, /* int8_t output_max */
257- flags, /* uint32_t flags */
258- op); /* xnn_operator_t* add_op_out */
259- }
260-
261- C10_ALWAYS_INLINE
262- enum xnn_status xnnp_reshape_add_nd (
263- xnn_operator_t op,
264- const std::vector<size_t >& a_shape,
265- const std::vector<size_t >& b_shape,
266- pthreadpool_t pt_pool) {
267- return xnn_reshape_add_nd_qs8 (
268- op, /* xnn_operator_t add_op */
269- a_shape.size (), /* size_t num_input1_dims */
270- a_shape.data (), /* const size_t* input1_shape */
271- b_shape.size (), /* size_t num_input2_dims */
272- b_shape.data (), /* const size_t* input2_shape */
273- pt_pool); /* pthreadpool_t threadpool */
274- }
275-
276- C10_ALWAYS_INLINE
277- enum xnn_status xnnp_setup_add_nd (
278- xnn_operator_t op,
279- const int8_t * da,
280- const int8_t * db,
281- int8_t * dc,
282- pthreadpool_t pt_pool) {
283- return xnn_setup_add_nd_qs8 (
284- op, /* xnn_operator_t add_op */
285- da, /* const int8_t* input1 */
286- db, /* const int8_t* input2 */
287- dc); /* int8_t* output */
237+ enum xnn_status xnnp_define_q_tensor (const Tensor& tensor, MemoryFormat format, uint32_t & id, xnn_subgraph_t subgraph_ptr, uint32_t external_id, uint32_t flags){
238+ Tensor contig_tensor = tensor.contiguous (format);
239+ const auto tensor_shape = xnnp_utils::get_mem_format_aware_shape (contig_tensor);
240+ const int32_t zero_point = static_cast <int32_t >(contig_tensor.q_zero_point ());
241+ const float scale = static_cast <float >(contig_tensor.q_scale ());
242+
243+ return xnn_define_quantized_tensor_value (
244+ subgraph_ptr,
245+ xnn_datatype_qint8,
246+ zero_point,
247+ scale,
248+ tensor.ndimension (),
249+ tensor_shape.data (),
250+ nullptr ,
251+ external_id,
252+ flags,
253+ &id);
288254}
289255
290256template <typename scalar_t , bool ReLUFused = false >
291257Tensor xnnp_add (Tensor qa, Tensor qb, double scale, int64_t zero_point) {
292- using underlying_t = typename scalar_t ::underlying;
293258 const string func_name = " xnnp_add()" ;
294259 TORCH_CHECK (qa.ndimension () > 0 , func_name, " : Got empty input tensor." );
295260 TORCH_CHECK (at::native::xnnpack::available (), func_name, " : XNNPACK is not available" )
@@ -299,12 +264,6 @@ Tensor xnnp_add(Tensor qa, Tensor qb, double scale, int64_t zero_point) {
299264 auto qa_mem_format = qa.suggest_memory_format ();
300265 Tensor qa_contig = qa.contiguous (qa_mem_format);
301266 Tensor qb_contig = qb.contiguous (qa_mem_format);
302-
303- const auto a_zero_point = qa_contig.q_zero_point ();
304- const auto b_zero_point = qb_contig.q_zero_point ();
305- const auto a_scale = qa_contig.q_scale ();
306- const auto b_scale = qb_contig.q_scale ();
307-
308267 Tensor qy = at::native::empty_affine_quantized (
309268 at::infer_size_dimvector (qa_contig.sizes (), qb_contig.sizes ()),
310269 qa.scalar_type (),
@@ -319,72 +278,108 @@ Tensor xnnp_add(Tensor qa, Tensor qb, double scale, int64_t zero_point) {
319278 return qy;
320279 }
321280
322- xnn_operator_t xnnp_op = nullptr ;
323- xnnpack_operator xnnp_add_operator;
324281
325- auto output_max = std::numeric_limits<underlying_t >::max ();
326- auto output_min = std::numeric_limits<underlying_t >::min ();
282+ auto output_max = std::numeric_limits<float >::infinity ();
283+ auto output_min = - std::numeric_limits<float >::infinity ();
327284 if (ReLUFused) {
328- /*
329- * FIXME: use activationLimits<T>()
330- * With <T>, MSVC runs into "error C3862: identifier activationLimits not found".
331- */
332- constexpr int64_t qmin = std::numeric_limits<underlying_t >::min ();
333- constexpr int64_t qmax = std::numeric_limits<underlying_t >::max ();
334- int64_t qvalue = static_cast <int64_t >(zero_point);
335- qvalue = std::max<int64_t >(qvalue, qmin);
336- output_min = static_cast <underlying_t >(std::min<int64_t >(qvalue, qmax));
285+ output_min = 0 ;
337286 }
338287
339- // Create an operator
340- auto status = xnnp_create_add_nd (
341- a_zero_point,
342- a_scale,
343- b_zero_point,
344- b_scale,
345- static_cast <underlying_t >(zero_point),
346- static_cast <float >(scale),
347- output_min,
348- output_max,
349- 0 ,
350- &xnnp_op);
351- xnnp_add_operator = xnnpack_operator (xnnp_op);
288+ // Create XNNPACK Subgraph
289+ xnn_subgraph_t subgraph_ptr = nullptr ;
290+ auto status = xnn_create_subgraph (
291+ /* external_value_ids=*/ 3 ,
292+ /* flags=*/ 0 ,
293+ &subgraph_ptr);
352294 TORCH_CHECK (
353295 status == xnn_status_success,
354- func_name, " : xnn create operator failed(" , status," )!" );
355-
356- const auto qa_shape = xnnp_utils::get_mem_format_aware_shape (qa_contig);
357- const auto qb_shape = xnnp_utils::get_mem_format_aware_shape (qb_contig);
358-
359- // Reshape the operator
360- status = xnnp_reshape_add_nd (
361- xnnp_add_operator.get (),
362- qa_shape,
363- qb_shape,
364- caffe2::pthreadpool_ ());
296+ func_name, " : xnn create subgraph failed(" , status," )!" );
297+ std::unique_ptr<xnn_subgraph, decltype (&xnn_delete_subgraph)> subgraph (
298+ subgraph_ptr, &xnn_delete_subgraph);
299+
300+ uint32_t input0_id = XNN_INVALID_VALUE_ID, input1_id = XNN_INVALID_VALUE_ID, output_id = XNN_INVALID_VALUE_ID;
301+
302+ // Defining the quantized input 0
303+ status = xnnp_define_q_tensor (
304+ qa,
305+ qa_mem_format,
306+ input0_id,
307+ subgraph_ptr,
308+ 0 ,
309+ XNN_VALUE_FLAG_EXTERNAL_INPUT
310+ );
311+ TORCH_CHECK (
312+ status == xnn_status_success && input0_id != XNN_INVALID_VALUE_ID,
313+ func_name, " : xnn define input 0 failed(" , status," )!" );
314+
315+ // Defining the quantized input 1
316+ status = xnnp_define_q_tensor (
317+ qb,
318+ qa_mem_format,
319+ input1_id,
320+ subgraph_ptr,
321+ 1 ,
322+ XNN_VALUE_FLAG_EXTERNAL_INPUT
323+ );
324+ TORCH_CHECK (
325+ status == xnn_status_success && input1_id != XNN_INVALID_VALUE_ID,
326+ func_name, " : xnn define input 1 failed(" , status," )!" );
327+
328+ // Defining the quantized output
329+ status = xnnp_define_q_tensor (
330+ qy,
331+ qa_mem_format,
332+ output_id,
333+ subgraph_ptr,
334+ 2 ,
335+ XNN_VALUE_FLAG_EXTERNAL_OUTPUT
336+ );
337+ TORCH_CHECK (
338+ status == xnn_status_success && output_id != XNN_INVALID_VALUE_ID,
339+ func_name, " : xnn define output failed(" , status," )!" );
340+
341+ const struct xnn_binary_params binary_params = {output_min, output_max};
342+ status = xnn_define_binary (
343+ subgraph_ptr,
344+ xnn_binary_add,
345+ &binary_params,
346+ input0_id,
347+ input1_id,
348+ output_id,
349+ 0 );
350+ TORCH_CHECK (
351+ status == xnn_status_success,
352+ func_name, " : xnn define binary add failed(" , status," )!" );
365353
354+ // create runtime
355+ xnn_runtime_t runtime_ptr = nullptr ;
356+ status = xnn_create_runtime_v2 (subgraph_ptr, caffe2::pthreadpool_ (), 0 , &runtime_ptr);
366357 TORCH_CHECK (
367358 status == xnn_status_success,
368- func_name, " : xnn reshape operator failed(" , status," )!" );
369-
370- // Setup the operator
371- status = xnnp_setup_add_nd (
372- xnnp_add_operator.get (),
373- reinterpret_cast <const underlying_t *>(qa_contig.data_ptr <scalar_t >()),
374- reinterpret_cast <const underlying_t *>(qb_contig.data_ptr <scalar_t >()),
375- reinterpret_cast <underlying_t *>(qy.data_ptr <scalar_t >()),
376- caffe2::pthreadpool_ ());
359+ func_name, " : xnn create runtime failed(" , status," )!" );
360+ TORCH_CHECK (
361+ runtime_ptr != nullptr ,
362+ func_name, " : xnn create runtime failed because runtime_ptr is null" );
363+ std::unique_ptr<xnn_runtime, decltype (&xnn_delete_runtime)> auto_runtime (
364+ runtime_ptr, &xnn_delete_runtime);
365+
366+ std::array<xnn_external_value, 3 > external = {
367+ xnn_external_value{input0_id, reinterpret_cast <void *>(qa_contig.data_ptr <scalar_t >())},
368+ xnn_external_value{input1_id, reinterpret_cast <void *>(qb_contig.data_ptr <scalar_t >())},
369+ xnn_external_value{output_id, reinterpret_cast <void *>(qy.data_ptr <scalar_t >())}};
370+
371+ status = xnn_setup_runtime (
372+ runtime_ptr,
373+ external.size (),
374+ external.data ());
377375 TORCH_CHECK (
378376 status == xnn_status_success,
379- func_name, " : xnn setup operator failed(" , status," )!" );
380-
381- // Run the operator
382- status = xnn_run_operator (
383- xnnp_add_operator.get (), /* xnn_operator_t op */
384- caffe2::pthreadpool_ ()); /* pthreadpool_t threadpool */
377+ func_name, " : xnn setup runtime failed(" , status," )!" );
378+ status = xnn_invoke_runtime (runtime_ptr);
385379 TORCH_CHECK (
386380 status == xnn_status_success,
387- func_name, " : xnn run operator failed(" , status," )" );
381+ func_name, " : xnn invoke runtime failed(" , status," )!" );
382+
388383 return qy;
389384}
390385#endif // USE_XNNPACK
0 commit comments