2020
2121from einops import rearrange , pack , unpack , repeat
2222
23- from optree import tree_flatten , tree_unflatten
24- import torch .utils ._pytree as pytree
23+ from torch .utils ._pytree import tree_flatten , tree_unflatten
2524
2625from x_transformers .x_transformers import (
2726 Attention ,
@@ -547,11 +546,8 @@ def auto_repeat_tensors_for_time(_, args, kwargs):
547546 first_arg , * rest_args = args
548547
549548 all_rest_args = (rest_args , kwargs )
550- if torch .compiler .is_dynamo_compiling ():
551- all_rest_args , pytree_spec = pytree .tree_flatten (all_rest_args )
552- else :
553- all_rest_args , pytree_spec = tree_flatten (all_rest_args )
554549
550+ all_rest_args , pytree_spec = tree_flatten (all_rest_args )
555551
556552 if not is_tensor (first_arg ) or len (all_rest_args ) == 0 :
557553 return args , kwargs
@@ -571,11 +567,8 @@ def auto_repeat_tensors_for_time(_, args, kwargs):
571567
572568 out_rest_args .append (arg )
573569
574- if torch .compiler .is_dynamo_compiling ():
575- # reordering of args is deliberate. pytree and optree have different API
576- rest_args , kwargs = pytree .tree_unflatten (out_rest_args , pytree_spec )
577- else :
578- rest_args , kwargs = tree_unflatten (pytree_spec , out_rest_args )
570+ rest_args , kwargs = tree_unflatten (out_rest_args , pytree_spec )
571+
579572 return (first_arg , * rest_args ), kwargs
580573
581574 for module in self .modules ():
0 commit comments