Skip to content

Conversation

@leisuzz
Copy link
Contributor

@leisuzz leisuzz commented Oct 10, 2024

What does this PR do?

Improve the performance (FPS) while training, and suitable for NPU computing.
Selection for free memory for CUDA or NPU
Add FlashAttention for NPU in attention processor

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
return {"model_input": model_input.cpu()}
return {"model_input": accelerator.gather(model_input)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if my question is stupid, but why do we need to gather here? Doesn't this cause a sync between all ranks, as opposed to npu to cpu memory sync, making it slower overall?

add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this change? For improvements to the library, feel free to open a separate PR :)

@a-r-r-o-w
Copy link
Contributor

cc @sayakpaul for training scripts and as original author for the sdxl script

@sayakpaul
Copy link
Member

@leisuzz any reason for closing the PR?

@leisuzz
Copy link
Contributor Author

leisuzz commented Oct 11, 2024

@sayakpaul Sorry, some conflicts for the commit, I created two new PRs #9642 and #9640

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants