Skip to content

Commit 36418dc

Browse files
authored
add scatter to tcp custom op (#95)
1 parent 78c1c25 commit 36418dc

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

lib/Conversion/TorchToTcp/TcpCustomOp.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,28 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
349349
}
350350
};
351351

352+
class ConvertAtenSliceScatterOp
353+
: public OpConversionPattern<AtenSliceScatterOp> {
354+
using OpConversionPattern::OpConversionPattern;
355+
356+
LogicalResult
357+
matchAndRewrite(AtenSliceScatterOp op, OpAdaptor adaptor,
358+
ConversionPatternRewriter &rewriter) const override {
359+
// this should really have some tcp op to reduce to. So going to CustomOp
360+
// is more of a placeholder than a serious implementation
361+
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
362+
getTypeConverter()};
363+
helper.addOperand("self", adaptor.getSelf());
364+
helper.addOperand("src", adaptor.getSrc());
365+
helper.addIntAttr("dim", op.getDim());
366+
helper.addIntAttr("start", op.getStart());
367+
helper.addIntAttr("end", op.getEnd());
368+
helper.addIntAttr("step", op.getStep());
369+
370+
return helper.replace();
371+
}
372+
};
373+
352374
} // namespace
353375

354376
void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
@@ -369,6 +391,7 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
369391
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenSortOp);
370392
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenCumsumOp);
371393
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenMinDimOp);
394+
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenSliceScatterOp);
372395
// AtenViewOp can still live after torch-to-tcp conversion
373396
patterns.add<ConvertAtenViewOp>(typeConverter, patterns.getContext());
374397
#undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN

test/Conversion/TorchToTcp/tcp_custom_ops.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,20 @@ func.func @torch.aten.view_dynamic_shape(%arg0: !torch.vtensor<[?,384,16],f32>,
341341
%4 = torch.aten.view %arg0, %3 : !torch.vtensor<[?,384,16],f32>, !torch.list<int> -> !torch.vtensor<[?,24,16,16],f32>
342342
return %4 : !torch.vtensor<[?,24,16,16],f32>
343343
}
344+
345+
// -----
346+
347+
// CHECK-LABEL: func.func @torch.aten.slice_scatter(
348+
// CHECK-DAG: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,3],f32> -> tensor<1x3xf32>
349+
// CHECK-DAG: %[[ARG1:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,2],f32> -> tensor<1x2xf32>
350+
// CHECK: %[[OUT:.*]] = tcp.custom_op("torch.aten.slice_scatter") %[[ARG0]], %[[ARG1]] {dim = 1 : i64, end = 3 : i64, start = 2 : i64, step = 4 : i64, torch_operand_names = ["self", "src"]} : tensor<1x3xf32>, tensor<1x2xf32> -> tensor<1x3xf32>
351+
// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[OUT]] : tensor<1x3xf32> -> !torch.vtensor<[1,3],f32>
352+
// CHECK: return %[[RET]]
353+
func.func @torch.aten.slice_scatter(%arg0: !torch.vtensor<[1,3],f32>, %arg1: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,3],f32> {
354+
%dim = torch.constant.int 1
355+
%start = torch.constant.int 2
356+
%end = torch.constant.int 3
357+
%step = torch.constant.int 4
358+
%0 = torch.aten.slice_scatter %arg0, %arg1, %dim, %start, %end, %step : !torch.vtensor<[1,3],f32>, !torch.vtensor<[1,2],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],f32>
359+
return %0 : !torch.vtensor<[1,3],f32>
360+
}

0 commit comments

Comments
 (0)