4040 BoolTypeExpr ,
4141 IsOfType ,
4242 SignatureDispatcher ,
43+ unwrap_if_possible ,
4344)
4445from .shape import (
4546 broadcast_dims ,
@@ -65,64 +66,65 @@ def assert_on_same_devices(*tensors: Tuple[ShardedTensor]) -> None:
6566 raise ValueError ("All tensors must be placed on the same devices." )
6667
6768
68- def sharded_wrap_override ():
69- def transfer_n_pin (f ):
69+ def transfer_n_pin (f ):
70+ """
71+ Wrapper for each NON-TRANSFERRING op defined in this file.
72+ """
73+
74+ def func_wrapper (* args : Tuple , ** kwargs : Dict [str , Any ]):
7075 """
71- Wrapper for each NON-TRANSFERRING op defined in this file.
76+ Wraps each NON-TRANSFERRING operation, f, to ensure that all incoming tensors are on the same device and that the result has the devices correctly labelled.
77+
78+ If no ShardedTensors are present in the input, then no changes are made to input/output.
7279 """
80+ sharded_tensors = []
81+ for value in itertools .chain (args , kwargs .values ()):
82+ if isinstance (value , ShardedTensor ):
83+ sharded_tensors .append (value )
84+ continue
85+ if isinstance (
86+ value ,
87+ (
88+ InferenceTensor ,
89+ Tensor ,
90+ ),
91+ ):
92+ continue
93+ if isinstance (value , Iterable ):
94+ for val in value :
95+ if isinstance (val , ShardedTensor ):
96+ sharded_tensors .append (val )
97+
98+ assert_on_same_devices (* sharded_tensors )
99+ res = f (* args , ** kwargs )
100+ if len (sharded_tensors ) > 0 :
101+ if isinstance (res , ShardedTensor ):
102+ res = res .clone (devices = sharded_tensors [0 ].devices )
103+ elif isinstance (res , Iterable ) and all (
104+ isinstance (r , ShardedTensor ) for r in res
105+ ):
106+ res = type (res )(
107+ r .clone (devices = sharded_tensors [0 ].devices ) for r in res
108+ )
109+ return res
73110
74- def func_wrapper (* args : Tuple , ** kwargs : Dict [str , Any ]):
75- """
76- Wraps each NON-TRANSFERRING operation, f, to ensure that all incoming tensors are on the same device and that the result has the devices correctly labelled.
77-
78- If no ShardedTensors are present in the input, then no changes are made to input/output.
79- """
80- sharded_tensors = []
81- for value in itertools .chain (args , kwargs .values ()):
82- if isinstance (value , ShardedTensor ):
83- sharded_tensors .append (value )
84- continue
85- if isinstance (
86- value ,
87- (
88- InferenceTensor ,
89- torch .Tensor ,
90- ),
91- ):
92- continue
93- if isinstance (value , Iterable ):
94- for val in value :
95- if isinstance (val , ShardedTensor ):
96- sharded_tensors .append (val )
97-
98- assert_on_same_devices (* sharded_tensors )
99- res = f (* args , ** kwargs )
100- if len (sharded_tensors ) > 0 :
101- if isinstance (res , ShardedTensor ):
102- res = res .clone (devices = sharded_tensors [0 ].devices )
103- elif isinstance (res , Iterable ) and all (
104- isinstance (r , ShardedTensor ) for r in res
105- ):
106- res = type (res )(
107- r .clone (devices = sharded_tensors [0 ].devices ) for r in res
108- )
109- return res
110-
111- func_wrapper ._impl_name = getattr (f , "_impl_name" , None ) # For impl selection
112-
113- if hasattr (f , "_trivially_replicable_wrapper" ):
114- # If wrapping a trivially replicable function, we do not know what underlying op will be called on each shard,
115- # since we don't dispatch based on shards.
116- # Instead label this wrapper as a trivially replicable wrapper so that
117- # _TEST_LAST_OP_DISPATCH tracking can handle it correctly.
118- # _TEST_LAST_OP_DISPATCH will not update for this wrapper, but instead allow the last inner call to set it.
119- func_wrapper ._trivially_replicable_wrapper = f ._trivially_replicable_wrapper
120- else :
121- # We know the underlying op will be called, set _unwrapped to the original op
122- # so that _TEST_LAST_OP_DISPATCH tracking can handle it correctly.
123- func_wrapper ._unwrapped = f
124- return func_wrapper
111+ func_wrapper ._impl_name = getattr (f , "_impl_name" , None ) # For impl selection
125112
113+ if hasattr (f , "_trivially_replicable_wrapper" ):
114+ # If wrapping a trivially replicable function, we do not know what underlying op will be called on each shard,
115+ # since we don't dispatch based on shards.
116+ # Instead label this wrapper as a trivially replicable wrapper so that
117+ # _TEST_LAST_OP_DISPATCH tracking can handle it correctly.
118+ # _TEST_LAST_OP_DISPATCH will not update for this wrapper, but instead allow the last inner call to set it.
119+ func_wrapper ._trivially_replicable_wrapper = f ._trivially_replicable_wrapper
120+ else :
121+ # We know which underlying op will be called, set _unwrapped to the original op
122+ # so that _TEST_LAST_OP_DISPATCH tracking can handle it correctly.
123+ func_wrapper ._unwrapped = unwrap_if_possible (f )
124+ return func_wrapper
125+
126+
127+ def sharded_wrap_override ():
126128 def wrap_override (signature_dispatcher_override ):
127129 """
128130 Wrap [op].override's result so that the transfer_n_pin(f) becomes the target in _TargetOverride rather than f itself.
0 commit comments