146146 "quantized_fully_connected(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
147147 "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
148148)
149-
149+ lib .define (
150+ "quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
151+ "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
152+ )
150153
151154# ------------------------------------ #
152155# Migrated from custom_ops.ymal #
192195 "quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
193196 "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
194197)
198+ lib .define (
199+ "quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
200+ "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
201+ )
195202lib .define (
196203 "quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
197204 "Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)"
@@ -595,6 +602,28 @@ def quantized_fully_connected_meta(
595602 bias : torch .Tensor ,
596603 in_zero_point : int ,
597604 weight_zero_point : torch .Tensor ,
605+ out_multiplier : torch .Tensor ,
606+ out_shift : torch .Tensor ,
607+ out_zero_point : int ,
608+ offset : Optional [torch .Tensor ],
609+ ) -> torch .Tensor :
610+ # src comes in shape [leading_dims, in_dim]
611+ # weight comes in shape [out_dim, in_dim]
612+ # output comes in empty with shape [leading_dims, out_dim]
613+ out_size = list (src .size ())
614+ weight_size = list (weight .size ())
615+ assert len (weight_size ) == 2
616+ out_size [- 1 ] = weight_size [0 ]
617+ return src .new_empty (out_size , dtype = src .dtype )
618+
619+
620+ @register_fake ("cadence::quantized_fully_connected.per_tensor" )
621+ def quantized_fully_connected_per_tensor_meta (
622+ src : torch .Tensor ,
623+ weight : torch .Tensor ,
624+ bias : torch .Tensor ,
625+ in_zero_point : int ,
626+ weight_zero_point : int ,
598627 out_multiplier : int ,
599628 out_shift : int ,
600629 out_zero_point : int ,
@@ -607,7 +636,7 @@ def quantized_fully_connected_meta(
607636 weight_size = list (weight .size ())
608637 assert len (weight_size ) == 2
609638 out_size [- 1 ] = weight_size [0 ]
610- return src .new_empty (out_size , dtype = torch . uint8 )
639+ return src .new_empty (out_size , dtype = src . dtype )
611640
612641
613642@register_fake ("cadence::convolution" )
0 commit comments