@@ -2361,17 +2361,17 @@ def acc_ops_slice_tensor(
2361
2361
2362
2362
ranks = len (input_val .shape ) + (1 if network .has_implicit_batch_dimension else 0 )
2363
2363
dim = get_positive_dim (cast (int , kwargs ["dim" ]), ranks )
2364
-
2364
+ dynamic_shape = has_dynamic_shape ( input_val . shape )
2365
2365
if network .has_implicit_batch_dimension :
2366
2366
if dim == 0 :
2367
2367
raise RuntimeError (
2368
2368
f"We do not support slice_tensor at batch dim when it's implicit, got { dim } !"
2369
2369
)
2370
2370
dim = dim - 1
2371
2371
else :
2372
- raise RuntimeError (
2373
- "We don't support slice_tensor with explicit batch dimension yet!"
2374
- )
2372
+ if dynamic_shape :
2373
+ # Check whether slice target dim is dynamic shape dim
2374
+ assert input_val . shape [ dim ] != - 1 , "Can't chunk on dynamic shape dimension!"
2375
2375
2376
2376
start_int = cast (int , kwargs ["start" ])
2377
2377
stop_int = cast (int , kwargs ["stop" ])
@@ -2383,7 +2383,18 @@ def acc_ops_slice_tensor(
2383
2383
output_shape = list (input_val .shape )
2384
2384
output_shape [dim ] = (stop_int - start_int ) // step_int
2385
2385
2386
- layer = network .add_slice (input_val , start = start , shape = output_shape , stride = stride )
2386
+ if dynamic_shape > 0 :
2387
+ output_shape = get_shape_with_dynamic_shape (
2388
+ network , output_shape , input_val , target , name
2389
+ )
2390
+ layer = network .add_slice (
2391
+ input_val ,
2392
+ start = start ,
2393
+ shape = [] if dynamic_shape else output_shape ,
2394
+ stride = stride ,
2395
+ )
2396
+ if dynamic_shape :
2397
+ layer .set_input (2 , output_shape )
2387
2398
set_layer_name (layer , target , name )
2388
2399
return layer .get_output (0 )
2389
2400
@@ -2584,11 +2595,14 @@ def acc_ops_split(
2584
2595
)
2585
2596
2586
2597
dim = cast (int , kwargs ["dim" ])
2598
+ dynamic_shape = has_dynamic_shape (input_val .shape )
2587
2599
if network .has_implicit_batch_dimension :
2588
2600
assert dim != 0 , "Can't split on batch dim when it's implicit!"
2589
2601
dim -= 1
2590
2602
else :
2591
- raise RuntimeError ("We don't support split with explicit batch dimension yet!" )
2603
+ if dynamic_shape > 0 :
2604
+ # Check whether slice target dim is dynamic shape dim
2605
+ assert input_val .shape [dim ] != - 1 , "Can't chunk on dynamic shape dimension!"
2592
2606
2593
2607
split_size = cast (int , kwargs ["split_size" ])
2594
2608
start = [0 ] * len (input_val .shape )
@@ -2607,7 +2621,15 @@ def acc_ops_split(
2607
2621
shape = list (input_val .shape )
2608
2622
shape [dim ] = min (split_size , cast (int , max_offset - offset ))
2609
2623
start [dim ] = offset
2610
- layer = network .add_slice (input_val , start = start , shape = shape , stride = stride )
2624
+ if dynamic_shape :
2625
+ shape = get_shape_with_dynamic_shape (
2626
+ network , shape , input_val , target , f"{ name } _shape_{ i } "
2627
+ )
2628
+ layer = network .add_slice (
2629
+ input_val , start = start , shape = [] if dynamic_shape else shape , stride = stride
2630
+ )
2631
+ if dynamic_shape :
2632
+ layer .set_input (2 , shape )
2611
2633
offset += split_size
2612
2634
set_layer_name (layer , target , f"{ name } _{ i } " )
2613
2635
output .append (layer .get_output (0 ))
@@ -2761,7 +2783,7 @@ def acc_ops_getitem(
2761
2783
slices = (slices ,)
2762
2784
2763
2785
dynamic_shape = get_dynamic_dims (input_val .shape )
2764
- if dynamic_shape :
2786
+ if len ( dynamic_shape ) > 0 :
2765
2787
for i , s in zip (input_val .shape , slices ):
2766
2788
assert i > 0 or (
2767
2789
s in [slice (None , None , None ), slice (0 , None , None ), Ellipsis ]
0 commit comments