Skip to content

PyTorch 1.12 and flash-attn==0.2.8 are not compatible. #27

@heheda12345

Description

@heheda12345

Thanks for your great work! I am trying to reproduce the latency tests with the scripts in Dejavu/benchmarks folder. I've installed the recommended PyTorch 1.12 and flash-attn==0.2.8. But these two libraries are not compatible. I get the following error caused by this line in flash attention. It calls get_global_rank that is not available in PyTorch 1.12 and only available in newer PyTorch. What library version should I use to reproduce the results?

p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
AttributeError: module 'torch.distributed' has no attribute 'get_global_rank'

Plus, the scripts use a weight called "full.pt". It is not in OPT's huggingface repo. How should I get this file?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions