-
Notifications
You must be signed in to change notification settings - Fork 433
Description
I tried to generate samples in Colab and everything works except that I had to change this line of code in /cm/unet.py, clearing out factory_kwargs.
Not sure if this is a bug or I did something wrong. This is how I ran it: https://github.com/JonathanFly/consistency_models_colab_notebook/blob/main/Consistency_Models_Make_Samples.ipynb
class QKVFlashAttention(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
batch_first=True,
attention_dropout=0.0,
causal=False,
device=None,
dtype=None,
**kwargs,
) -> None:
from einops import rearrange
from flash_attn.flash_attention import FlashAttention
assert batch_first
#factory_kwargs = {"device": device, "dtype": dtype}
factory_kwargs = {}
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.causal = causal
ChintanTrivedi, georgestein, shiyegao and ShashwatNigam99
Metadata
Metadata
Assignees
Labels
No labels