Skip to content

Commit 48f11a1

Browse files
chunnienccopybara-github
authored andcommitted
Fix jax bridge lowering with composite
PiperOrigin-RevId: 875814513
1 parent e4e5d5e commit 48f11a1

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

litert_torch/backend/jax_bridge/_wrap.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def wrapped(lctx, *args, **kwargs):
147147
builtin.module(
148148
vhlo-legalize-stablehlo,
149149
reconcile-unrealized-casts,
150+
func.func(stablehlo-legalize-composite-to-call),
151+
inline,
150152
strip-debuginfo
151153
)""")).run(module.operation)
152154

0 commit comments

Comments
 (0)