Skip to content

Commit e46a2a2

Browse files
monorimetIanNoddan-garveyaviator19941saienduri
authored
Unifies SD pipeline APIs, adds sd3 support, punet integration (#706)
- Introduces a new sd_pipeline.py that handles inference for sd1.5, sd2.1, sdxl, sdxl-turbo, sd3. The pipeline is a child of the new pipeline_base.py that provides a comprehensive starting point to bringing up new pipelines. - Generally moves SDXL away from the "scheduled unet" approach, instead compiling small scheduler models that fit around a standalone unet module. - Reworks pipeline API to enable deployment / compatibility APIs - Adds multi-device pipelining support to SD pipeline - Carries flag updates for key targets - file management improvements - integrates sharktank int8 partitioned unet. Signed-off-by: aviator19941 <[email protected]> Signed-off-by: monorimet <[email protected]> Co-authored-by: Ian <[email protected]> Co-authored-by: dan <[email protected]> Co-authored-by: IanNod <[email protected]> Co-authored-by: aviator19941 <[email protected]> Co-authored-by: saienduri <[email protected]>
1 parent 4f5f31f commit e46a2a2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+10347
-3135
lines changed

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
run: git fetch --no-tags --prune --depth=1 origin "${GITHUB_BASE_REF?}:${GITHUB_BASE_REF?}"
2222
- name: Install black
2323
run: |
24-
python3 -m pip install black==23.3
24+
python3 -m pip install black
2525
- name: Check if modified files are formatted
2626
run: |
2727
# The filter lowercase `d` means to exclude deleted files.

.github/workflows/test_models.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
# from non default locations first. Installing the PyTorch CPU
5151
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
5252
pip install --no-compile -r ${{ github.workspace }}/iree-turbine/pytorch-cpu-requirements.txt
53-
pip install --no-compile --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt
53+
pip install --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt
5454
pip install --no-compile --pre -e ${{ github.workspace }}/iree-turbine[testing]
5555
pip install --upgrade --pre --no-cache-dir iree-compiler iree-runtime -f https://iree.dev/pip-release-links.html
5656
pip install --no-compile --pre --upgrade -e models -r models/requirements.txt
@@ -69,7 +69,8 @@ jobs:
6969
source turbine_venv/bin/activate
7070
7171
pytest -v models/turbine_models/tests/sd_test.py
72-
pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu
72+
pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5
7373
pytest -v models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux
74-
pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16
75-
74+
pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default
75+
pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default --batch_size 2
76+
pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5

.github/workflows/test_shark.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
strategy:
2121
matrix:
2222
version: [3.11]
23-
os: [nodai-ubuntu-builder-large]
23+
os: [nodai-amdgpu-mi250-x86-64]
2424

2525
runs-on: ${{matrix.os}}
2626
steps:
@@ -49,7 +49,6 @@ jobs:
4949
cd $GITHUB_WORKSPACE/SHARK
5050
python${{ matrix.version }} -m venv shark.venv
5151
source shark.venv/bin/activate
52-
sed -i 's/SHARK-Turbine#/SHARK-Turbine.git@${{github.sha}}#/g' requirements.txt
5352
pip install -r requirements.txt --no-cache-dir
5453
pip install -e .
5554
python apps/shark_studio/tests/api_test.py

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,7 @@ wheelhouse
2828
*.safetensors
2929
*.gguf
3030
*.vmfb
31+
*.mlir
32+
*.npy
33+
*.png
34+
*tmp*

models/requirements.txt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
protobuf
2-
sentencepiece
3-
shark_turbine
2+
gguf
43
transformers==4.37.1
4+
torchsde
55
accelerate
6-
diffusers @ git+https://github.com/nod-ai/[email protected]
6+
peft
7+
diffusers @ git+https://github.com/nod-ai/[email protected]
78
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
89
# turbine tank downloading/uploading
910
azure-storage-blob
1011
# microsoft/phi model
1112
einops
13+
pytest
14+
scipy
15+
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
16+
-e git+https://github.com/nod-ai/sharktank.git@main#egg=sharktank&subdirectory=sharktank

models/setup.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,11 @@ def load_version_info():
5555
),
5656
install_requires=[
5757
"Shark-Turbine",
58-
"brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b",
5958
"protobuf",
6059
"sentencepiece",
61-
"transformers==4.37.1",
60+
"transformers>=4.37.1",
6261
"accelerate",
63-
"diffusers==0.24.0",
62+
"diffusers==0.29.0.dev0",
6463
"azure-storage-blob",
6564
"einops",
6665
],

models/turbine_models/custom_models/llama_argmax_td_spec.mlir

Lines changed: 169 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)