@@ -341,3 +341,20 @@ func.func @torch.aten.view_dynamic_shape(%arg0: !torch.vtensor<[?,384,16],f32>,
341341 %4 = torch.aten.view %arg0 , %3 : !torch.vtensor <[?,384 ,16 ],f32 >, !torch.list <int > -> !torch.vtensor <[?,24 ,16 ,16 ],f32 >
342342 return %4 : !torch.vtensor <[?,24 ,16 ,16 ],f32 >
343343}
344+
345+ // -----
346+
347+ // CHECK-LABEL: func.func @torch.aten.slice_scatter(
348+ // CHECK-DAG: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,3],f32> -> tensor<1x3xf32>
349+ // CHECK-DAG: %[[ARG1:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,2],f32> -> tensor<1x2xf32>
350+ // CHECK: %[[OUT:.*]] = tcp.custom_op("torch.aten.slice_scatter") %[[ARG0]], %[[ARG1]] {dim = 1 : i64, end = 3 : i64, start = 2 : i64, step = 4 : i64, torch_operand_names = ["self", "src"]} : tensor<1x3xf32>, tensor<1x2xf32> -> tensor<1x3xf32>
351+ // CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[OUT]] : tensor<1x3xf32> -> !torch.vtensor<[1,3],f32>
352+ // CHECK: return %[[RET]]
353+ func.func @torch.aten.slice_scatter (%arg0: !torch.vtensor <[1 ,3 ],f32 >, %arg1: !torch.vtensor <[1 ,2 ],f32 >) -> !torch.vtensor <[1 ,3 ],f32 > {
354+ %dim = torch.constant.int 1
355+ %start = torch.constant.int 2
356+ %end = torch.constant.int 3
357+ %step = torch.constant.int 4
358+ %0 = torch.aten.slice_scatter %arg0 , %arg1 , %dim , %start , %end , %step : !torch.vtensor <[1 ,3 ],f32 >, !torch.vtensor <[1 ,2 ],f32 >, !torch.int , !torch.int , !torch.int , !torch.int -> !torch.vtensor <[1 ,3 ],f32 >
359+ return %0 : !torch.vtensor <[1 ,3 ],f32 >
360+ }
0 commit comments