Skip to content

Commit 5308c89

Browse files
committed
torch.compile friendly pytree
1 parent c5953c8 commit 5308c89

File tree

2 files changed

+5
-12
lines changed

2 files changed

+5
-12
lines changed

lumiere_pytorch/lumiere.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020

2121
from einops import rearrange, pack, unpack, repeat
2222

23-
from optree import tree_flatten, tree_unflatten
24-
import torch.utils._pytree as pytree
23+
from torch.utils._pytree import tree_flatten, tree_unflatten
2524

2625
from x_transformers.x_transformers import (
2726
Attention,
@@ -547,11 +546,8 @@ def auto_repeat_tensors_for_time(_, args, kwargs):
547546
first_arg, *rest_args = args
548547

549548
all_rest_args = (rest_args, kwargs)
550-
if torch.compiler.is_dynamo_compiling():
551-
all_rest_args, pytree_spec = pytree.tree_flatten(all_rest_args)
552-
else:
553-
all_rest_args, pytree_spec = tree_flatten(all_rest_args)
554549

550+
all_rest_args, pytree_spec = tree_flatten(all_rest_args)
555551

556552
if not is_tensor(first_arg) or len(all_rest_args) == 0:
557553
return args, kwargs
@@ -571,11 +567,8 @@ def auto_repeat_tensors_for_time(_, args, kwargs):
571567

572568
out_rest_args.append(arg)
573569

574-
if torch.compiler.is_dynamo_compiling():
575-
# reordering of args is deliberate. pytree and optree have different API
576-
rest_args, kwargs = pytree.tree_unflatten(out_rest_args, pytree_spec)
577-
else:
578-
rest_args, kwargs = tree_unflatten(pytree_spec, out_rest_args)
570+
rest_args, kwargs = tree_unflatten(out_rest_args, pytree_spec)
571+
579572
return (first_arg, *rest_args), kwargs
580573

581574
for module in self.modules():

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lumiere-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.23',
6+
version = '0.0.24',
77
license='MIT',
88
description = 'Lumiere',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)