@@ -483,6 +483,7 @@ class LoRAConfig(AdapterConfig):
483483 Place a trainable gating module besides the added parameter module to control module activation. This is
484484 e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using
485485 `merge_adapter()`.
486+ dtype (str, optional): torch dtype for reparametrization tensors. Defaults to None.
486487 """
487488
488489 architecture : Optional [str ] = "lora"
@@ -499,6 +500,7 @@ class LoRAConfig(AdapterConfig):
499500 composition_mode : str = "add"
500501 init_weights : str = "lora"
501502 use_gating : bool = False
503+ dtype : Optional [str ] = None
502504
503505
504506@dataclass (eq = False )
@@ -521,6 +523,7 @@ class IA3Config(LoRAConfig):
521523 composition_mode : str = "scale"
522524 init_weights : str = "ia3"
523525 use_gating : bool = False
526+ dtype : Optional [str ] = None
524527
525528
526529@dataclass (eq = False )
@@ -540,6 +543,7 @@ class ReftConfig(AdapterConfig):
540543 subtract_projection (bool): If True, subtract the projection of the input.
541544 dropout (float): The dropout rate used in the intervention layer.
542545 non_linearity (str): The activation function used in the intervention layer.
546+ dtype (str, optional): torch dtype for intervention tensors. Defaults to None.
543547 """
544548
545549 layers : Union [Literal ["all" ], List [int ]]
@@ -551,6 +555,7 @@ class ReftConfig(AdapterConfig):
551555 subtract_projection = True
552556 dropout : float = 0.05
553557 non_linearity : Optional [str ] = None
558+ dtype : Optional [str ] = None
554559
555560 architecture : str = "reft"
556561
@@ -569,6 +574,7 @@ class LoReftConfig(ReftConfig):
569574 r : int = 1
570575 orthogonality : bool = True
571576 tied_weights : bool = False
577+ dtype : Optional [str ] = None
572578
573579
574580@dataclass (eq = False )
@@ -583,6 +589,7 @@ class NoReftConfig(ReftConfig):
583589 r : int = 1
584590 orthogonality : bool = False
585591 tied_weights : bool = False
592+ dtype : Optional [str ] = None
586593
587594
588595@dataclass (eq = False )
@@ -598,6 +605,7 @@ class DiReftConfig(ReftConfig):
598605 orthogonality : bool = False
599606 tied_weights : bool = False
600607 subtract_projection = False
608+ dtype : Optional [str ] = None
601609
602610
603611class ConfigUnion (AdapterConfig ):
0 commit comments