@@ -33,7 +33,7 @@ def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module)
3333 "Please install fast-hadamard-transform: pip install fast-hadamard-transform"
3434 )
3535
36- class FeedForwardCustom (nn .Module ):
36+ class FeedForwardCudaCustom (nn .Module ):
3737 def __init__ (self , w1 , w2 , w3 ):
3838 super ().__init__ ()
3939 self .w1 = w1
@@ -47,7 +47,7 @@ def forward(self, x):
4747
4848 for name , child in module .named_children ():
4949 if isinstance (child , FeedForward ):
50- setattr (module , name , FeedForwardCustom (child .w1 , child .w2 , child .w3 ))
50+ setattr (module , name , FeedForwardCudaCustom (child .w1 , child .w2 , child .w3 ))
5151 else :
5252 _inject_fast_hadamard_transform_cuda_for_spin_quant (child )
5353
@@ -59,6 +59,38 @@ def inject_fast_hadamard_transform_cuda_for_spin_quant(
5959 return module
6060
6161
62+ def _inject_fast_hadamard_transform_native_for_spin_quant (module : torch .nn .Module ):
63+ """
64+ SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer.
65+ R3 needs to be injected as well when KV cache quantization is enabled.
66+ """
67+
68+ class FeedForwardNativeCustom (nn .Module ):
69+ def __init__ (self , w1 , w2 , w3 ):
70+ super ().__init__ ()
71+ self .w1 = w1
72+ self .w2 = w2
73+ self .w3 = w3
74+
75+ def forward (self , x ):
76+ return self .w2 (
77+ torch .ops .llama .fast_hadamard_transform (F .silu (self .w1 (x )) * self .w3 (x ))
78+ )
79+
80+ for name , child in module .named_children ():
81+ if isinstance (child , FeedForward ):
82+ setattr (module , name , FeedForwardNativeCustom (child .w1 , child .w2 , child .w3 ))
83+ else :
84+ _inject_fast_hadamard_transform_native_for_spin_quant (child )
85+
86+
87+ def inject_fast_hadamard_transform_native_for_spin_quant (
88+ module : torch .nn .Module ,
89+ ) -> torch .nn .Module :
90+ _inject_fast_hadamard_transform_native_for_spin_quant (module )
91+ return module
92+
93+
6294def _replace_linear_with_linear_8da4w_for_spin_quant (
6395 module : torch .nn .Module ,
6496 checkpoint : Any ,
0 commit comments