@@ -191,6 +191,21 @@ def forward(self, x, y):
191191        return  a  +  b 
192192
193193
194+ class  SharedDQChain (torch .nn .Module ):
195+     def  __init__ (self , input_size , output_size ):
196+         super ().__init__ ()
197+         self .linear1_weight  =  torch .nn .Parameter (torch .rand (output_size , input_size ))
198+         self .linear1_bias  =  torch .nn .Parameter (torch .rand (output_size ))
199+ 
200+         self .linear2_weight  =  torch .nn .Parameter (torch .rand (output_size , input_size ))
201+         self .linear2_bias  =  torch .nn .Parameter (torch .rand (output_size ))
202+ 
203+     def  forward (self , x ):
204+         a  =  torch .nn .functional .linear (x , self .linear1_weight , self .linear1_bias )
205+         b  =  torch .nn .functional .linear (x , self .linear2_weight , self .linear2_bias )
206+         return  a  +  b 
207+ 
208+ 
194209class  TestLinear (unittest .TestCase ):
195210    """ 
196211    Test Class for XNNPACK Linear Operators. 
@@ -316,6 +331,7 @@ def _test_dqlinear(
316331        uses_bias = False ,
317332        qconfig : Optional [QuantizationConfig ] =  None ,
318333        atol = 5e-02 ,  # TODO(T212995726): Investigate right atol for rand[n] inputs 
334+         no_per_op_mode = False ,
319335    ):
320336        """ 
321337        Helper function to test dynamic quantized linear op with different configurations. 
@@ -324,8 +340,9 @@ def _test_dqlinear(
324340            is_per_channel = is_per_channel ,
325341            is_dynamic = True ,
326342        )
343+         per_op_mode_choices  =  [False ] if  no_per_op_mode  else  [True , False ]
327344        for  legacy_partitioner  in  (True , False ):
328-             for  per_op_mode  in  ( True ,  False ) :
345+             for  per_op_mode  in  per_op_mode_choices :
329346                DynamicallyQuantizedPartitioner  =  XnnpackPartitioner (
330347                    config_precisions = ConfigPrecisionType .DYNAMIC_QUANT ,
331348                    per_op_mode = per_op_mode ,
@@ -520,6 +537,24 @@ def get_qnode_checks(quant_node_checks, dialect):
520537                #     qtol=bool(quant_config), atol=atol 
521538                # ) 
522539
540+     def  test_qd8_f32_per_channel_shared_dq_chain (self ):
541+         for  use_bias  in  (False , True ):
542+             module  =  SharedDQChain (
543+                 input_size = 13 ,
544+                 output_size = 17 ,
545+             )
546+             inputs  =  (torch .randn (1 , 2 , 13 ),)
547+ 
548+             self ._test_dqlinear (
549+                 module ,
550+                 inputs ,
551+                 dynamic_shapes = None ,
552+                 is_per_channel = True ,
553+                 linear_count = 2 ,
554+                 uses_bias = use_bias ,
555+                 no_per_op_mode = True ,
556+             )
557+ 
523558    def  _test_qd8_per_channel_linear (self , dtype : torch .dtype  =  torch .float ):
524559        for  uses_bias  in  (False , True ):
525560            module  =  BaseLinear (
0 commit comments