3838 %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) {{
3939 %step_64 = arith.index_cast %arg0 : index to i64
4040 %this_step = tensor.from_elements %step_64 : tensor<1xi64>
41- %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step ) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{ bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>, tensor<1xi64 >) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}>
41+ %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %this_step, % p_embeds, %t_embeds, %time_ids, %guidance_scale) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<1x{precision}>, tensor<{ bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}>
4242 scf.yield %inner : tensor<{batch_size}x4x{lw}x{lh}x{precision}>
4343 }}
4444 return %res : tensor<{batch_size}x4x{lw}x{lh}x{precision}>
4848
4949produce_img_split = r"""
5050module @sdxl_compiled_pipeline {{
51- func.func private @{scheduler_module}.run_initialize(%arg0: !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>) -> (!torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>, !torch.vtensor<[ {bd},6], {precision}>, !torch.vtensor<[1],f16 >, !torch.vtensor<[ {num_steps}],f32 >) attributes {{torch.assume_strict_symbolic_shapes}}
52- func.func private @{scheduler_module}.run_scale(%arg0: !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>, %arg1: !torch.vtensor<[1],si64 >, %arg2: !torch.vtensor<[ {num_steps}],f32 >) -> (!torch.vtensor<[{bd},4, {lh}, {lw}], {precision}>, !torch.vtensor<[1], {precision}>) attributes {{torch.assume_strict_symbolic_shapes}}
53- func.func private @{scheduler_module}.run_step(%arg0: !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>, %arg1: !torch.vtensor<[1], {precision}>, %arg2: !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>) -> !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}> attributes {{torch.assume_strict_symbolic_shapes}}
54- func.func private @{unet_module}.{unet_function}(%arg0: !torch.vtensor<[{bd},4, {lh}, {lw}], {precision}>, %arg1: !torch.vtensor<[1], {precision}>, %arg2: !torch.vtensor<[ {bd}, {max_length},2048], {precision}>, %arg3: !torch.vtensor<[ {bd},1280], {precision}>, %arg4: !torch.vtensor<[ {bd},6], {precision}>, %arg5: !torch.vtensor<[1], {precision}>) -> !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}> attributes {{torch.assume_strict_symbolic_shapes}}
55- func.func private @{vae_module}.decode(%arg0: !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>) -> !torch.vtensor<[ {batch_size},3, {height}, {width}], {precision}> attributes {{torch.assume_strict_symbolic_shapes}}
56-
57- func.func @produce_image_latents(%sample: !torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>, %p_embeds: !torch.vtensor<[ {bd}, {max_length},2048], {precision}>, %t_embeds: !torch.vtensor<[ {bd},1280], {precision}>, %guidance_scale: !torch.vtensor<[1], {precision}>) -> !torch.vtensor<[ {batch_size},3, {height}, {width}], {precision}> {{
58- %noisy_sample, %time_ids, %delete, %timesteps = func.call @{scheduler_module}.run_initialize(%sample) : (!torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>) -> (!torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>, !torch.vtensor<[ {bd},6], {precision}>, !torch.vtensor<[1], {precision}>, !torch.vtensor<[ {num_steps}],f32 >)
51+ func.func private @{scheduler_module}.run_initialize(%arg0: tensor< {batch_size}x4x {lh}x {lw}x {precision}>) -> (tensor< {batch_size}x4x {lh}x {lw}x {precision}>, tensor< {bd}x6x {precision}>, tensor<1xf16 >, tensor< {num_steps}xf32 >) attributes {{torch.assume_strict_symbolic_shapes}}
52+ func.func private @{scheduler_module}.run_scale(%arg0: tensor< {batch_size}x4x {lh}x {lw}x {precision}>, %arg1: tensor<1xi64 >, %arg2: tensor< {num_steps}xf32 >) -> (tensor<{batch_size}x4x {lh}x {lw}x {precision}>, tensor<1x {precision}>) attributes {{torch.assume_strict_symbolic_shapes}}
53+ func.func private @{scheduler_module}.run_step(%arg0: tensor< {batch_size}x4x {lh}x {lw}x {precision}>, %arg1: tensor<1x {precision}>, %arg2: tensor< {batch_size}x4x {lh}x {lw}x {precision}>) -> tensor< {batch_size}x4x {lh}x {lw}x {precision}> attributes {{torch.assume_strict_symbolic_shapes}}
54+ func.func private @{unet_module}.{unet_function}(%arg0: tensor<{batch_size}x4x {lh}x {lw}x {precision}>, %arg1: tensor<1x {precision}>, %arg2: tensor< {bd}x {max_length}x2048x {precision}>, %arg3: tensor< {bd}x1280x {precision}>, %arg4: tensor< {bd}x6x {precision}>, %arg5: tensor<1x {precision}>) -> tensor< {batch_size}x4x {lh}x {lw}x {precision}> attributes {{torch.assume_strict_symbolic_shapes}}
55+ func.func private @{vae_module}.decode(%arg0: tensor< {batch_size}x4x {lh}x {lw}x {precision}>) -> tensor< {batch_size}x3x {height}x {width}x {precision}> attributes {{torch.assume_strict_symbolic_shapes}}
56+
57+ func.func @produce_image_latents(%sample: tensor< {batch_size}x4x {lh}x {lw}x {precision}>, %p_embeds: tensor< {bd}x {max_length}x2048x {precision}>, %t_embeds: tensor< {bd}x1280x {precision}>, %guidance_scale: tensor<1x {precision}>) -> tensor< {batch_size}x3x {height}x {width}x {precision}> {{
58+ %noisy_sample, %time_ids, %delete, %timesteps = func.call @{scheduler_module}.run_initialize(%sample) : (tensor< {batch_size}x4x {lh}x {lw}x {precision}>) -> (tensor< {batch_size}x4x {lh}x {lw}x {precision}>, tensor< {bd}x6x {precision}>, tensor<1x {precision}>, tensor< {num_steps}xf32 >)
5959 %c0 = arith.constant 0 : index
6060 %c1 = arith.constant 1 : index
6161 %n_steps = arith.constant {num_steps} : index
62- %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (!torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>) {{
62+ %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor< {batch_size}x4x {lh}x {lw}x {precision}>) {{
6363 %step_64 = arith.index_cast %arg0 : index to i64
6464 %this_step = tensor.from_elements %step_64 : tensor<1xi64>
65- %step_torch = torch_c.from_builtin_tensor %this_step : tensor<1xi64> -> !torch.vtensor<[1],si64>
66- %scaled, %timestep = func.call @{scheduler_module}.run_scale(%arg, %step_torch, %timesteps) : (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],si64>, !torch.vtensor<[{num_steps}],f32>) -> (!torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>)
67- %inner = func.call @{unet_module}.{unet_function}(%scaled, %timestep, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (!torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>, !torch.vtensor<[{bd},{max_length},2048],{precision}>, !torch.vtensor<[{bd},1280],{precision}>, !torch.vtensor<[{bd},6],{precision}>, !torch.vtensor<[1],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>
68- %pred = func.call @{scheduler_module}.run_step(%inner, %timestep, %arg) : (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>, !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>
69- scf.yield %pred : !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>
65+ %scaled, %timestep = func.call @{scheduler_module}.run_scale(%arg, %this_step, %timesteps) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1xi64>, tensor<{num_steps}xf32>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>)
66+ %inner = func.call @{unet_module}.{unet_function}(%scaled, %timestep, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}>
67+ %pred = func.call @{scheduler_module}.run_step(%inner, %timestep, %arg) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}>
68+ scf.yield %pred : tensor<{batch_size}x4x{lh}x{lw}x{precision}>
7069 }}
71- %image = func.call @{vae_module}.decode(%res): (!torch.vtensor<[ {batch_size},4, {lh}, {lw}], {precision}>) -> !torch.vtensor<[ {batch_size},3, {height}, {width}], {precision}>
72- return %image : !torch.vtensor<[ {batch_size},3, {height}, {width}], {precision}>
70+ %image = func.call @{vae_module}.decode(%res): (tensor< {batch_size}x4x {lh}x {lw}x {precision}>) -> tensor< {batch_size}x3x {height}x {width}x {precision}>
71+ return %image : tensor< {batch_size}x3x {height}x {width}x {precision}>
7372 }}
7473}}
7574"""
@@ -128,4 +127,4 @@ def get_pipeline_ir(
128127 scheduler_module = scheduler_module_name ,
129128 vae_module = vae_module_name ,
130129 num_steps = num_steps ,
131- )
130+ )
0 commit comments