@@ -230,22 +230,37 @@ def scatter_update(inputs, indices, updates):
230
230
231
231
def slice (inputs , start_indices , shape ):
232
232
inputs = convert_to_tensor (inputs )
233
-
233
+ if not isinstance (shape , list ):
234
+ shape = convert_to_tensor (shape , dtype = "int32" ).tolist ()
235
+ else :
236
+ shape = [i if isinstance (i , int ) else i .item () for i in shape ]
237
+ if not isinstance (start_indices , list ):
238
+ start_indices = convert_to_tensor (start_indices , dtype = "int32" ).tolist ()
239
+ else :
240
+ start_indices = [
241
+ i if isinstance (i , int ) else i .item () for i in start_indices
242
+ ]
234
243
python_slice = __builtins__ ["slice" ]
235
244
slices = tuple (
236
- python_slice (int ( start_index ), int ( start_index + length ) )
245
+ python_slice (start_index , start_index + length )
237
246
for start_index , length in zip (start_indices , shape )
238
247
)
239
248
return inputs [slices ]
240
249
241
250
242
251
def slice_update (inputs , start_indices , updates ):
243
252
inputs = convert_to_tensor (inputs )
253
+ if not isinstance (start_indices , list ):
254
+ start_indices = convert_to_tensor (start_indices , dtype = "int32" ).tolist ()
255
+ else :
256
+ start_indices = [
257
+ i if isinstance (i , int ) else i .item () for i in start_indices
258
+ ]
244
259
updates = convert_to_tensor (updates )
245
260
246
261
python_slice = __builtins__ ["slice" ]
247
262
slices = tuple (
248
- python_slice (int ( start_index ), int ( start_index + update_length ) )
263
+ python_slice (start_index , start_index + update_length )
249
264
for start_index , update_length in zip (start_indices , updates .shape )
250
265
)
251
266
inputs [slices ] = updates
0 commit comments