4141torch .backends .cuda .matmul .allow_tf32 = False
4242
4343
44- def create_lora_layers (model ):
44+ def create_lora_layers (model , mock_weights : bool = True ):
4545 lora_attn_procs = {}
4646 for name in model .attn_processors .keys ():
4747 cross_attention_dim = None if name .endswith ("attn1.processor" ) else model .config .cross_attention_dim
@@ -57,12 +57,13 @@ def create_lora_layers(model):
5757 lora_attn_procs [name ] = LoRAAttnProcessor (hidden_size = hidden_size , cross_attention_dim = cross_attention_dim )
5858 lora_attn_procs [name ] = lora_attn_procs [name ].to (model .device )
5959
60- # add 1 to weights to mock trained weights
61- with torch .no_grad ():
62- lora_attn_procs [name ].to_q_lora .up .weight += 1
63- lora_attn_procs [name ].to_k_lora .up .weight += 1
64- lora_attn_procs [name ].to_v_lora .up .weight += 1
65- lora_attn_procs [name ].to_out_lora .up .weight += 1
60+ if mock_weights :
61+ # add 1 to weights to mock trained weights
62+ with torch .no_grad ():
63+ lora_attn_procs [name ].to_q_lora .up .weight += 1
64+ lora_attn_procs [name ].to_k_lora .up .weight += 1
65+ lora_attn_procs [name ].to_v_lora .up .weight += 1
66+ lora_attn_procs [name ].to_out_lora .up .weight += 1
6667
6768 return lora_attn_procs
6869
@@ -378,26 +379,7 @@ def test_lora_processors(self):
378379 with torch .no_grad ():
379380 sample1 = model (** inputs_dict ).sample
380381
381- lora_attn_procs = {}
382- for name in model .attn_processors .keys ():
383- cross_attention_dim = None if name .endswith ("attn1.processor" ) else model .config .cross_attention_dim
384- if name .startswith ("mid_block" ):
385- hidden_size = model .config .block_out_channels [- 1 ]
386- elif name .startswith ("up_blocks" ):
387- block_id = int (name [len ("up_blocks." )])
388- hidden_size = list (reversed (model .config .block_out_channels ))[block_id ]
389- elif name .startswith ("down_blocks" ):
390- block_id = int (name [len ("down_blocks." )])
391- hidden_size = model .config .block_out_channels [block_id ]
392-
393- lora_attn_procs [name ] = LoRAAttnProcessor (hidden_size = hidden_size , cross_attention_dim = cross_attention_dim )
394-
395- # add 1 to weights to mock trained weights
396- with torch .no_grad ():
397- lora_attn_procs [name ].to_q_lora .up .weight += 1
398- lora_attn_procs [name ].to_k_lora .up .weight += 1
399- lora_attn_procs [name ].to_v_lora .up .weight += 1
400- lora_attn_procs [name ].to_out_lora .up .weight += 1
382+ lora_attn_procs = create_lora_layers (model )
401383
402384 # make sure we can set a list of attention processors
403385 model .set_attn_processor (lora_attn_procs )
@@ -465,28 +447,7 @@ def test_lora_save_load_safetensors(self):
465447 with torch .no_grad ():
466448 old_sample = model (** inputs_dict ).sample
467449
468- lora_attn_procs = {}
469- for name in model .attn_processors .keys ():
470- cross_attention_dim = None if name .endswith ("attn1.processor" ) else model .config .cross_attention_dim
471- if name .startswith ("mid_block" ):
472- hidden_size = model .config .block_out_channels [- 1 ]
473- elif name .startswith ("up_blocks" ):
474- block_id = int (name [len ("up_blocks." )])
475- hidden_size = list (reversed (model .config .block_out_channels ))[block_id ]
476- elif name .startswith ("down_blocks" ):
477- block_id = int (name [len ("down_blocks." )])
478- hidden_size = model .config .block_out_channels [block_id ]
479-
480- lora_attn_procs [name ] = LoRAAttnProcessor (hidden_size = hidden_size , cross_attention_dim = cross_attention_dim )
481- lora_attn_procs [name ] = lora_attn_procs [name ].to (model .device )
482-
483- # add 1 to weights to mock trained weights
484- with torch .no_grad ():
485- lora_attn_procs [name ].to_q_lora .up .weight += 1
486- lora_attn_procs [name ].to_k_lora .up .weight += 1
487- lora_attn_procs [name ].to_v_lora .up .weight += 1
488- lora_attn_procs [name ].to_out_lora .up .weight += 1
489-
450+ lora_attn_procs = create_lora_layers (model )
490451 model .set_attn_processor (lora_attn_procs )
491452
492453 with torch .no_grad ():
@@ -518,21 +479,7 @@ def test_lora_save_safetensors_load_torch(self):
518479 model = self .model_class (** init_dict )
519480 model .to (torch_device )
520481
521- lora_attn_procs = {}
522- for name in model .attn_processors .keys ():
523- cross_attention_dim = None if name .endswith ("attn1.processor" ) else model .config .cross_attention_dim
524- if name .startswith ("mid_block" ):
525- hidden_size = model .config .block_out_channels [- 1 ]
526- elif name .startswith ("up_blocks" ):
527- block_id = int (name [len ("up_blocks." )])
528- hidden_size = list (reversed (model .config .block_out_channels ))[block_id ]
529- elif name .startswith ("down_blocks" ):
530- block_id = int (name [len ("down_blocks." )])
531- hidden_size = model .config .block_out_channels [block_id ]
532-
533- lora_attn_procs [name ] = LoRAAttnProcessor (hidden_size = hidden_size , cross_attention_dim = cross_attention_dim )
534- lora_attn_procs [name ] = lora_attn_procs [name ].to (model .device )
535-
482+ lora_attn_procs = create_lora_layers (model , mock_weights = False )
536483 model .set_attn_processor (lora_attn_procs )
537484 # Saving as torch, properly reloads with directly filename
538485 with tempfile .TemporaryDirectory () as tmpdirname :
@@ -553,21 +500,7 @@ def test_lora_save_torch_force_load_safetensors_error(self):
553500 model = self .model_class (** init_dict )
554501 model .to (torch_device )
555502
556- lora_attn_procs = {}
557- for name in model .attn_processors .keys ():
558- cross_attention_dim = None if name .endswith ("attn1.processor" ) else model .config .cross_attention_dim
559- if name .startswith ("mid_block" ):
560- hidden_size = model .config .block_out_channels [- 1 ]
561- elif name .startswith ("up_blocks" ):
562- block_id = int (name [len ("up_blocks." )])
563- hidden_size = list (reversed (model .config .block_out_channels ))[block_id ]
564- elif name .startswith ("down_blocks" ):
565- block_id = int (name [len ("down_blocks." )])
566- hidden_size = model .config .block_out_channels [block_id ]
567-
568- lora_attn_procs [name ] = LoRAAttnProcessor (hidden_size = hidden_size , cross_attention_dim = cross_attention_dim )
569- lora_attn_procs [name ] = lora_attn_procs [name ].to (model .device )
570-
503+ lora_attn_procs = create_lora_layers (model , mock_weights = False )
571504 model .set_attn_processor (lora_attn_procs )
572505 # Saving as torch, properly reloads with directly filename
573506 with tempfile .TemporaryDirectory () as tmpdirname :
0 commit comments