5353 help = "Specify vulkan target triple or rocm/cuda target device." ,
5454)
5555parser .add_argument ("--vulkan_max_allocation" , type = str , default = "4294967296" )
56+ parser .add_argument ('--controlled' , dest = 'controlled' , action = 'store_true' , help = "Whether or not to use controlled unet (for use with controlnet)" )
57+ parser .add_argument ('--no-controlled' , dest = 'controlled' , action = 'store_false' , help = "Whether or not to use controlled unet (for use with controlnet)" )
58+ parser .set_defaults (controlled = False )
5659
5760
5861class UnetModel (torch .nn .Module ):
59- def __init__ (self , hf_model_name , hf_auth_token ):
62+ def __init__ (self , hf_model_name , hf_auth_token , is_controlled ):
6063 super ().__init__ ()
6164 self .unet = UNet2DConditionModel .from_pretrained (
6265 hf_model_name ,
6366 subfolder = "unet" ,
6467 token = hf_auth_token ,
6568 )
6669 self .guidance_scale = 7.5
70+ if is_controlled :
71+ self .forward = self .forward_controlled
72+ else :
73+ self .forward = self .forward_default
6774
68- def forward (self , sample , timestep , encoder_hidden_states ):
75+ def forward_default (self , sample , timestep , encoder_hidden_states ):
6976 samples = torch .cat ([sample ] * 2 )
7077 unet_out = self .unet .forward (
7178 samples , timestep , encoder_hidden_states , return_dict = False
@@ -76,6 +83,65 @@ def forward(self, sample, timestep, encoder_hidden_states):
7683 )
7784 return noise_pred
7885
86+ def forward_controlled (
87+ self ,
88+ sample ,
89+ timestep ,
90+ encoder_hidden_states ,
91+ control1 ,
92+ control2 ,
93+ control3 ,
94+ control4 ,
95+ control5 ,
96+ control6 ,
97+ control7 ,
98+ control8 ,
99+ control9 ,
100+ control10 ,
101+ control11 ,
102+ control12 ,
103+ control13 ,
104+ scale1 ,
105+ scale2 ,
106+ scale3 ,
107+ scale4 ,
108+ scale5 ,
109+ scale6 ,
110+ scale7 ,
111+ scale8 ,
112+ scale9 ,
113+ scale10 ,
114+ scale11 ,
115+ scale12 ,
116+ scale13 ,
117+ ):
118+ db_res_samples = tuple (
119+ [
120+ control1 * scale1 ,
121+ control2 * scale2 ,
122+ control3 * scale3 ,
123+ control4 * scale4 ,
124+ control5 * scale5 ,
125+ control6 * scale6 ,
126+ control7 * scale7 ,
127+ control8 * scale8 ,
128+ control9 * scale9 ,
129+ control10 * scale10 ,
130+ control11 * scale11 ,
131+ control12 * scale12 ,
132+ ]
133+ )
134+ mb_res_samples = control13 * scale13
135+ samples = torch .cat ([sample ] * 2 )
136+ unet_out = self .unet .forward (
137+ samples , timestep , encoder_hidden_states , down_block_additional_residuals = db_res_samples , mid_block_additional_residual = mb_res_samples , return_dict = False
138+ )[0 ]
139+ noise_pred_uncond , noise_pred_text = unet_out .chunk (2 )
140+ noise_pred = noise_pred_uncond + self .guidance_scale * (
141+ noise_pred_text - noise_pred_uncond
142+ )
143+ return noise_pred
144+
79145
80146def export_unet_model (
81147 unet_model ,
@@ -90,6 +156,7 @@ def export_unet_model(
90156 device = None ,
91157 target_triple = None ,
92158 max_alloc = None ,
159+ is_controlled = False ,
93160):
94161 mapper = {}
95162 utils .save_external_weights (
@@ -100,7 +167,7 @@ def export_unet_model(
100167 if hf_model_name == "stabilityai/stable-diffusion-2-1-base" :
101168 encoder_hidden_states_sizes = (2 , 77 , 1024 )
102169
103- sample = (batch_size , unet_model .unet .config .in_channels , height // 8 , width // 8 )
170+ sample = (batch_size , unet_model .unet .config .in_channels , height , width )
104171
105172 class CompiledUnet (CompiledModule ):
106173 if external_weights :
@@ -120,8 +187,85 @@ def main(
120187 ):
121188 return jittable (unet_model .forward )(sample , timestep , encoder_hidden_states )
122189
190+ class CompiledControlledUnet (CompiledModule ):
191+ if external_weights :
192+ params = export_parameters (
193+ unet_model , external = True , external_scope = "" , name_mapper = mapper .get
194+ )
195+ else :
196+ params = export_parameters (unet_model )
197+
198+ def main (
199+ self ,
200+ sample = AbstractTensor (* sample , dtype = torch .float32 ),
201+ timestep = AbstractTensor (1 , dtype = torch .float32 ),
202+ encoder_hidden_states = AbstractTensor (
203+ * encoder_hidden_states_sizes , dtype = torch .float32
204+ ),
205+ control1 = AbstractTensor (2 , 320 , height , width , dtype = torch .float32 ),
206+ control2 = AbstractTensor (2 , 320 , height , width , dtype = torch .float32 ),
207+ control3 = AbstractTensor (2 , 320 , height , width , dtype = torch .float32 ),
208+ control4 = AbstractTensor (2 , 320 , height // 2 , width // 2 , dtype = torch .float32 ),
209+ control5 = AbstractTensor (2 , 640 , height // 2 , width // 2 , dtype = torch .float32 ),
210+ control6 = AbstractTensor (2 , 640 , height // 2 , width // 2 , dtype = torch .float32 ),
211+ control7 = AbstractTensor (2 , 640 , height // 4 , width // 4 , dtype = torch .float32 ),
212+ control8 = AbstractTensor (2 , 1280 , height // 4 , width // 4 , dtype = torch .float32 ),
213+ control9 = AbstractTensor (2 , 1280 , height // 4 , width // 4 , dtype = torch .float32 ),
214+ control10 = AbstractTensor (2 , 1280 , height // 8 , width // 8 , dtype = torch .float32 ),
215+ control11 = AbstractTensor (2 , 1280 , height // 8 , width // 8 , dtype = torch .float32 ),
216+ control12 = AbstractTensor (2 , 1280 , height // 8 , width // 8 , dtype = torch .float32 ),
217+ control13 = AbstractTensor (2 , 1280 , height // 8 , width // 8 , dtype = torch .float32 ),
218+ scale1 = AbstractTensor (1 , dtype = torch .float32 ),
219+ scale2 = AbstractTensor (1 , dtype = torch .float32 ),
220+ scale3 = AbstractTensor (1 , dtype = torch .float32 ),
221+ scale4 = AbstractTensor (1 , dtype = torch .float32 ),
222+ scale5 = AbstractTensor (1 , dtype = torch .float32 ),
223+ scale6 = AbstractTensor (1 , dtype = torch .float32 ),
224+ scale7 = AbstractTensor (1 , dtype = torch .float32 ),
225+ scale8 = AbstractTensor (1 , dtype = torch .float32 ),
226+ scale9 = AbstractTensor (1 , dtype = torch .float32 ),
227+ scale10 = AbstractTensor (1 , dtype = torch .float32 ),
228+ scale11 = AbstractTensor (1 , dtype = torch .float32 ),
229+ scale12 = AbstractTensor (1 , dtype = torch .float32 ),
230+ scale13 = AbstractTensor (1 , dtype = torch .float32 ),
231+ ):
232+ return jittable (unet_model .forward )(
233+ sample ,
234+ timestep ,
235+ encoder_hidden_states ,
236+ control1 ,
237+ control2 ,
238+ control3 ,
239+ control4 ,
240+ control5 ,
241+ control6 ,
242+ control7 ,
243+ control8 ,
244+ control9 ,
245+ control10 ,
246+ control11 ,
247+ control12 ,
248+ control13 ,
249+ scale1 ,
250+ scale2 ,
251+ scale3 ,
252+ scale4 ,
253+ scale5 ,
254+ scale6 ,
255+ scale7 ,
256+ scale8 ,
257+ scale9 ,
258+ scale10 ,
259+ scale11 ,
260+ scale12 ,
261+ scale13 ,
262+ )
263+
123264 import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
124- inst = CompiledUnet (context = Context (), import_to = import_to )
265+ if is_controlled :
266+ inst = CompiledControlledUnet (context = Context (), import_to = import_to )
267+ else :
268+ inst = CompiledUnet (context = Context (), import_to = import_to )
125269
126270 module_str = str (CompiledModule .get_mlir_module (inst ))
127271 safe_name = utils .create_safe_name (hf_model_name , "-unet" )
@@ -134,8 +278,9 @@ def main(
134278if __name__ == "__main__" :
135279 args = parser .parse_args ()
136280 unet_model = UnetModel (
137- args .hf_model_name ,
281+ args .hf_model_name if not args . controlled else "CompVis/stable-diffusion-v1-4" ,
138282 args .hf_auth_token ,
283+ args .controlled ,
139284 )
140285 mod_str = export_unet_model (
141286 unet_model ,
@@ -150,6 +295,7 @@ def main(
150295 args .device ,
151296 args .iree_target_triple ,
152297 args .vulkan_max_allocation ,
298+ args .controlled ,
153299 )
154300 safe_name = utils .create_safe_name (args .hf_model_name , "-unet" )
155301 with open (f"{ safe_name } .mlir" , "w+" ) as f :
0 commit comments