@@ -398,16 +398,8 @@ at::Tensor woq_linear_pack_weight(
398398 // Note that weight is already compressed
399399 int64_t K_int4_compressed = K / 2 ;
400400 int64_t N_int4 = N % block_n ? N / block_n * block_n + block_n : N;
401- at::Tensor weight_int4 = at::empty (
402- {N_int4, K_int4_compressed}, device (c10::kCPU ).dtype (c10::kByte ));
403- int64_t weight_size_bytes = weight.numel ();
404- int64_t weight_int4_size_bytes = weight_int4.numel ();
405- int64_t pad_size_bytes = weight_int4_size_bytes - weight_size_bytes;
406- std::memcpy (weight_int4.data_ptr (), weight.data_ptr (), weight_size_bytes);
407- std::fill_n (
408- (uint8_t *)weight_int4.data_ptr () + weight_size_bytes,
409- pad_size_bytes,
410- 0 );
401+ at::Tensor weight_int4 =
402+ at::pad (weight, {0 , 0 , 0 , N_int4 - N}, " constant" , 0 );
411403 return woq_tpp_gemm_packB_stub (
412404 kCPU , weight_int4, weight_dtype, block_n, block_k, lowp_mode);
413405 }
@@ -491,7 +483,9 @@ at::Tensor woq_linear_kernel(
491483 int64_t lowp_mode,
492484 int64_t act_quant_mode,
493485 const c10::optional<at::Tensor>& compensation) {
494- int64_t quant_w_mode = group_size > 0 ? 1 : 0 ;
486+ int64_t quant_w_mode = zps_list[0 ].defined ()
487+ ? (group_size > 0 ? QUANT_W_PER_K_BLOCK : QUANT_W_PER_CHANNEL)
488+ : (group_size > 0 ? QUANT_W_PER_K_BLOCK_SYM : QUANT_W_PER_CHANNEL_SYM);
495489 auto K = self.size (-1 );
496490 auto M = self.numel () / K;
497491 auto in = self;
@@ -533,6 +527,63 @@ at::Tensor woq_linear_forward(
533527 ->run (input);
534528}
535529
530+ at::Tensor woq_linear_forward_v2 (
531+ const at::Tensor& input,
532+ const at::Tensor& qweight,
533+ const c10::string_view& weight_dtype,
534+ const std::vector<int64_t >& weight_shape,
535+ const std::vector<at::Tensor>& weight_scales,
536+ const c10::optional<std::vector<at::Tensor>>& weight_zeros,
537+ const c10::optional<std::vector<at::Tensor>>& bias,
538+ const c10::optional<at::Tensor>& g_idx,
539+ int64_t group_size,
540+ int64_t lowp_mode,
541+ int64_t act_quant_mode,
542+ const c10::optional<at::Tensor>& compensation) {
543+ static const std::map<c10::string_view, int64_t > WOQ_DTYPE_MAP = {
544+ {" int8" , WOQ_DTYPE_INT8},
545+ {" int4" , WOQ_DTYPE_INT4},
546+ {" nf4" , WOQ_DTYPE_NF4},
547+ };
548+ TORCH_CHECK (
549+ WOQ_DTYPE_MAP.find (weight_dtype) != WOQ_DTYPE_MAP.end (),
550+ " Unsupported weight dtype: " ,
551+ weight_dtype);
552+ if (WOQ_DTYPE_MAP.at (weight_dtype) == WOQ_DTYPE_INT8 && lowp_mode == 3 ) {
553+ TORCH_CHECK (compensation.has_value () && compensation.value ().defined ());
554+ }
555+ static const at::Tensor empty_tensor = at::Tensor ();
556+ // zp list of all dtypes = {fp32, fp16, bf16, int8}
557+ static const std::vector<at::Tensor> empty_zp_list = {
558+ empty_tensor, empty_tensor, empty_tensor, empty_tensor};
559+ // bias list of all dtypes = {fp32, fp16, bf16}
560+ static const std::vector<at::Tensor> empty_bias_list = {
561+ empty_tensor, empty_tensor, empty_tensor};
562+ if (weight_zeros.has_value ()) {
563+ TORCH_CHECK (
564+ weight_zeros.value ().size () == 4 ,
565+ " IPEX WOQ: expect list of zeros has length 4" );
566+ }
567+ auto & zeros_list =
568+ weight_zeros.has_value () ? weight_zeros.value () : empty_zp_list;
569+ if (bias.has_value ()) {
570+ TORCH_CHECK (
571+ bias.value ().size () == 3 , " IPEX WOQ: expect list of bias has length 3" );
572+ }
573+ auto & bias_list = bias.has_value () ? bias.value () : empty_bias_list;
574+ return woq_linear_kernel (
575+ input,
576+ qweight,
577+ WOQ_DTYPE_MAP.at (weight_dtype),
578+ weight_scales,
579+ zeros_list,
580+ bias_list,
581+ group_size,
582+ lowp_mode,
583+ act_quant_mode,
584+ compensation);
585+ }
586+
536587at::Tensor woq_linear_unary_kernel (
537588 const at::Tensor& self,
538589 const at::Tensor& weight,
@@ -559,7 +610,9 @@ at::Tensor woq_linear_unary_kernel(
559610 } else if (post_op == " silu" ) {
560611 post_op_fusion_type = WOQ_FUSE_SILU;
561612 }
562- int64_t quant_w_mode = group_size > 0 ? 1 : 0 ;
613+ int64_t quant_w_mode = zps_list[0 ].defined ()
614+ ? (group_size > 0 ? QUANT_W_PER_K_BLOCK : QUANT_W_PER_CHANNEL)
615+ : (group_size > 0 ? QUANT_W_PER_K_BLOCK_SYM : QUANT_W_PER_CHANNEL_SYM);
563616 auto K = self.size (-1 );
564617 auto M = self.numel () / K;
565618 auto in = self;
@@ -648,7 +701,9 @@ at::Tensor woq_linear_binary_kernel(
648701 } else if (post_op == " mul" ) {
649702 post_op_fusion_type = WOQ_FUSE_MUL;
650703 }
651- int64_t quant_w_mode = group_size > 0 ? 1 : 0 ;
704+ int64_t quant_w_mode = zps_list[0 ].defined ()
705+ ? (group_size > 0 ? QUANT_W_PER_K_BLOCK : QUANT_W_PER_CHANNEL)
706+ : (group_size > 0 ? QUANT_W_PER_K_BLOCK_SYM : QUANT_W_PER_CHANNEL_SYM);
652707 auto K = self.size (-1 );
653708 auto M = self.numel () / K;
654709 auto in = self;
@@ -782,6 +837,39 @@ at::Tensor woq_linear_forward(
782837 return op.call (cpu_cached_cast (target_type, input), op_context);
783838}
784839
840+ at::Tensor woq_linear_forward_v2 (
841+ const at::Tensor& input,
842+ const at::Tensor& qweight,
843+ const c10::string_view& weight_dtype,
844+ const std::vector<int64_t >& weight_shape,
845+ const std::vector<at::Tensor>& weight_scales,
846+ const c10::optional<std::vector<at::Tensor>>& weight_zeros,
847+ const c10::optional<std::vector<at::Tensor>>& bias,
848+ const c10::optional<at::Tensor>& g_idx,
849+ int64_t group_size,
850+ int64_t lowp_mode,
851+ int64_t act_quant_mode,
852+ const c10::optional<at::Tensor>& compensation) {
853+ c10::impl::ExcludeDispatchKeyGuard no_autocastCPU (DispatchKey::AutocastCPU);
854+ static auto op = torch::Dispatcher::singleton ()
855+ .findSchemaOrThrow (" torch_ipex::woq_linear" , " " )
856+ .typed <decltype (woq_linear_forward_v2)>();
857+ auto target_type = get_autocast_dtype ();
858+ return op.call (
859+ cpu_cached_cast (target_type, input),
860+ qweight,
861+ weight_dtype,
862+ weight_shape,
863+ weight_scales,
864+ weight_zeros,
865+ bias,
866+ g_idx,
867+ group_size,
868+ lowp_mode,
869+ act_quant_mode,
870+ compensation);
871+ }
872+
785873at::Tensor woq_linear_gelu_forward (
786874 const at::Tensor& input,
787875 const at::Tensor& op_context) {
@@ -964,6 +1052,19 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
9641052 " woq_linear_mul" ,
9651053 c10::DispatchKey::AutocastCPU,
9661054 torch_ipex::autocast::woq_linear_mul_forward);
1055+ // the version without op_context
1056+ m.def (
1057+ " woq_linear(Tensor input, Tensor qweight, str weight_dtype, int[] weight_shape, Tensor[] weight_scales, "
1058+ " Tensor[]? weight_zeros, Tensor[]? bias, Tensor? g_idx, int group_size, int lowp_mode, int act_quant_mode, "
1059+ " Tensor? compensation = None) -> Tensor" );
1060+ m.impl (
1061+ " woq_linear" ,
1062+ c10::DispatchKey::CPU,
1063+ torch_ipex::cpu::woq_linear_forward_v2);
1064+ m.impl (
1065+ " woq_linear" ,
1066+ c10::DispatchKey::AutocastCPU,
1067+ torch_ipex::autocast::woq_linear_forward_v2);
9671068#endif
9681069 // fuse eltwise
9691070 m.def (
0 commit comments