@@ -140,15 +140,15 @@ def _encode_tensor(
140
140
# this creates a copy of the tensor
141
141
obj = obj .contiguous () if not obj .is_contiguous () else obj
142
142
# view the tensor as a 1D array of bytes
143
- arr = obj .view ([ obj .numel ()] ).view (torch .uint8 ).numpy ()
143
+ arr = obj .view (( obj .numel (),) ).view (torch .uint8 ).numpy ()
144
144
if obj .nbytes < self .size_threshold :
145
145
data = msgpack .Ext (CUSTOM_TYPE_RAW_VIEW , arr .data )
146
146
else :
147
147
# Otherwise encode index of backing buffer to avoid copy.
148
148
data = len (self .aux_buffers )
149
149
self .aux_buffers .append (arr .data )
150
- dt = str (obj .dtype )[6 :] # remove 'torch.' prefix
151
- return dt , obj .shape , data
150
+ dtype = str (obj .dtype )[6 :] # remove 'torch.' prefix
151
+ return dtype , obj .shape , data
152
152
153
153
def _encode_nested_tensors (self , nt : NestedTensors ) -> Any :
154
154
if isinstance (nt , torch .Tensor ):
@@ -228,8 +228,10 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor:
228
228
# the returned memory is non-writeable.
229
229
buffer = self .aux_buffers [data ] if isinstance (data , int ) \
230
230
else bytearray (data )
231
- arr = np .ndarray (buffer = buffer , dtype = np .uint8 , shape = [len (buffer )])
232
- return torch .from_numpy (arr ).view (getattr (torch , dtype )).view (shape )
231
+ arr = np .ndarray (buffer = buffer , dtype = np .uint8 , shape = (len (buffer ),))
232
+ torch_dtype = getattr (torch , dtype )
233
+ assert isinstance (torch_dtype , torch .dtype )
234
+ return torch .from_numpy (arr ).view (torch_dtype ).view (shape )
233
235
234
236
def _decode_mm_items (self , obj : list ) -> list [MultiModalKwargsItem ]:
235
237
decoded_items = []
0 commit comments