@@ -80,7 +80,7 @@ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
80
80
81
81
def enc_hook (self , obj : Any ) -> Any :
82
82
if isinstance (obj , torch .Tensor ):
83
- return self ._encode_ndarray (obj . numpy () )
83
+ return self ._encode_tensor (obj )
84
84
85
85
# Fall back to pickle for object or void kind ndarrays.
86
86
if isinstance (obj , np .ndarray ) and obj .dtype .kind not in ('O' , 'V' ):
@@ -133,9 +133,26 @@ def _encode_ndarray(
133
133
# backing buffers that we've stashed in `aux_buffers`.
134
134
return obj .dtype .str , obj .shape , data
135
135
136
+ def _encode_tensor (
137
+ self , obj : torch .Tensor
138
+ ) -> tuple [str , tuple [int , ...], Union [int , memoryview ]]:
139
+ assert self .aux_buffers is not None
140
+ # this creates a copy of the tensor
141
+ obj = obj .contiguous () if not obj .is_contiguous () else obj
142
+ # view the tensor as a 1D array of bytes
143
+ arr = obj .view ([obj .numel ()]).view (torch .uint8 ).numpy ()
144
+ if obj .nbytes < self .size_threshold :
145
+ data = msgpack .Ext (CUSTOM_TYPE_RAW_VIEW , arr .data )
146
+ else :
147
+ # Otherwise encode index of backing buffer to avoid copy.
148
+ data = len (self .aux_buffers )
149
+ self .aux_buffers .append (arr .data )
150
+ dt = str (obj .dtype )[6 :] # remove 'torch.' prefix
151
+ return dt , obj .shape , data
152
+
136
153
def _encode_nested_tensors (self , nt : NestedTensors ) -> Any :
137
154
if isinstance (nt , torch .Tensor ):
138
- return self ._encode_ndarray (nt . numpy () )
155
+ return self ._encode_tensor (nt )
139
156
if isinstance (nt , (int , float )):
140
157
# Although it violates NestedTensors type, MultiModalKwargs
141
158
# values are sometimes floats.
@@ -186,7 +203,7 @@ def dec_hook(self, t: type, obj: Any) -> Any:
186
203
if issubclass (t , np .ndarray ):
187
204
return self ._decode_ndarray (obj )
188
205
if issubclass (t , torch .Tensor ):
189
- return torch . from_numpy ( self ._decode_ndarray (obj ) )
206
+ return self ._decode_tensor (obj )
190
207
if issubclass (t , MultiModalKwargs ):
191
208
if isinstance (obj , list ):
192
209
return MultiModalKwargs .from_items (
@@ -205,6 +222,15 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray:
205
222
else bytearray (data )
206
223
return np .ndarray (buffer = buffer , dtype = np .dtype (dtype ), shape = shape )
207
224
225
+ def _decode_tensor (self , arr : Any ) -> torch .Tensor :
226
+ dtype , shape , data = arr
227
+ # Copy from inline representation, otherwise Torch is unhappy since
228
+ # the returned memory is non-writeable.
229
+ buffer = self .aux_buffers [data ] if isinstance (data , int ) \
230
+ else bytearray (data )
231
+ arr = np .ndarray (buffer = buffer , dtype = np .uint8 , shape = [len (buffer )])
232
+ return torch .from_numpy (arr ).view (getattr (torch , dtype )).view (shape )
233
+
208
234
def _decode_mm_items (self , obj : list ) -> list [MultiModalKwargsItem ]:
209
235
decoded_items = []
210
236
for item in obj :
@@ -228,7 +254,7 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
228
254
if not isinstance (obj , list ):
229
255
raise TypeError (f"Unexpected NestedTensors contents: { type (obj )} " )
230
256
if obj and isinstance (obj [0 ], str ):
231
- return torch . from_numpy ( self ._decode_ndarray (obj ) )
257
+ return self ._decode_tensor (obj )
232
258
return [self ._decode_nested_tensors (x ) for x in obj ]
233
259
234
260
def ext_hook (self , code : int , data : memoryview ) -> Any :
0 commit comments