@@ -33,7 +33,7 @@ def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module)
3333            "Please install fast-hadamard-transform: pip install fast-hadamard-transform" 
3434        )
3535
36-     class  FeedForwardCustom (nn .Module ):
36+     class  FeedForwardCudaCustom (nn .Module ):
3737        def  __init__ (self , w1 , w2 , w3 ):
3838            super ().__init__ ()
3939            self .w1  =  w1 
@@ -47,7 +47,7 @@ def forward(self, x):
4747
4848    for  name , child  in  module .named_children ():
4949        if  isinstance (child , FeedForward ):
50-             setattr (module , name , FeedForwardCustom (child .w1 , child .w2 , child .w3 ))
50+             setattr (module , name , FeedForwardCudaCustom (child .w1 , child .w2 , child .w3 ))
5151        else :
5252            _inject_fast_hadamard_transform_cuda_for_spin_quant (child )
5353
@@ -59,6 +59,38 @@ def inject_fast_hadamard_transform_cuda_for_spin_quant(
5959    return  module 
6060
6161
62+ def  _inject_fast_hadamard_transform_native_for_spin_quant (module : torch .nn .Module ):
63+     """ 
64+     SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer. 
65+     R3 needs to be injected as well when KV cache quantization is enabled. 
66+     """ 
67+ 
68+     class  FeedForwardNativeCustom (nn .Module ):
69+         def  __init__ (self , w1 , w2 , w3 ):
70+             super ().__init__ ()
71+             self .w1  =  w1 
72+             self .w2  =  w2 
73+             self .w3  =  w3 
74+ 
75+         def  forward (self , x ):
76+             return  self .w2 (
77+                 torch .ops .llama .fast_hadamard_transform (F .silu (self .w1 (x )) *  self .w3 (x ))
78+             )
79+ 
80+     for  name , child  in  module .named_children ():
81+         if  isinstance (child , FeedForward ):
82+             setattr (module , name , FeedForwardNativeCustom (child .w1 , child .w2 , child .w3 ))
83+         else :
84+             _inject_fast_hadamard_transform_native_for_spin_quant (child )
85+ 
86+ 
87+ def  inject_fast_hadamard_transform_native_for_spin_quant (
88+     module : torch .nn .Module ,
89+ ) ->  torch .nn .Module :
90+     _inject_fast_hadamard_transform_native_for_spin_quant (module )
91+     return  module 
92+ 
93+ 
6294def  _replace_linear_with_linear_8da4w_for_spin_quant (
6395    module : torch .nn .Module ,
6496    checkpoint : Any ,
0 commit comments