File tree Expand file tree Collapse file tree 1 file changed +18
-15
lines changed
intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules Expand file tree Collapse file tree 1 file changed +18
-15
lines changed Original file line number Diff line number Diff line change @@ -372,14 +372,15 @@ def forward(
372372 batch_size = hidden_states .shape [0 ]
373373
374374 if self .use_layer_norm :
375- # norm_hidden_states = self.norm1(hidden_states)
376- norm_hidden_states = torch .ops .torch_ipex .fast_layer_norm (
377- hidden_states ,
378- self .norm1 .normalized_shape ,
379- self .norm1 .weight ,
380- self .norm1 .bias ,
381- self .norm1 .eps ,
382- )
375+ # native layernorm performs better than `fast_layer_norm` in stable diffusion
376+ norm_hidden_states = self .norm1 (hidden_states )
377+ # norm_hidden_states = torch.ops.torch_ipex.fast_layer_norm(
378+ # hidden_states,
379+ # self.norm1.normalized_shape,
380+ # self.norm1.weight,
381+ # self.norm1.bias,
382+ # self.norm1.eps,
383+ # )
383384 else :
384385 raise ValueError ("Incorrect norm used" )
385386
@@ -419,13 +420,15 @@ def forward(
419420 # 3. Cross-Attention
420421 if self .attn2 is not None :
421422 if self .use_layer_norm :
422- norm_hidden_states = torch .ops .torch_ipex .fast_layer_norm (
423- hidden_states ,
424- self .norm2 .normalized_shape ,
425- self .norm2 .weight ,
426- self .norm2 .bias ,
427- self .norm2 .eps ,
428- )
423+ # native layernorm performs better than `fast_layer_norm` in stable diffusion
424+ norm_hidden_states = self .norm2 (hidden_states )
425+ # norm_hidden_states = torch.ops.torch_ipex.fast_layer_norm(
426+ # hidden_states,
427+ # self.norm2.normalized_shape,
428+ # self.norm2.weight,
429+ # self.norm2.bias,
430+ # self.norm2.eps,
431+ # )
429432 else :
430433 raise ValueError ("Incorrect norm" )
431434
You can’t perform that action at this time.
0 commit comments