Skip to content

Commit 6199246

Browse files
committed
[torch.compile] Use pytree with torch.compile
1 parent 3dcd22b commit 6199246

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

lumiere_pytorch/lumiere.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from einops import rearrange, pack, unpack, repeat
2222

2323
from optree import tree_flatten, tree_unflatten
24+
import torch.utils._pytree as pytree
2425

2526
from x_transformers.x_transformers import (
2627
Attention,
@@ -546,7 +547,11 @@ def auto_repeat_tensors_for_time(_, args, kwargs):
546547
first_arg, *rest_args = args
547548

548549
all_rest_args = (rest_args, kwargs)
549-
all_rest_args, pytree_spec = tree_flatten(all_rest_args)
550+
if torch.compiler.is_dynamo_compiling():
551+
all_rest_args, pytree_spec = pytree.tree_flatten(all_rest_args)
552+
else:
553+
all_rest_args, pytree_spec = tree_flatten(all_rest_args)
554+
550555

551556
if not is_tensor(first_arg) or len(all_rest_args) == 0:
552557
return args, kwargs
@@ -566,7 +571,11 @@ def auto_repeat_tensors_for_time(_, args, kwargs):
566571

567572
out_rest_args.append(arg)
568573

569-
rest_args, kwargs = tree_unflatten(pytree_spec, out_rest_args)
574+
if torch.compiler.is_dynamo_compiling():
575+
# reordering of args is deliberate. pytree and optree have different API
576+
rest_args, kwargs = pytree.tree_unflatten(out_rest_args, pytree_spec)
577+
else:
578+
rest_args, kwargs = tree_unflatten(pytree_spec, out_rest_args)
570579
return (first_arg, *rest_args), kwargs
571580

572581
for module in self.modules():

0 commit comments

Comments
 (0)