Skip to content

Commit c41dfff

Browse files
authored
Merge branch 'main' into dreambooth-lora-flux-exploration
2 parents a4429e0 + a3e8d3f commit c41dfff

File tree

5 files changed

+325
-87
lines changed

5 files changed

+325
-87
lines changed

examples/community/hd_painter.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -898,13 +898,16 @@ class GaussianSmoothing(nn.Module):
898898
Apply gaussian smoothing on a
899899
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
900900
in the input using a depthwise convolution.
901-
Arguments:
902-
channels (int, sequence): Number of channels of the input tensors. Output will
903-
have this number of channels as well.
904-
kernel_size (int, sequence): Size of the gaussian kernel.
905-
sigma (float, sequence): Standard deviation of the gaussian kernel.
906-
dim (int, optional): The number of dimensions of the data.
907-
Default value is 2 (spatial).
901+
902+
Args:
903+
channels (`int` or `sequence`):
904+
Number of channels of the input tensors. The output will have this number of channels as well.
905+
kernel_size (`int` or `sequence`):
906+
Size of the Gaussian kernel.
907+
sigma (`float` or `sequence`):
908+
Standard deviation of the Gaussian kernel.
909+
dim (`int`, *optional*, defaults to `2`):
910+
The number of dimensions of the data. Default is 2 (spatial dimensions).
908911
"""
909912

910913
def __init__(self, channels, kernel_size, sigma, dim=2):
@@ -944,10 +947,14 @@ def __init__(self, channels, kernel_size, sigma, dim=2):
944947
def forward(self, input):
945948
"""
946949
Apply gaussian filter to input.
947-
Arguments:
948-
input (torch.Tensor): Input to apply gaussian filter on.
950+
951+
Args:
952+
input (`torch.Tensor` of shape `(N, C, H, W)`):
953+
Input to apply Gaussian filter on.
954+
949955
Returns:
950-
filtered (torch.Tensor): Filtered output.
956+
`torch.Tensor`:
957+
The filtered output tensor with the same shape as the input.
951958
"""
952959
return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups, padding="same")
953960

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,6 @@ def encode_prompt(
985985
text_input_ids_list=None,
986986
):
987987
prompt = [prompt] if isinstance(prompt, str) else prompt
988-
batch_size = len(prompt)
989988
dtype = text_encoders[0].dtype
990989

991990
pooled_prompt_embeds = _encode_prompt_with_clip(
@@ -1007,8 +1006,7 @@ def encode_prompt(
10071006
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
10081007
)
10091008

1010-
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
1011-
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
1009+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
10121010

10131011
return prompt_embeds, pooled_prompt_embeds, text_ids
10141012

0 commit comments

Comments
 (0)