2121from einops import rearrange , pack , unpack , repeat
2222
2323from optree import tree_flatten , tree_unflatten
24+ import torch .utils ._pytree as pytree
2425
2526from x_transformers .x_transformers import (
2627 Attention ,
@@ -546,7 +547,11 @@ def auto_repeat_tensors_for_time(_, args, kwargs):
546547 first_arg , * rest_args = args
547548
548549 all_rest_args = (rest_args , kwargs )
549- all_rest_args , pytree_spec = tree_flatten (all_rest_args )
550+ if torch .compiler .is_dynamo_compiling ():
551+ all_rest_args , pytree_spec = pytree .tree_flatten (all_rest_args )
552+ else :
553+ all_rest_args , pytree_spec = tree_flatten (all_rest_args )
554+
550555
551556 if not is_tensor (first_arg ) or len (all_rest_args ) == 0 :
552557 return args , kwargs
@@ -566,7 +571,11 @@ def auto_repeat_tensors_for_time(_, args, kwargs):
566571
567572 out_rest_args .append (arg )
568573
569- rest_args , kwargs = tree_unflatten (pytree_spec , out_rest_args )
574+ if torch .compiler .is_dynamo_compiling ():
575+ # reordering of args is deliberate. pytree and optree have different API
576+ rest_args , kwargs = pytree .tree_unflatten (out_rest_args , pytree_spec )
577+ else :
578+ rest_args , kwargs = tree_unflatten (pytree_spec , out_rest_args )
570579 return (first_arg , * rest_args ), kwargs
571580
572581 for module in self .modules ():
0 commit comments