Skip to content

Commit a80a963

Browse files
authored
attention mask & bias (#1)
* add support for attn mask * add mask operation * add mask operation * add mask operation * add interface * add mask support * add mask supprt * fix up * add bias * add template * add test * clean * clean code * add mask load * add mask test * fix forward bugs * add test * add mask in backward * add test case * add bias * add mask * add bias test * fix test case * add without mask test * add kernel test * add ds save * fix interface * add test * fix dbias * add bias support * add mask shape * add test * add support * fix bf16 and mask shape * fix mask head=1 shape * add dump * to fix len 512 * add test * fix seqlen greater than 256 * fix bias seqlen * add constexpr * add const expr for bwd * add benchmark * add test tools * add script * add cross attention * add cross attn * fix bugs * remove test tools * clean fmha_api.cpp * clean fmha_dgrad_fp16_kernel_loop.sm80.cu * clean fmha_dgrad_kernel_1xN_loop.h * clean fmha_fprop_fp16_kernel.sm80.cu * clean fmha_fprop_kernel_1xN.h * cleangmem_tile.h * clean softmax.h * restore test_flash_attn.py * clean gmem_tile.h * fix fmha_fprop_kernel_1xN.h * fix fmha_dgrad_kernel_1xN_loop.h * rename has_attn to has_attn_mask, has_bias to has_attn_bias * fix fmha_fprop_kernel_1xN.h * rename has_attn to has_attn_mask, has_bias to has_attn_bias * remove useless benchmark code * add declaration * remove useless comments * remove useless comments * add timeout * add default timeout for build wheel * remove timeout * reduce build worker for workflow oom
1 parent f515c77 commit a80a963

17 files changed

+1201
-104
lines changed

.github/workflows/publish.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ jobs:
112112
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
113113
export CUDA_INSTALL_DIR=/usr/local/cuda-11.3$CUDA_INSTALL_DIR
114114
pip install wheel
115-
python setup.py bdist_wheel --dist-dir=dist
115+
MAX_JOBS=1 python setup.py bdist_wheel --dist-dir=dist
116116
tmpname=cu${{ matrix.cuda-version }}torch${{ matrix.torch-version }}
117117
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
118118
ls dist/*whl |xargs -I {} mv {} ${wheel_name}
@@ -127,4 +127,4 @@ jobs:
127127
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
128128
asset_path: ./${{env.wheel_name}}
129129
asset_name: ${{env.wheel_name}}
130-
asset_content_type: application/*
130+
asset_content_type: application/*

.gitignore

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
*.pt
2+
*.tfevents.*
3+
# JetBrains PyCharm IDE
4+
.idea/
5+
6+
# Byte-compiled / optimized / DLL files
7+
__pycache__/
8+
*.py[cod]
9+
*$py.class
10+
11+
# C extensions
12+
*.so
13+
14+
# macOS dir files
15+
.DS_Store
16+
17+
# Distribution / packaging
18+
.Python
19+
env/
20+
build/
21+
develop-eggs/
22+
dist/
23+
downloads/
24+
eggs/
25+
.eggs/
26+
lib/
27+
lib64/
28+
parts/
29+
sdist/
30+
var/
31+
wheels/
32+
*.egg-info/
33+
.installed.args
34+
*.egg
35+
36+
# Checkpoints
37+
checkpoints
38+
39+
# PyInstaller
40+
# Usually these files are written by a python script from a template
41+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
42+
*.manifest
43+
*.spec
44+
45+
# Installer logs
46+
pip-log.txt
47+
pip-delete-this-directory.txt
48+
49+
# Unit test / coverage reports
50+
htmlcov/
51+
.tox/
52+
.coverage
53+
.coverage.*
54+
.cache
55+
nosetests.xml
56+
coverage.xml
57+
*.cover
58+
.hypothesis/
59+
60+
# Translations
61+
*.mo
62+
*.pot
63+
64+
# Django stuff:
65+
*.log
66+
local_settings.py
67+
68+
# Flask stuff:
69+
instance/
70+
.webassets-cache
71+
72+
# Scrapy stuff:
73+
.scrapy
74+
75+
# Sphinx documentation
76+
docs/_build/
77+
78+
# PyBuilder
79+
target/
80+
81+
# Jupyter Notebook
82+
.ipynb_checkpoints
83+
84+
# pyenv
85+
.python-version
86+
87+
# celery beat schedule file
88+
celerybeat-schedule
89+
90+
# SageMath parsed files
91+
*.sage.py
92+
93+
# dotenv
94+
.env
95+
96+
# virtualenv
97+
.venv
98+
venv/
99+
ENV/
100+
101+
# Spyder project settings
102+
.spyderproject
103+
.spyproject
104+
105+
# Rope project settings
106+
.ropeproject
107+
108+
# mypy
109+
.mypy_cache/
110+
111+
# VSCODE
112+
.vscode/ftp-sync.json
113+
.vscode/settings.json
114+
115+
# too big to git
116+
*.lmdb
117+
*.sto
118+
*.pt
119+
*.pkl
120+
121+
# pytest
122+
.pytest_cache
123+
test/.pytest_cache
124+
/local*
125+
/_*

0 commit comments

Comments
 (0)