|
1 | 1 | tokens_to_image = r""" |
2 | 2 | module @sdxl_compiled_pipeline {{ |
3 | 3 | func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x6x{precision}>, tensor<i64>) attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}]"}} |
4 | | - func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<{batch_size}x4x{lw}x{lh}x{precision}>, %arg1: tensor<{bd}x{max_length}x2048x{precision}>, %arg2: tensor<{bd}x1280x{precision}>, %arg3: tensor<{bd}x6x{precision}>, %arg4: tensor<1x{precision}>, %arg5: tensor<1xi64>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]"}} |
| 4 | + func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<{batch_size}x4x{lw}x{lh}x{precision}>, %arg1: tensor<1x{precision}>, %arg2: tensor<{bd}x{max_length}x2048x{precision}>, %arg3: tensor<{bd}x1280x{precision}>, %arg4: tensor<{bd}x6x{precision}>, %arg5: tensor<1x{precision}>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]"}} |
5 | 5 | func.func private @compiled_clip.encode_prompts(%arg0: tensor<{batch_size}x{max_length}xi64>, %arg1: tensor<{batch_size}x{max_length}xi64>, %arg2: tensor<{batch_size}x{max_length}xi64>, %arg3: tensor<{batch_size}x{max_length}xi64>) -> (tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>) attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}]"}} |
6 | 6 | func.func private @{vae_module}.main(%arg0: tensor<{batch_size}x4x{lw}x{lh}x{precision}>) -> tensor<{batch_size}x3x{width}x{height}x{precision}> attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]"}} |
7 | 7 |
|
|
15 | 15 | %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) {{ |
16 | 16 | %step_64 = arith.index_cast %arg0 : index to i64 |
17 | 17 | %this_step = tensor.from_elements %step_64 : tensor<1xi64> |
18 | | - %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>, tensor<1xi64>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> |
19 | | - scf.yield %inner : tensor<{batch_size}x4x{lw}x{lh}x{precision}> |
| 18 | + %scaled, %timestep = func.call @{scheduler_module}.run_scale(%arg, %this_step, %steps) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1xi64>, tensor<{num_steps}xf32>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>) |
| 19 | + %inner = func.call @compiled_scheduled_unet.run_forward(%scaled, %timestep, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<1xi64>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> |
| 20 | + %pred = func.call @{scheduler_module}.run_step(%inner, %timestep, %arg) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> |
| 21 | + scf.yield %pred : tensor<{batch_size}x4x{lh}x{lw}x{precision}> |
20 | 22 | }} |
21 | 23 | %image = func.call @{vae_module}.main(%res): (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) -> tensor<{batch_size}x3x{width}x{height}x{precision}> |
22 | 24 | return %image : tensor<{batch_size}x3x{width}x{height}x{precision}> |
|
27 | 29 | unet_loop = r""" |
28 | 30 | module @sdxl_compiled_pipeline {{ |
29 | 31 | func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<{batch_size}x4x{lw}x{lh}x{precision}>) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x6x{precision}>, tensor<i64>) attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}]"}} |
30 | | - func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<{batch_size}x4x{lw}x{lh}x{precision}>, %arg1: tensor<{bd}x{max_length}x2048x{precision}>, %arg2: tensor<{bd}x1280x{precision}>, %arg3: tensor<{bd}x6x{precision}>, %arg4: tensor<1x{precision}>, %arg5: tensor<1xi64>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]"}} |
31 | | - |
| 32 | + func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<{batch_size}x4x{lw}x{lh}x{precision}>, %arg1: tensor<1x{precision}>, %arg2: tensor<{bd}x{max_length}x2048x{precision}>, %arg3: tensor<{bd}x1280x{precision}>, %arg4: tensor<{bd}x6x{precision}>, %arg5: tensor<1x{precision}>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]"}} |
| 33 | +
|
32 | 34 | func.func @produce_image_latents(%sample: tensor<{batch_size}x4x{lw}x{lh}x{precision}>, %p_embeds: tensor<{bd}x{max_length}x2048x{precision}>, %t_embeds: tensor<{bd}x1280x{precision}>, %guidance_scale: tensor<1x{precision}>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> {{ |
33 | 35 | %noisy_sample, %time_ids, %steps = func.call @compiled_scheduled_unet.run_initialize(%sample) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x6x{precision}>, tensor<i64>) |
34 | 36 | %c0 = arith.constant 0 : index |
|
38 | 40 | %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) {{ |
39 | 41 | %step_64 = arith.index_cast %arg0 : index to i64 |
40 | 42 | %this_step = tensor.from_elements %step_64 : tensor<1xi64> |
41 | | - %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>, tensor<1xi64>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> |
| 43 | + %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %this_step, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<1x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> |
42 | 44 | scf.yield %inner : tensor<{batch_size}x4x{lw}x{lh}x{precision}> |
43 | 45 | }} |
44 | 46 | return %res : tensor<{batch_size}x4x{lw}x{lh}x{precision}> |
|
49 | 51 | produce_img_split = r""" |
50 | 52 | module @sdxl_compiled_pipeline {{ |
51 | 53 | func.func private @{scheduler_module}.run_initialize(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<{bd}x6x{precision}>, tensor<1xf16>, tensor<{num_steps}xf32>) attributes {{torch.assume_strict_symbolic_shapes}} |
52 | | - func.func private @{scheduler_module}.run_scale(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1xi64>, %arg2: tensor<{num_steps}xf32>) -> (tensor<{bd}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>) attributes {{torch.assume_strict_symbolic_shapes}} |
| 54 | + func.func private @{scheduler_module}.run_scale(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1xi64>, %arg2: tensor<{num_steps}xf32>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>) attributes {{torch.assume_strict_symbolic_shapes}} |
53 | 55 | func.func private @{scheduler_module}.run_step(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1x{precision}>, %arg2: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}} |
54 | | - func.func private @{unet_module}.{unet_function}(%arg0: tensor<{bd}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1x{precision}>, %arg2: tensor<{bd}x{max_length}x2048x{precision}>, %arg3: tensor<{bd}x1280x{precision}>, %arg4: tensor<{bd}x6x{precision}>, %arg5: tensor<1x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}} |
| 56 | + func.func private @{unet_module}.{unet_function}(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1x{precision}>, %arg2: tensor<{bd}x{max_length}x2048x{precision}>, %arg3: tensor<{bd}x1280x{precision}>, %arg4: tensor<{bd}x6x{precision}>, %arg5: tensor<1x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}} |
55 | 57 | func.func private @{vae_module}.decode(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x3x{height}x{width}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}} |
56 | | - |
| 58 | +
|
57 | 59 | func.func @produce_image_latents(%sample: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %p_embeds: tensor<{bd}x{max_length}x2048x{precision}>, %t_embeds: tensor<{bd}x1280x{precision}>, %guidance_scale: tensor<1x{precision}>) -> tensor<{batch_size}x3x{height}x{width}x{precision}> {{ |
58 | 60 | %noisy_sample, %time_ids, %delete, %timesteps = func.call @{scheduler_module}.run_initialize(%sample) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>, tensor<{num_steps}xf32>) |
59 | 61 | %c0 = arith.constant 0 : index |
|
62 | 64 | %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>) {{ |
63 | 65 | %step_64 = arith.index_cast %arg0 : index to i64 |
64 | 66 | %this_step = tensor.from_elements %step_64 : tensor<1xi64> |
65 | | - %scaled, %timestep = func.call @{scheduler_module}.run_scale(%arg, %this_step, %timesteps) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1xi64>, tensor<{num_steps}xf32>) -> (tensor<{bd}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>) |
66 | | - %inner = func.call @{unet_module}.{unet_function}(%scaled, %timestep, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (tensor<{bd}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> |
| 67 | + %scaled, %timestep = func.call @{scheduler_module}.run_scale(%arg, %this_step, %timesteps) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1xi64>, tensor<{num_steps}xf32>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>) |
| 68 | + %inner = func.call @{unet_module}.{unet_function}(%scaled, %timestep, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> |
67 | 69 | %pred = func.call @{scheduler_module}.run_step(%inner, %timestep, %arg) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> |
68 | 70 | scf.yield %pred : tensor<{batch_size}x4x{lh}x{lw}x{precision}> |
69 | 71 | }} |
|
0 commit comments