-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[BlockInfo] Index to Tensor #11198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BlockInfo] Index to Tensor #11198
Conversation
|
@kijai could you check if this fixes the torch.compile graph breaks? |
|
Yes, tested on HunyuanVideo 1.5 and went from 54 recompiles on first step to none. |
|
Sweet, I'll merge it in after stable! |
|
@Haoming02 could you make the tensors be on the CPU instead? |
|
@kijai hey, could you retest this to make sure that the CPU tensors work fine with torch.compile? |
While creating it on cpu is fine, actually using a cpu tensor is a bit problematic as with inductor it requires cpu compile support, which then requires more compiler libraries installed than at least the current Triton-windows package includes, and if you don't have them it just errors out. You can of course cast it to gpu before using it, it does then create |
|
damn, that is definitely annoying. the goal would be that the block index is used inside attention code and is only compared. To even do this comparison, would it be better to compare against the GPU tensor, or CPU tensor? if GPU tensor would be easier, then we can edit this PR to use GPU. Alternatively, is there a way to tell torch compile to ignore transformer_options or at least that one key? |
|
I'm also waiting for blockinfo to be merged so I can implement things like RadialAttn more easily. I think it's ok to create a scalar tensor on GPU and do comparison with it. It's not the first time we do it, for example in https://github.com/comfyanonymous/ComfyUI/blob/c5a47a16924e1be96241553a1448b298e57e50a1/comfy/extra_samplers/uni_pc.py#L785 |
block_indexfromintto atorch.Tensor, to supporttorch.compiletorch.uint8input's devicecpu