Skip to content

Commit 6c9d96d

Browse files
committed
Update mmdit runner inputs, small attn reproducer, pad attention flag
1 parent 3495f63 commit 6c9d96d

File tree

4 files changed

+109
-5
lines changed

4 files changed

+109
-5
lines changed

models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,13 @@ def is_valid_file(arg):
284284
default="SD3_output.png",
285285
help="Path to output file for generated images.",
286286
)
287+
p.add_argument(
288+
"--attn_repro",
289+
default=False,
290+
action="store_true",
291+
help="Just compile attention reproducer for mmdit.",
292+
)
293+
287294

288295
##############################################################################
289296
# IREE Compiler Options

models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,87 @@ def forward(
5252
return_dict=False,
5353
)[0]
5454
return noise_pred
55+
56+
class MMDiTAttention(torch.nn.Module):
57+
def __init__(
58+
self,
59+
):
60+
super().__init__()
61+
62+
def forward(self, q, k, v):
63+
return torch.nn.functional.scaled_dot_product_attention(
64+
q, k, v, dropout_p=0.0, is_causal=False
65+
)
66+
67+
68+
@torch.no_grad()
69+
def export_attn(
70+
precision="fp16",
71+
device="cpu",
72+
target_triple="x86_64-unknown-linux-gnu",
73+
ireec_flags="",
74+
compile_to="torch",
75+
decomp_attn=False,
76+
attn_spec=None,
77+
):
78+
dtype = torch.float16 if precision == "fp16" else torch.float32
79+
qkv_shape = (2, 24, 4250, 64)
80+
attn_module = MMDiTAttention()
81+
safe_name = "attn_repro_" + precision + "_" + target_triple
82+
if decomp_attn == True:
83+
safe_name += "_decomp"
84+
85+
if dtype == torch.float16:
86+
attn_module = attn_module.half()
87+
88+
example_qkv = [
89+
torch.empty(qkv_shape, dtype=dtype),
90+
torch.empty(qkv_shape, dtype=dtype),
91+
torch.empty(qkv_shape, dtype=dtype),
92+
]
93+
94+
decomp_list = []
95+
if decomp_attn == True:
96+
decomp_list = [
97+
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
98+
torch.ops.aten._scaled_dot_product_flash_attention.default,
99+
torch.ops.aten.scaled_dot_product_attention,
100+
]
101+
with decompositions.extend_aot_decompositions(
102+
from_current=True,
103+
add_ops=decomp_list,
104+
):
105+
fxb = FxProgramsBuilder(attn_module)
55106

107+
@fxb.export_program(
108+
args=(example_qkv,),
109+
)
110+
def _forward(
111+
module,
112+
inputs,
113+
):
114+
return module.forward(*inputs)
115+
116+
class CompiledAttn(CompiledModule):
117+
run_forward = _forward
118+
119+
inst = CompiledAttn(context=Context(), import_to="IMPORT")
120+
121+
module_str = str(CompiledModule.get_mlir_module(inst))
122+
123+
if compile_to != "vmfb":
124+
return module_str
125+
else:
126+
vmfb_path = utils.compile_to_vmfb(
127+
module_str,
128+
device,
129+
target_triple,
130+
ireec_flags,
131+
safe_name,
132+
return_path=True,
133+
attn_spec=attn_spec,
134+
)
135+
return vmfb_path
56136

57137
@torch.no_grad()
58138
def export_mmdit_model(
@@ -183,6 +263,22 @@ class CompiledMmdit(CompiledModule):
183263
logging.basicConfig(level=logging.DEBUG)
184264
from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args
185265

266+
if args.attn_repro:
267+
mod_str = export_attn(
268+
args.precision,
269+
args.device,
270+
args.iree_target_triple,
271+
args.ireec_flags,
272+
args.compile_to,
273+
args.decomp_attn,
274+
attn_spec=args.attn_spec,
275+
)
276+
if args.compile_to != "vmfb":
277+
safe_name = "attn_repro_" + args.precision
278+
with open(f"{safe_name}.mlir", "w+") as f:
279+
f.write(mod_str)
280+
print("Saved to", safe_name + ".mlir")
281+
exit()
186282
if args.input_mlir:
187283
mmdit_model = None
188284
else:

models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,15 @@ def run_diffusers_mmdit(
6969
dtype = torch.float16
7070
else:
7171
dtype = torch.float32
72-
72+
73+
batch_size = args.batch_size * 2 #do classifier free guidance
7374
hidden_states = torch.randn(
74-
(args.batch_size, 16, args.height // 8, args.width // 8), dtype=dtype
75+
(batch_size, 16, args.height // 8, args.width // 8), dtype=dtype
7576
)
7677
encoder_hidden_states = torch.randn(
77-
(args.batch_size, args.max_length * 2, 4096), dtype=dtype
78+
(batch_size, args.max_length * 2, 4096), dtype=dtype
7879
)
79-
pooled_projections = torch.randn((args.batch_size, 2048), dtype=dtype)
80+
pooled_projections = torch.randn((batch_size, 2048), dtype=dtype)
8081
timestep = torch.tensor([0], dtype=dtype)
8182

8283
turbine_output = run_mmdit_turbine(

models/turbine_models/custom_models/sd_inference/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
"--iree-codegen-gpu-native-math-precision=true",
5858
"--iree-codegen-llvmgpu-use-vector-distribution=true",
5959
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
60-
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))",
60+
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))",
6161
],
6262
"unet": [""],
6363
"clip": [""],

0 commit comments

Comments
 (0)