1
1
from __future__ import annotations
2
2
3
3
from contextlib import contextmanager
4
- from typing import TYPE_CHECKING , Dict , Optional , Set , Tuple
4
+ from typing import TYPE_CHECKING , Dict , Tuple
5
5
6
6
import torch
7
7
from diffusers import UNet2DConditionModel
@@ -28,97 +28,84 @@ def __init__(
28
28
self ._weight = weight
29
29
30
30
@contextmanager
31
- def patch_unet (
32
- self , unet : UNet2DConditionModel , cached_weights : Optional [Dict [str , torch .Tensor ]] = None
33
- ) -> Tuple [Set [str ], Dict [str , torch .Tensor ]]:
31
+ def patch_unet (self , unet : UNet2DConditionModel , original_weights : Dict [str , torch .Tensor ]):
34
32
lora_model = self ._node_context .models .load (self ._model_id ).model
35
- modified_cached_weights , modified_weights = self .patch_model (
33
+ self .patch_model (
36
34
model = unet ,
37
35
prefix = "lora_unet_" ,
38
36
lora = lora_model ,
39
37
lora_weight = self ._weight ,
40
- cached_weights = cached_weights ,
38
+ original_weights = original_weights ,
41
39
)
42
40
del lora_model
43
41
44
- yield modified_cached_weights , modified_weights
42
+ yield
45
43
46
44
@classmethod
45
+ @torch .no_grad ()
47
46
def patch_model (
48
47
cls ,
49
48
model : torch .nn .Module ,
50
49
prefix : str ,
51
50
lora : LoRAModelRaw ,
52
51
lora_weight : float ,
53
- cached_weights : Optional [ Dict [str , torch .Tensor ]] = None ,
54
- ) -> Tuple [ Set [ str ], Dict [ str , torch . Tensor ]] :
52
+ original_weights : Dict [str , torch .Tensor ],
53
+ ):
55
54
"""
56
55
Apply one or more LoRAs to a model.
57
56
:param model: The model to patch.
58
57
:param lora: LoRA model to patch in.
59
58
:param lora_weight: LoRA patch weight.
60
59
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
61
- :param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
60
+ :param original_weights: TODO:
62
61
"""
63
- if cached_weights is None :
64
- cached_weights = {}
65
-
66
- modified_weights : Dict [str , torch .Tensor ] = {}
67
- modified_cached_weights : Set [str ] = set ()
68
- with torch .no_grad ():
69
- # assert lora.device.type == "cpu"
70
- for layer_key , layer in lora .layers .items ():
71
- if not layer_key .startswith (prefix ):
72
- continue
73
-
74
- # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
75
- # should be improved in the following ways:
76
- # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
77
- # LoRA model is applied.
78
- # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
79
- # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
80
- # weights to have valid keys.
81
- assert isinstance (model , torch .nn .Module )
82
- module_key , module = cls ._resolve_lora_key (model , layer_key , prefix )
83
-
84
- # All of the LoRA weight calculations will be done on the same device as the module weight.
85
- # (Performance will be best if this is a CUDA device.)
86
- device = module .weight .device
87
- dtype = module .weight .dtype
88
-
89
- layer_scale = layer .alpha / layer .rank if (layer .alpha and layer .rank ) else 1.0
90
-
91
- # We intentionally move to the target device first, then cast. Experimentally, this was found to
92
- # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
93
- # same thing in a single call to '.to(...)'.
94
- layer .to (device = device )
95
- layer .to (dtype = torch .float32 )
96
-
97
- # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
98
- # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
99
- for param_name , lora_param_weight in layer .get_parameters (module ).items ():
100
- param_key = module_key + "." + param_name
101
- module_param = module .get_parameter (param_name )
102
-
103
- # save original weight
104
- if param_key not in modified_cached_weights and param_key not in modified_weights :
105
- if param_key in cached_weights :
106
- modified_cached_weights .add (param_key )
107
- else :
108
- modified_weights [param_key ] = module_param .detach ().to (
109
- device = TorchDevice .CPU_DEVICE , copy = True
110
- )
111
-
112
- if module_param .shape != lora_param_weight .shape :
113
- # TODO: debug on lycoris
114
- lora_param_weight = lora_param_weight .reshape (module_param .shape )
115
-
116
- lora_param_weight *= lora_weight * layer_scale
117
- module_param += lora_param_weight .to (dtype = dtype )
118
-
119
- layer .to (device = TorchDevice .CPU_DEVICE )
120
-
121
- return modified_cached_weights , modified_weights
62
+
63
+ # assert lora.device.type == "cpu"
64
+ for layer_key , layer in lora .layers .items ():
65
+ if not layer_key .startswith (prefix ):
66
+ continue
67
+
68
+ # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
69
+ # should be improved in the following ways:
70
+ # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
71
+ # LoRA model is applied.
72
+ # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
73
+ # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
74
+ # weights to have valid keys.
75
+ assert isinstance (model , torch .nn .Module )
76
+ module_key , module = cls ._resolve_lora_key (model , layer_key , prefix )
77
+
78
+ # All of the LoRA weight calculations will be done on the same device as the module weight.
79
+ # (Performance will be best if this is a CUDA device.)
80
+ device = module .weight .device
81
+ dtype = module .weight .dtype
82
+
83
+ layer_scale = layer .alpha / layer .rank if (layer .alpha and layer .rank ) else 1.0
84
+
85
+ # We intentionally move to the target device first, then cast. Experimentally, this was found to
86
+ # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
87
+ # same thing in a single call to '.to(...)'.
88
+ layer .to (device = device )
89
+ layer .to (dtype = torch .float32 )
90
+
91
+ # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
92
+ # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
93
+ for param_name , lora_param_weight in layer .get_parameters (module ).items ():
94
+ param_key = module_key + "." + param_name
95
+ module_param = module .get_parameter (param_name )
96
+
97
+ # save original weight
98
+ if param_key not in original_weights :
99
+ original_weights [param_key ] = module_param .detach ().to (device = TorchDevice .CPU_DEVICE , copy = True )
100
+
101
+ if module_param .shape != lora_param_weight .shape :
102
+ # TODO: debug on lycoris
103
+ lora_param_weight = lora_param_weight .reshape (module_param .shape )
104
+
105
+ lora_param_weight *= lora_weight * layer_scale
106
+ module_param += lora_param_weight .to (dtype = dtype )
107
+
108
+ layer .to (device = TorchDevice .CPU_DEVICE )
122
109
123
110
@staticmethod
124
111
def _resolve_lora_key (model : torch .nn .Module , lora_key : str , prefix : str ) -> Tuple [str , torch .nn .Module ]:
0 commit comments