@@ -485,7 +485,7 @@ def eval(self):
485485 self .weight .data += (self .lora_B @ self .lora_A ) * self .scaling
486486 self .merged = True
487487
488- def forward (self , x : torch .Tensor ):
488+ def forward (self , x : torch .Tensor , ** kwargs ):
489489 if self .r > 0 and not self .merged and self .is_activated ():
490490 result = nn .Embedding .forward (self , x )
491491 if self .r > 0 :
@@ -572,7 +572,7 @@ def T(w):
572572 self .weight .data += T (self .lora_B @ self .lora_A ) * self .scaling
573573 self .merged = True
574574
575- def forward (self , x : torch .Tensor ):
575+ def forward (self , x : torch .Tensor , ** kwargs ):
576576
577577 def T (w ):
578578 return w .T if self .fan_in_fan_out else w
@@ -692,7 +692,7 @@ def T(w):
692692 self .weight .data += self .zero_pad (T (delta_w * self .scaling ))
693693 self .merged = True
694694
695- def forward (self , x : torch .Tensor ):
695+ def forward (self , x : torch .Tensor , ** kwargs ):
696696
697697 def T (w ):
698698 return w .T if self .fan_in_fan_out else w
@@ -778,7 +778,7 @@ def eval(self):
778778 self .weight .shape ) * self .scaling
779779 self .merged = True
780780
781- def forward (self , x : torch .Tensor ):
781+ def forward (self , x : torch .Tensor , ** kwargs ):
782782 if self .r > 0 and not self .merged and self .is_activated ():
783783 return F .conv2d (
784784 x ,
0 commit comments